【Vit】Vision Transformer 入门与理解

在学习VIT之前,建议先把 Transformer 搞明白了:【transformer】入门与理解

做了那些改进?

在这里插入图片描述

看图就比较明白了,VIT只用了Encoder的部分,把每一个图片裁剪成若干子图,然后把一个子图flatten一下,当成nlp中的一个token处理。
值得注意的是,在首个 token中嵌入了一个 class_token,维度为(1,embed_dim=768),这个class_token在预测的时候比较有意思,见下图:

在这里插入图片描述
注意上图中有些细节遗漏,全流程应该是:先把输入进行 patch_embedding 变成 visual tokens,然后和 class_token 合并,最后 position_embedding。

另外需要注意的是,class_token 是一个可学习的参数,并不是每次输入时都需要输入的类别数值。

self.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98)  #(1,1,768)

代码

其实有了 Transformer 的基础后,直接看代码就知道VIT是怎么做的了。

import copy
import torch
import torch.nn as nn

# 所有基于nn.Module结构的模版,可以删掉
class Identity(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class Mlp(nn.Module):
    def __init__(self, embed_dim, mlp_ratio, dropout=0.):
        super().__init__()
        self.fc1 = nn.Linear(embed_dim, int(embed_dim * mlp_ratio)) # 中间层扩增
        self.fc2 = nn.Linear(int(embed_dim * mlp_ratio), embed_dim)
        self.act = nn.GELU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # TODO
        x = self.fc1(x)
        x = self.act(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x


class PatchEmbedding(nn.Module):
    def __init__(self, image_size=224, patch_size=16, in_channels=3, embed_dim=768, dropout=0.):
        super().__init__()
        n_patches = (image_size // patch_size) * (image_size // patch_size)  # 196 个 patch
        self.patch_embedding = nn.Conv2d(in_channels=in_channels,  # embedding 操作后变成 torch.Size([10, 768, 14, 14])
                                         out_channels=embed_dim,
                                         kernel_size=patch_size,
                                         stride=patch_size)
        self.dropout = nn.Dropout(dropout)

        # TODO: add class token
        self.class_token = nn.Parameter(torch.ones(1, 1, embed_dim) * 0.98)  #(1,1,768)

        # TODO: add position embedding
        self.position_embedding = nn.Parameter(torch.ones(1, n_patches+1, embed_dim) * 0.98)  #(1,196+1,768)

    def forward(self, x): # 先把 x patch_embedding,然后和 class_token 合并,最后 position_embedding
        # [n, c, h, w]
        cls_tokens = self.class_token.expand([x.shape[0], -1, -1]) #(10,1,768) 根据batch扩增 class_token
        x = self.patch_embedding(x) # [n, embed_dim, h', w']
        x = x.flatten(2) # torch.Size([10, 768, 196])
        x = x.permute([0, 2, 1]) # torch.Size([10, 196, 768])
        x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)
        x = x + self.position_embedding
        return x # torch.Size([10, 197, 768])


class Attention(nn.Module):
    """multi-head self attention"""
    def __init__(self, embed_dim, num_heads, qkv_bias=True, dropout=0., attention_dropout=0.):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = int(embed_dim / num_heads) # 768/4=192
        self.all_head_dim = self.head_dim * num_heads
        self.scales = self.head_dim ** -0.5

        self.qkv = nn.Linear(embed_dim,
                             self.all_head_dim * 3) # [768, 768*3]

        self.proj = nn.Linear(embed_dim, embed_dim)

        self.dropout = nn.Dropout(dropout)
        self.attention_dropout = nn.Dropout(attention_dropout)
        self.softmax = nn.Softmax()

    def transpose_multihead(self, x):
        # x: [N, num_patches 197, all_head_dim 768] -> [N, n_heads, num_patches, head_dim]
        new_shape = [x.shape[:-1][0], x.shape[:-1][1], self.num_heads, self.head_dim] # [10, 197, 4, 192]
        x = x.reshape(new_shape) 
        x = x.permute([0, 2, 1, 3]) # [10, 4, 197, 192]
        return x

    def forward(self, x): # Attention 前后输入输出维度不变,都是 [10, 197, 768]
        B, N, _ = x.shape   # torch.Size([10, 197, 768])
        qkv = self.qkv(x).chunk(3, axis=-1) # 含有三个元素的列表,每一个元素大小 [10, 197, 768]
        q, k, v = map(self.transpose_multihead, qkv) # [10, 4, 197, 192]
        attn = torch.matmul(q, k.transpose(2,3)) # [10, 4, 197, 197]
        attn = attn * self.scales
        attn = self.softmax(attn)
        attn = self.attention_dropout(attn)

        out = torch.matmul(attn, v) # [10, 4, 197, 192]
        out = out.permute([0, 2, 1, 3]) # [10, 197, 4, 192]
        out = out.reshape([B, N, -1]) # [10, 197, 768]

        out = self.proj(out) # [10, 197, 768]
        out = self.dropout(out)

        return out

class EncoderModule(nn.Module):
    def __init__(self, embed_dim=768, num_heads=4, qkv_bias=True, mlp_ratio=4.0, dropout=0., attention_dropout=0.):
        super().__init__()
        self.attn_norm = nn.LayerNorm(embed_dim)
        self.attn = Attention(embed_dim, num_heads)
        self.mlp_norm = nn.LayerNorm(embed_dim)
        self.mlp = Mlp(embed_dim, mlp_ratio)

    def forward(self, x):
        h = x # residual
        x = self.attn_norm(x)
        x = self.attn(x)
        x = x + h

        h = x # residual
        x = self.mlp_norm(x)
        x = self.mlp(x)
        x = x + h
        return x

class Encoder(nn.Module):
    def __init__(self, embed_dim, depth):
        super().__init__()
        Module_list = []
        for i in range(depth):
            encoder_Module = EncoderModule()
            Module_list.append(encoder_Module)
        self.Modules = nn.ModuleList(Module_list)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        for Module in self.Modules:
            x = Module(x)
        x = self.norm(x)
        return x
        

class VisualTransformer(nn.Module):
    def __init__(self,
                 image_size=224,
                 patch_size=16,
                 in_channels=3,
                 num_classes=1000,
                 embed_dim=768,
                 depth=3,
                 num_heads=8,
                 ):
        super().__init__()
        self.patch_embedding = PatchEmbedding(image_size, patch_size, in_channels, embed_dim)
        self.encoder = Encoder(embed_dim, depth)
        self.classifier = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # x:[N, C, H, W]
        x = self.patch_embedding(x) # torch.Size([10, 197, 768])
        x = self.encoder(x) # torch.Size([10, 197, 768])
        x = self.classifier(x[:, 0]) # 注意这里的处理很奇妙哦,参考 x = torch.concat([cls_tokens, x], axis=1) # (10,196+1,768)
        return x


vit = VisualTransformer()
print(vit)

input_data = torch.randn([10,3,224,224]) # 每批次输入10张图片
print(vit(input_data).shape) # torch.Size([10, 1000])

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/544900.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[大模型]Langchain-Chatchat安装和使用

项目地址: https://github.com/chatchat-space/Langchain-Chatchat 快速上手 1. 环境配置 首先,确保你的机器安装了 Python 3.8 - 3.11 (我们强烈推荐使用 Python3.11)。 $ python --version Python 3.11.7接着,创建一个虚拟环境&#xff…

力扣HOT100 - 240. 搜索二维矩阵 II

解题思路&#xff1a; 从左下角开始&#xff0c;根据条件删除行和列。 class Solution {public boolean searchMatrix(int[][] matrix, int target) {int row matrix.length - 1;int col matrix[0].length - 1;int l 0;while (row > 0 && l < col) {if (targ…

宝宝洗衣机怎么选?四款畅销卓越婴儿洗衣机深度剖析!

近几年科技高速发展&#xff0c;我们的生活也因此变得更加便捷、健康高效。尤其是在家庭生活中&#xff0c;各种新兴家电的出现让我们的生活变得更加健康卫生。婴儿洗衣机也为现代家庭提供了极大的便捷。由于婴儿刚出生免疫力比较弱&#xff0c;所以建议婴儿的衣物尽量和大人的…

【面试八股总结】排序算法(一)

参考资料 &#xff1a;阿秀 一、冒泡排序 冒泡排序就是把小的元素往前交换或者把大的元素往后交换&#xff0c;比较相邻的两个元素&#xff0c;交换也发生在这两个元素之间。具体步骤&#xff1a; 比较相邻的元素。如果第一个比第二个大&#xff0c;就交换他们两个。对每一对…

细胞活性和细胞增殖检测试剂盒--CCK8试剂盒

Cell Counting Kit-8又称CCK-8试剂盒或者CCK8试剂盒。CCK8试剂盒基于WST-8法检测细胞的细胞活性和细胞增殖。 WST-8是MTT的升级产品&#xff0c;试剂盒的工作原理是在电子耦合试剂存在的情况下&#xff0c;可以被线粒体内的脱氢酶还原生成高度 水溶性的橙黄色的甲臜产物&#…

域权限维持—黄金票据和白金票据

黄金票据和白金票据 前言 某老哥的一次面试里问到了这个问题&#xff0c;故来做一番了解 该攻击方式在BlackHat 2014被提出&#xff0c;演讲者为Alva Duckwall & Benjamin Delpy&#xff08;gentilkiwi)进行了演示&#xff0c;该演讲提出了Kerberos协议实现过程中的设计…

Python数据分析案例42——基于Attention-BiGRU的时间序列数据预测

承接上一篇的学术缝合&#xff0c;排列组合模型&#xff0c;本次继续缝合模型演示。 Python数据分析案例41——基于CNN-BiLSTM的沪深300收盘价预测-CSDN博客 案例背景 虽然我自己基于各种循环神经网络做时间序列的预测已经做烂了.....但是还是会有很多刚读研究生或者是别的领…

2024-4-11-arm作业

汇编实现三个灯的闪烁 源代码&#xff1a; .text .global _start _start: 时钟使能LDR r0,0x50000A28ldr r1,[r0]orr r1,r1,#(0x3<<4)str r1,[r0]设置PE10输出LDR r0,0x50006000ldr r1,[r0]bic r1,r1,#(0x3<<20)orr r1,r1,#(0x1<<20)str r1,[r0]设置PE1…

TikTok如何矩阵养号?TK防关联引流系统助力TK账号安全运营

TK是 TikTok旗下的短视频社交媒体&#xff0c;平台目前是全球最火的短视频平台&#xff0c;目前全球活跃用户已经超过8亿。其中 TikTok的用户已经达到8亿。TK这款短视频社交媒体平台在海外的发展潜力非常大&#xff0c;也是国内很多人的创业目标&#xff0c;很多人都想从 TK这个…

Lua脚本使用手册(Redis篇)

Lua脚本 **简介&#xff1a;**Lua是一种功能强大的&#xff0c;高效&#xff0c;轻量级&#xff0c;可嵌入的脚本语言。它是动态类型语言&#xff0c;通过使用基于寄存器的虚拟机解释字节码运行&#xff0c;并具有增量垃圾收集的自动内存管理&#xff0c;是配置&#xff0c;脚…

【源码】2024最新海外刷单抢单平台源码/自带利息宝/理财活动/带搭建教程

源码描述&#xff1a; 前台是单语言 全开源 可二开的版本 CD&#xff1a;获取方式联系小编 微信&#xff1a;uucodes 公众号&#xff1a;资源猿

房贷还款(C语言)

一、运行结果&#xff1b; 二、源代码&#xff1b; # define _CRT_SECURE_NO_WARNINGS # include <stdio.h> # include <math.h>int main() {//初始化变量值&#xff1b;double m, r 0.01;float d 300000;float p 6000;//运算还款所需月份&#xff1b;m log10…

婴儿洗衣机哪个牌子好?四款超值婴儿洗衣机汇总安利

婴儿洗衣机的优点很多&#xff0c;一是省时省力&#xff0c;二是安全卫生&#xff0c;虽说我们无法为孩子营造一个无菌的成长环境&#xff0c;但哪个宝妈宝爸不希望自己的孩子随时都能保持自己的清洁卫生呢&#xff1f;随着市场的不断增长&#xff0c;婴儿洗衣机的品牌也在不断…

llama-factory SFT系列教程 (一),大模型 API 部署与使用

文章目录 背景简介难点 前置条件1. 大模型 api 部署下一步阅读 背景 本来今天没有计划学 llama-factory&#xff0c;逐步跟着github的文档走&#xff0c;发现这框架确实挺方便&#xff0c;逐渐掌握了一些。 最近想使用 SFT 微调大模型&#xff0c;llama-factory 是使用非常广泛…

2024年MathorCup数学应用挑战赛C题思路分析(妈妈杯)

2024年第十四届MathorCup数学应用挑战赛C题解析 文章目录 题目概览第一问&#xff1a;货量预测第二问&#xff1a;运输线路变化的预测第三问&#xff1a;单目标优化第四问&#xff1a;排班计划的优化 MATLAB代码框架货量预测人员排班 2024年的MathorCup数学应用挑战赛再次为我…

在家如何查找下载外文文献

查找下载外文文献的数据库大部分都需要使用权限的&#xff0c;那么我们如何在家进入这些数据库查找下载文献资源呢&#xff1f;请看本文的经验分享&#xff1a; 举例1、 一位同学的文献求助&#xff1a;Performance of financial hedging and earnings management under dive…

关于亚马逊、速卖通等平台,成熟的自养号测评系统需具备哪些条件

在亚马逊等跨境电商平台的严格监管下&#xff0c;众多卖家和买家不幸遭遇了封号&#xff0c;这对于依赖线上销售的小型卖家来说无疑是沉重的打击。经过深入调查&#xff0c;发现大部分账号被封的根源在于底层环境搭建不当。不论是亚马逊还是其他跨境电商巨头如eBay、速卖通、虾…

第一节:什么是操作系统

什么是操作系统 一、一台计算机的组成部分1、计算机能干啥2、谈谈计算机硬件 二、什么是操作系统三、学习操作系统的层次 一、一台计算机的组成部分 如下图所示&#xff1a; 这就是就是构成一台计算机的组成部分 1、计算机能干啥 ∙ \bullet ∙计算机是我们专业吃饭的家伙&a…

绝地求生:杜卡迪来了,这些摩托车技巧不学一下吗?

摩托车在远古版本和现在完全不一样&#xff0c;虽然容易翻车造就了一批玩家“摩托杀手”的外号&#xff0c;但是速度可比今天快多了。 后来在蓝洞的削弱了其加速度&#xff0c;虽然资料上写着最高时速155km/h&#xff0c;但是平时游戏中一般只能拉到110~120km/h。这里写一点摩托…

电脑直播录屏软件怎么选?看这一篇就够了

随着网络直播的日益普及&#xff0c;越来越多的用户希望将直播内容保存下来&#xff0c;以供日后观看或分享。电脑直播录屏软件应运而生&#xff0c;它们不仅能够帮助用户实现录屏需求&#xff0c;还能保证录屏的高清和流畅。本文将介绍两种常用的电脑直播录屏软件&#xff0c;…