Transformer实战-系列教程5:Vision Transformer 源码解读3

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传

6、Block类------构造函数

class Block(nn.Module):
    def __init__(self, config, vis):
        super(Block, self).__init__()
        self.hidden_size = config.hidden_size
        self.attention_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn_norm = LayerNorm(config.hidden_size, eps=1e-6)
        self.ffn = Mlp(config)
        self.attn = Attention(config, vis)
  1. 隐藏特征,768维向量
  2. LayerNorm,层归一化
  3. 定义一个mlp:两次全连接+relu+Dropout
  4. 定义一个Attention类的实例

7、Attention类------构造函数

class Attention(nn.Module):
    def __init__(self, config, vis):
        super(Attention, self).__init__()
        self.vis = vis
        self.num_attention_heads = config.transformer["num_heads"]
        self.attention_head_size = int(config.hidden_size / self.num_attention_heads)
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        self.query = Linear(config.hidden_size, self.all_head_size)
        self.key = Linear(config.hidden_size, self.all_head_size)
        self.value = Linear(config.hidden_size, self.all_head_size)

        self.out = Linear(config.hidden_size, config.hidden_size)
        self.attn_dropout = Dropout(config.transformer["attention_dropout_rate"])
        self.proj_dropout = Dropout(config.transformer["attention_dropout_rate"])

        self.softmax = Softmax(dim=-1)
  1. num_attention_heads ,多头注意力机制,配置参数中定义了要多少头
  2. attention_head_size ,每头注意力处理多少维向量,隐层特征是768维/12头=64维向量
  3. all_head_size 应该要等于config.hidden_size
  4. 全连接层生成Q、K、V
  5. 全连接层、dropout、softmax

Transformer

8、Vision Transformer类------前向传播

class VisionTransformer(nn.Module):
    def forward(self, x, labels=None):
        x, attn_weights = self.transformer(x)
        # print(x.shape)
        logits = self.head(x[:, 0])
        # print(logits.shape)

        if labels is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_classes), labels.view(-1))
            return loss
        else:
            return logits, attn_weights

前向传播函数打上断点,开启debug模式,查看数据维度变化:
输入x=[16,3,224,224],依次为batch_size,通道数,图像长和宽
经过self.transformer()后
x=[16,197,768],依次为batch_size,197=196+1其中196为序列长度、1为分类标记,自定义的向量维度

logits = [16,10],依次为batch_size,10分类的分数

9、Transformer类------前向传播

class Transformer(nn.Module):
    def __init__(self, config, img_size, vis):
        super(Transformer, self).__init__()
        self.embeddings = Embeddings(config, img_size=img_size)
        self.encoder = Encoder(config, vis)

    def forward(self, input_ids):
        embedding_output = self.embeddings(input_ids)
        encoded, attn_weights = self.encoder(embedding_output)
        return encoded, attn_weights

前向传播函数打上断点,开启debug模式,查看数据维度变化:

  • input_ids.shape = torch.Size([16, 3, 224, 224])
  • embedding_output.shape = torch.Size([16, 197, 768])
  • encoded.shape = torch.Size([16, 197, 768])

Embeddings输入是彩色图,输出每个位置得到768维向量
Encoder输入与输出都是768维向量

10、Embeddings------前向传播

class Embeddings(nn.Module):
    def forward(self, x):
        B = x.shape[0]
        cls_tokens = self.cls_token.expand(B, -1, -1)
        if self.hybrid:
            x = self.hybrid_model(x)
        x = self.patch_embeddings(x)
        x = x.flatten(2)
        x = x.transpose(-1, -2)
        x = torch.cat((cls_tokens, x), dim=1)
        embeddings = x + self.position_embeddings
        embeddings = self.dropout(embeddings)
        return embeddings
  1. 输入图像:x.shape= torch.Size([16, 3, 224, 224])
  2. B为batch_size=16
  3. cls_tokens.shape= torch.Size([16, 1, 768]),cls_token:这是一个特殊的“分类(class)标记”,其设计灵感来源于BERT模型中的[CLS]标记。在Vision Transformer中,这个cls_token被添加到图像块(patch)嵌入的序列的前面,并且在整个Transformer模型的处理过程中一直携带着。模型的最终目标是使用这个cls_token的表示(经过Transformer模型的多层处理后的输出)来进行分类任务。换句话说,cls_token在模型的最后一层的输出被用作图像分类或其他下游任务的基础。这个就是197=196+1的1的由来
  4. x.shape = self.patch_embeddings(x).shape = torch.Size([16, 768, 14, 14]),patch_embeddings:这是将输入图像分割成多个图像块(patches),然后将每个图像块转换成模型可以处理的嵌入向量的过程。在Vision Transformer中,输入图像首先被划分为多个固定大小的小块,每个小块接着通过一个卷积层(在这个代码中是Conv2d层)转换成一个嵌入向量。这个卷积层的输出通道数等于模型的隐藏层大小(config.hidden_size),这样每个图像块就被映射到了一个高维空间,以便后续由Transformer处理。patch_embeddings实质上是对图像进行了一种“词嵌入”操作,将图像的原始像素值转换为模型可以理解的语义向量
  5. x.flatten(2).torch.Size([16, 768, 196]),展开为一个序列,现在是self-Attention中标准的输入格式
  6. x.transpose(-1, -2).torch.Size([16, 196, 768]),转换维度
  7. torch.cat((cls_tokens, x), dim=1).torch.Size([16, 197, 768]),x和cls_tokens拼接
  8. (x + self.position_embeddings).torch.Size([16, 197, 768]),加上位置编码,维度保存不变,self.position_embeddings位置编码是为每一个小块都增加了一个初始化的全零的可学习的位置嵌入张量,
  9. self.dropout(embeddings).torch.Size([16, 197, 768]),加上Dropout

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/371740.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C++】类和对象(2)

这篇博客继续学习类和对象~,主要介绍了类的6个默认成员函数。 目录 类的6个默认成员函数 构造函数 概念 特性 析构函数 概念 特性 拷贝构造函数 特性 赋值运算符重载 运算符重载 赋值运算符重载 前置和后置重载 日期类的实现 const成员 取地址及cons…

FTP口令问题

FTP(File Transfer Protocol ,文件传输协议)是一个文件传输协议,用户通 过FTP可从客户机程序向远程主机上传或下载文件,常用于网站代码维护、 日常 源码备份等。如果攻击者通过FTP匿名访问或者通过弱口令破解获取FTP…

Linux进程信号(2)--信号的保存

目录 1.阻塞信号 1.1 信号其他相关常见概念 1.实际执行信号的处理动作称为信号递达(Delivery) 2.信号从产生到递达之间的状态,称为信号未决(Pending)。 3.进程可以选择阻塞 (Block )某个信号。 1.2信号在内核中的表示 sigset_t 信号集操作函数 使用sigprocm…

JAVA-File五个练习

下面习题思路大多都是: 1.获取路径下所有列表(listfiles),2.遍历文件或文件夹(增强for),3.判断是否是文件(isFile)并直接执行逻辑,4.判断当前是文件夹的情况&…

【React】redux状态管理、react-redux状态管理高级封装模块化

【React】react组件传参、redux状态管理 一、redux全局状态管理1、redux概述2、redux的组成1.1 State-状态1.2 Action-事件1.3 Reducer1.4 Store 3、redux入门案例1.1 前期准备1.2 构建store1.2.1 在src下新建store文件夹1.2.2 在store文件夹下新建index.ts文件1.2.3 在index.t…

Swift Vapor 教程(查询数据、插入数据)

上一篇简单写了 怎么创建 Swift Vapor 项目以及在开发过程中使用到的软件。 这一篇写一个怎么在创建的项目中创建一个简单的查询数据和插入数据。 注:数据库配置比较重要 先将本地的Docker启动起来,用Docker管理数据库 将项目自己创建的Todo相关的都删掉…

【python】python爱心代码【附源码】

一、实现效果: 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 二、完整代码: import math import random import threading import time from math import sin, cos, pi, log from tkinter import * import re# 烟花相关设置 Fireworks [] m…

STM32F407 CAN参数配置 500Kbps

本篇CAN参数适用 芯片型号:STM32F407xx系统时钟:168MHz,CAN挂载总线APB1为42M波 特 率 :500Kpbs引脚使用:TX_PB9,RX_PB8;修改为PA11PA12后,参数不变。 步骤一、打勾开启CAN&#xf…

vue使用es的reduce方法编译报错Error: Can‘t resolve ‘core-js/modules/es.array.reduce.js‘

哈喽 大家好啊 最近在vue使用es的reduce方法编译报错Error: Cant resolve core-js/modules/es.array.reduce.js 报错如图所示: 解决方案: npm install --save core-js 然后重新编译下将正常了 参考原文: 使用import异步加载语法报错_module not foun…

CAD-autolisp(四)——编译

目录 一、编译1.1 界面操作1.2 生成的应用程序(二选一) 二、后续学习 一、编译 编译:lsp后缀名为原文件,后缀名为fas、vlx为编译后文件,其会把sld、dcl、lsp等文件都编译进一个应用程序文件中加载:cad命令…

ZigBee学习——在官方例程基础实现点灯

IAR版本 :10.10.1 Z-stack版本 :3.0.2 文章目录 一、买的板子原理图二、实现过程2.1 重定义LED的物理映射(HAL层)2.2 创建LED事件(应用层)2.2.1 定义用户事件2.2.2 修改zclGenericApp_event_loop() 2.3 触发事件 一、买的板子原理图 二、实现过程 2.1 重定义LED的物理映射(HAL…

windows10忘记密码的解决方案

大家好,我是爱编程的喵喵。双985硕士毕业,现担任全栈工程师一职,热衷于将数据思维应用到工作与生活中。从事机器学习以及相关的前后端开发工作。曾在阿里云、科大讯飞、CCF等比赛获得多次Top名次。现为CSDN博客专家、人工智能领域优质创作者。喜欢通过博客创作的方式对所学的…

处理SERVLET中的错误

处理SERVLET中的错误 问题陈述 一位用户在使用在线计算机应用程序时输入一个非数字字符做数字加法。servlet试图将用户输入的值转换成整数型时,引发了NumberFormException类型的异常。要创建一个Web应用程序来使用自定义错误页面处理该异常。该自定义错误页面需要向用户提供关…

uniapp android和微信小程序实现PDF在线预览

在使用uniapp开发移动端时,微信开发者工具里webview能正常打开后端接口返回的pdf文件流。正式发布后,在配置了业务域名和服务器域名的前提下,预览pdf文件却只能看到白屏,因此我猜测微信小程序不能通过webview读取文件流。这个想法…

笔记本电脑的WIFI模块,突然不显示了,网络也连接不上

问题复现: 早上,在更新完笔记本电脑的系统之后,连网之后,网络突然直接断开,一查看,WiFi模块居然不见了,开机重启也是如此,这种情况常常出现在更新系统之后,WiFi模块驱动就…

Unix五种I/O模型(阻塞、非阻塞、多路复用、信号驱动、异步)

文章目录 概要一、I/O基础二、阻塞式I/O三、非阻塞式I/O三、I/O多路复用四、信号驱动I/O五、异步I/O六、小结 概要 在工作中,经常使用Nginx、Redis等开源组件,常提到其高性能的原因是网络I/O的实现是基于epoll(多路复用)。这次呢…

复制命令COPY

copy在英文中是复制的意思,所谓复制就是原来的文件并没有任何改变,重新产生了一个内容和原来文件没有任何差别的文件。复制(copy)命令主要用于复制一个或更多个文件到指定的位置,该命令可以被用于合并文件,…

ES6-数组的解构赋值

一、数组的解构赋值的规律 - 只要等号两边的模式相同,左边的变量就会被赋予对应的值二、数组的解构赋值的例子讲解 1)简单的示例(完整的解构赋值) 示例 //基本的模式匹配 // a,b,c依次和1,2&#xff0c…

Element UI+Spring Boot进行CRUD的实例

ElementUI安装与使用指南 前端代码:点击查看learnelementuispringboot项目源码 后端代码:点击查看 LearnElementUiAndSpringBoot 一、前端配置 安装axios Gitee的axios介绍与使用 GitHub的axios介绍与使用 方式一:使用npm安装 $ npm in…

深度学习(生成式模型)—— Consistency Models

文章目录 前言预备知识:SDE与ODEMethod实验结果 前言 Diffusion model需要多次推断才能生成最终的图像,这将耗费大量的计算资源。前几篇博客我们已经介绍了加速Diffusion model生成图像速率的DDIM和Stable Diffusion,本节将介绍最近大火的Co…