【前沿模型解析】潜在扩散模型 2-3 | 手撕感知图像压缩 基础块 自注意力块

1 注意力机制回顾

同ResNet一样,注意力机制应该也是神经网络最重要的一部分了。

想象一下你在观看一场电影,但你的朋友在给你发短信。虽然你正在专心观看电影,但当你听到手机响起时,你会停下来查看短信,然后这时候电影的内容就会被忽略。这就是注意力机制的工作原理。

在处理输入序列时,比如一句话中的每个单词,注意力机制允许模型像你一样,专注于输入中的不同部分。模型可以根据输入的重要性动态地调整自己的注意力,注意自己觉得比较重要的部分,忽略一些不太重要的部分,以便更好地理解和处理序列数据。

具体来说,是通过q,k,v实现的

q(查询),k(键值)之间先进行计算,获得重要性权重w,w再作用于v

利用卷积操作确定q,k,v

q,k做运算得到w,缩放w

w和v做运行

最后残差

得到

2 Atten块的实现

在这里插入图片描述

2.1 初始化函数

    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=3, num_channels=in_channels, eps=1e-6, affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)

2.2 前向传递函数

def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention 自注意力计算
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w) #[4,12,1024]
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)
  1. b,c,h,w = q.shape:假设q是一个四维张量,其中b表示batch size,c表示通道数,hw表示高度和宽度。

  2. q = q.reshape(b,c,h*w):将q张量重新形状为三维张量,其中第三维是原高度和宽度的乘积。这样做是为了方便后续计算。

  3. q = q.permute(0,2,1):交换张量维度,将第三维移动到第二维,这是为了后续计算方便。

  4. k = k.reshape(b,c,h*w):对k做和q类似的操作,将其形状改为三维张量。

  5. w_ = torch.bmm(q,k):计算qk的批次矩阵乘积(batch matrix multiplication),得到注意力权重的初始矩阵。这里的w_是一个b x (h*w) x (h*w)的张量,表示每个位置对应的注意力权重。

  6. w_ = w_ * (int(c)**(-0.5)):对初始注意力权重进行缩放,这里使用了一个缩放因子,通常是通道数的倒数的平方根。这个缩放是为了确保在计算注意力时不会因为通道数过大而导致梯度消失或梯度爆炸。

  7. w_ = torch.nn.functional.softmax(w_, dim=2):对注意力权重进行softmax操作,将其归一化为概率分布,表示每个位置的重要性。

这段代码的作用是实现自注意力机制中计算注意力权重的过程,其中qk分别代表查询(query)和键(key),通过计算它们的相似度得到注意力权重。

        # attend to values 加注意力到值上
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] [4,12,1024]*[4,1024,1024]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_
  1. v = v.reshape(b,c,h*w):将值(value)张量v重新形状为三维张量,其中第三维是原高度和宽度的乘积。这样做是为了方便后续计算。

  2. w_ = w_.permute(0,2,1):交换注意力权重w_张量的维度,将第三维移动到第二维,这是为了后续计算方便。

  3. h_ = torch.bmm(v,w_):计算值v和经过缩放的注意力权重w_的批次矩阵乘积(batch matrix multiplication),得到自注意力的输出。这里的h_是一个b x c x (h*w)的张量,表示每个位置经过注意力计算后的输出。

  4. h_ = h_.reshape(b,c,h,w):将h_张量重新形状为四维张量,恢复其原始的高度和宽度。

  5. h_ = self.proj_out(h_):通过一个全连接层proj_out对自注意力的输出h_进行线性变换和非线性变换,这个操作有助于提取特征并保持网络的表达能力。

最后,将输入x和自注意力的输出h_相加,得到最终的自注意力输出。这样做是为了在保留原始输入信息的同时,加入了经过自注意力计算后的新信息,从而使模型能够更好地理解输入序列的语义信息。

2.3 Atten注意力完整代码

from torch import nn
import torch
from einops import rearrange


class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=3, num_channels=in_channels, eps=1e-6, affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention 自注意力计算
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w) #[4,12,1024]
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values 加注意力到值上
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] [4,12,1024]*[4,1024,1024]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_

def make_attn(in_channels, attn_type="vanilla"):
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    else:
        return nn.Identity(in_channels)
    

atten_block=make_attn(12)
x=torch.ones(4,12,32,32)
y=atten_block(x)
print(y.shape)

3 源代码中的另一种注意力实现

源代码中还实现了LinearAttention,是另一种注意力机制

可以看看

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)  
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
        return self.to_out(out)

class LinAttnBlock(LinearAttention):
    """to match AttnBlock usage"""
    def __init__(self, in_channels):
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)

对于forward函数

  1. b, c, h, w = x.shape:假设输入张量x是一个四维张量,其中b表示batch size,c表示通道数,hw表示高度和宽度。

  2. qkv = self.to_qkv(x):将输入张量x通过一个线性变换(可能包括分别计算查询(query)、键(key)和值(value))得到qkv张量,其形状为b x (3*heads*c) x h x w,其中heads是多头注意力的头数。

  3. q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads=self.heads, qkv=3):将qkv张量重新排列为三个张量qkv,分别表示查询、键和值,形状为b x heads x c x (h*w)

  4. k = k.softmax(dim=-1):对键张量k进行softmax操作,将其归一化为概率分布,以便计算注意力权重。

  5. context = torch.einsum('bhdn,bhen->bhde', k, v):使用torch.einsum函数计算注意力权重与值的加权和,得到上下文张量context,形状为b x heads x c x (h*w)

  6. out = torch.einsum('bhde,bhdn->bhen', context, q):使用torch.einsum函数计算上下文张量与查询张量的加权和,得到输出张量out,形状为b x heads x c x (h*w)

  7. out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w):将输出张量out重新排列为形状b x (heads*c) x h x w,恢复其原始形状。

  8. return self.to_out(out):将输出张量out通过一个线性变换得到最终的输出。

如果注意力机制type=None的话,则不进行注意力机制的计算~

用一个torch函数

nn.Identity 这是一个恒等变化的一个函数,不做任何处理

4 完整代码及其测试

from torch import nn
import torch
from einops import rearrange

class LinearAttention(nn.Module):
    def __init__(self, dim, heads=4, dim_head=32):
        super().__init__()
        self.heads = heads
        hidden_dim = dim_head * heads
        self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
        self.to_out = nn.Conv2d(hidden_dim, dim, 1)

    def forward(self, x):
        b, c, h, w = x.shape
        qkv = self.to_qkv(x)
        q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
        k = k.softmax(dim=-1)  
        context = torch.einsum('bhdn,bhen->bhde', k, v)
        out = torch.einsum('bhde,bhdn->bhen', context, q)
        out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
        return self.to_out(out)

class LinAttnBlock(LinearAttention):
    """to match AttnBlock usage"""
    def __init__(self, in_channels):
        super().__init__(dim=in_channels, heads=1, dim_head=in_channels)

class AttnBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.in_channels = in_channels

        self.norm = torch.nn.GroupNorm(num_groups=3, num_channels=in_channels, eps=1e-6, affine=True)
        self.q = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.k = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.v = torch.nn.Conv2d(in_channels,
                                 in_channels,
                                 kernel_size=1,
                                 stride=1,
                                 padding=0)
        self.proj_out = torch.nn.Conv2d(in_channels,
                                        in_channels,
                                        kernel_size=1,
                                        stride=1,
                                        padding=0)


    def forward(self, x):
        h_ = x
        h_ = self.norm(h_)
        q = self.q(h_)
        k = self.k(h_)
        v = self.v(h_)

        # compute attention 自注意力计算
        b,c,h,w = q.shape
        q = q.reshape(b,c,h*w) #[4,12,1024]
        q = q.permute(0,2,1)   # b,hw,c
        k = k.reshape(b,c,h*w) # b,c,hw
        w_ = torch.bmm(q,k)     # b,hw,hw    w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
        w_ = w_ * (int(c)**(-0.5))
        w_ = torch.nn.functional.softmax(w_, dim=2)

        # attend to values 加注意力到值上
        v = v.reshape(b,c,h*w)
        w_ = w_.permute(0,2,1)   # b,hw,hw (first hw of k, second of q)
        h_ = torch.bmm(v,w_)     # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] [4,12,1024]*[4,1024,1024]
        h_ = h_.reshape(b,c,h,w)

        h_ = self.proj_out(h_)

        return x+h_

def make_attn(in_channels, attn_type="vanilla"):
    print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
    if attn_type == "vanilla":
        return AttnBlock(in_channels)
    elif attn_type=="line":
        return LinAttnBlock(in_channels)
    else:
        return nn.Identity(in_channels)
    

atten_block=make_attn(12)
x=torch.ones(4,12,32,32)
y=atten_block(x)
print(y.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/538576.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CSS特效---纯CSS实现点击切换按钮

1、演示 2、一切尽在代码中 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><meta name"viewport" content"w…

第11版《中国网络安全行业全景图》发布,谁霸榜了软件供应链安全领域?

近日&#xff0c;知名网络安全行业媒体安全牛正式发布了第11版《中国网络安全行业全景图》&#xff08;以下简称”全景图“&#xff09;&#xff0c;共收录了国内网络安全企业454家&#xff0c;细分领域共收录2413项&#xff0c;旨在优先展现当前热门网络安全领域中具有较强市场…

mysql题目1

tj11: ​ select * from t_student where grade 大一 and major 软件工程 ​ tj12: SELECTt_student.name, count(t_choice.cid)FROMt_choiceINNER JOINt_courseON t_choice.cid t_course.idINNER JOINt_studentON t_choice.sid t_student.id GROUP BYt_choice.sid HAVIN…

如何免费搭建幻兽帕鲁服务器?

雨云是一家国内的云计算服务提供商&#xff0c;为了吸引用户推出了积分兑换云产品活动&#xff0c;只需要完成简单积分任务即可获得积分&#xff0c;积分可以兑换免费游戏云、对象存储或者虚拟主机。本文将给大家分享雨云免费游戏云领取及幻兽帕鲁开服教程。 第一步&#xff1a…

字节面试:ThreadLocal内存泄漏,怎么破?什么是 ITL、TTL、FTL?

尼恩说在前面 在40岁老架构师 尼恩的读者交流群(50)中&#xff0c;最近有小伙伴拿到了一线互联网企业如得物、阿里、滴滴、极兔、有赞、希音、百度、网易、美团的面试资格&#xff0c;遇到很多很重要的面试题&#xff1a; 1.请解释ThreadLocal是什么&#xff0c;以及它的主要用…

【Nacos】Nacos最新版的安装、配置过程记录和踩坑分享

Nacos是什么&#xff1f;有什么功能&#xff1f;大家可以自行联网&#xff08;推荐 https://cn.bing.com/&#xff09;搜索&#xff0c;这里就不做介绍了。 简单的看了下官网&#xff0c;安装最新版的Nacos&#xff08;v2.3.2&#xff09;需要使用到JDK&#xff08;1.8.0&…

【数据结构】——八大排序(详解+图+代码详解)看完你会有一个全新认识

创作不易&#xff0c;给一个免费的三连吧&#xff1f;&#xff01; 前言 排序在生活中是非常重要的&#xff0c;所以排序在数据结构中也占有很大的地位&#xff0c;相信大家可能被这些排序弄得比较混淆或者对某个排序原理没有弄清&#xff0c;相信看完本篇会对你有所帮助&…

力扣HOT100 - 41. 缺失的第一个正数

解题思路&#xff1a; 原地哈希 就相当于&#xff0c;让每个数字n都回到下标为n-1的家里。 而那些没有回到家里的就成了孤魂野鬼流浪在外&#xff0c;他们要么是根本就没有自己的家&#xff08;数字小于等于0或者大于nums.size()&#xff09;&#xff0c;要么是自己的家被别…

【报错】AttributeError: ‘NoneType‘ object has no attribute ‘pyplot_show‘(已解决)

【报错】AttributeError: ‘NoneType’ object has no attribute ‘pyplot_show’ 问题描述&#xff1a;python可视化出现下面报错 我的原始代码&#xff1a; import matplotlib.pyplot as pltplt.figure() plt.plot(x, y, bo-) plt.axis(equal) plt.xlabel(X) plt.ylabe…

了解何为vue-cli及其作用

Vue CLI是一个由Vue.js官方提供的命令行工具&#xff0c;用于快速搭建基于Vue.js的项目。它可以帮助开发者快速搭建项目结构、配置构建工具、添加插件等&#xff0c;从而更加高效地进行Vue.js项目的开发。 注&#xff1a;在创建工程前需要 先使用命令行&#xff1a;npm instal…

实战项目——智慧社区(三)之 门禁管理

1、人脸识别 实现思路 ①查询出所有的小区信息&#xff0c;下拉列表显示&#xff0c;用于后续判断人脸信息是否与所选小区匹配 ②人脸识别&#xff1a;调用腾讯人脸识别的API接口&#xff0c;首先判断传入图片是否为一张人脸&#xff1b;其次将这张人脸去服务器的人员库进行…

拥有一台阿里云服务器可以做什么?

阿里云ECS云服务器可以用来做什么&#xff1f;云服务器可以用来搭建网站、爬虫、邮件服务器、接口服务器、个人博客、企业官网、数据库应用、大数据计算、AI人工智能、论坛、电子商务、AI、LLM大语言模型、测试环境等&#xff0c;云服务器吧yunfuwuqiba.com整理阿里云服务器可以…

SpringBoot 中的日志原来是这么工作的

在有些场景&#xff0c;能通过调整日志的打印策略来提升我们的系统吞吐量,你知道吗&#xff1f; 我们以Springboot集成Log4j2为例&#xff0c;详细说明Springboot框架下Log4j2是如何工作的&#xff0c;你可能会担心&#xff0c;如果是使用Logback日志框架该怎么办呢&#xff1…

langchain-chatchat指定一个或多个文件回答,不允许回答内容有其他文件内容,即屏蔽其他文件内容

1.找到langchain-chatchat中的knowledge_base_chat.py 2.knowledge_base_chat.py的api内容加上一个flie_name参数&#xff0c;即传过来你需要指定一个文件名称&#xff0c;或多个文件名称&#xff0c;同时也可以不指定&#xff0c;加上以下代码&#xff1a; flie_name: List …

2024-4-10 群讨论:JFR 热点方法采样实现原理

以下来自本人拉的一个关于 Java 技术的讨论群。关注公众号&#xff1a;hashcon&#xff0c;私信拉你 什么是 JFR 热点方法采样&#xff0c;效果是什么样子&#xff1f; 其实对应的就是 jdk.ExecutionSample 和 jdk.NativeMethodSample 事件 这两个事件是用来采样的&#xff0c…

[SystemVerilog]Simulation and Test Benches

Simulation and Test Benches 测试语言中有很大一部分专门用于测试台和测试。在本章中&#xff0c;我们将介绍为硬件设计编写高效测试台的一些常用技术。 6.1 How SystemVerilog Simulator Works 在深入研究如何编写适当的测试台之前&#xff0c;我们需要深入了解模拟器的工作原…

git查看单独某一个文件的历史修改记录

git查看单独某一个文件的历史修改记录 git log -p 文件具体路径 注意&#xff0c;Windows下默认文件路径分隔符是 \&#xff0c;在git bash 里面需要改成 /。 git基于change代码修改与提交_git change-CSDN博客文章浏览阅读361次。git cherry-pick&#xff1a;复制多个提交comm…

使用 Citavi 和 NVivo 简化您的文献综述和研究分析

NVivo 是一款支持定性研究方法和混合研究方法的软件。它可以帮助您收集、整理和分析访谈、焦点小组讨论、问卷调查、音频等内容。NVivo&#xff08;1.0版&#xff09;是Windows和Mac的主要版本。遵循最新的主要版本NVivo 12&#xff08;Windows和Mac&#xff09;。 NVivo 强大…

Java Reflection(从浅入深理解反射)

本节的代码链接&#xff1a;reflection 1. 反射的由来 反射机制允许程序执行期借助于Reflection API取得任何类的内部信息&#xff0c;如成员变量、构造器、成员方法等&#xff0c;并能操作对象的属性及方法&#xff0c;在设计模式和框架底层都会用到。 1.1 引入需求 编写框…

Scala实战:打印九九表

本次实战的目标是使用不同的方法实现打印九九表的功能。我们将通过四种不同的方法来实现这个目标&#xff0c;并在day02子包中创建相应的对象。 方法一&#xff1a;双重循环 我们将使用双重循环来实现九九表的打印。在NineNineTable01对象中&#xff0c;我们使用两个嵌套的fo…