VIT总结

关于transformer、VIT和Swin T的总结

1.transformer

1.1.注意力机制

An attention function can be described as mapping a query and a set of key-value pairs to an output, where the query, keys, values, and output are all vectors. The output is computed as a weighted sum of the values, where the weight assigned to each value is computed by a compatibility function of the query with the corresponding key.[1]
输入是query和 key-value,注意力机制首先计算query与每个key的关联性(compatibility)每个关联性作为每个value的权重(weight),各个权重与value的乘积相加得到输出

Attention Is All You Need 中用到的attention叫做“Scaled Dot-Product Attention”,具体过程如下图所示:
在这里插入图片描述
代码实现:

import torch
import torch.nn as nn


class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert (self.head_dim * heads == embed_size), "Embed size needs  to  be div by heads"
        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]  # the number of training examples
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into self.heads pieces
        values = values.reshape(N, value_len, self.heads, self.head_dim)
        keys = keys.reshape(N, key_len, self.heads, self.head_dim)
        queries = query.reshape(N, query_len, self.heads, self.head_dim)

        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])
        # queries shape: (N, query_len, heads, heads_dim)
        # keys shape: (N, key_len, heads, heads_dim)
        # energy shape: (N, heads, query_len, key_len)

        if mask is not None:
            energy = energy.masked_fill(mask == 0, float("-1e20"))
            # Fills elements of self tensor with value where mask is True

        attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)
        out = torch.einsum("nhql, nlhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )
        # attention shape: (N, heads, query_len, key_len)
        # values shape: (N, value_len, heads, head_dim)
        # after einsum (N, query_len, heads, head_dim) then flatten last two dimensions

        out = self.fc_out(out)
        return out

1.为什么有mask?
NLP处理不定长文本需要padding,但是padding的内容无意义,所以处理时需要mask.
2.关于qkv
qkv是相同的,需要查询的q,与每一个key相乘得到权重信息,权重与v相乘,这样结果受权重大的v影响
3.为什么除以根号dk

We suspect that for large values of dk, the dot products grow large in magnitude, pushing the softmax function into regions where it has extremely small gradients 4. To counteract this effect, we scale the dot products by 1 √dk
点积过大,经过softmax,进入饱和区,梯度很小

4.为什么需要多头
在这里插入图片描述
不同头部的output就是从不同层面(representation subspace)考虑关联性而得到的输出。

1.2.TransformerBlock

解码端的后面两部分和编码段一样,所以打包成一个类
在这里插入图片描述

class TransformerBlock(nn.Module):
    def __init__(self, embed_size, heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(embed_size, heads)
        self.norm1 = nn.LayerNorm(embed_size)
        self.norm2 = nn.LayerNorm(embed_size)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, forward_expansion * embed_size),
            nn.ReLU(),
            nn.Linear(forward_expansion * embed_size, embed_size)
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, value, key, query, mask):
        attention = self.attention(value, key, query, mask)

        x = self.dropout(self.norm1(attention + query))
        forward = self.feed_forward(x)
        out = self.dropout(self.norm2(forward + x))
        return out

1.3.Encoder

关键的就是位置编码

class Encoder(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 embed_size,
                 num_layers,
                 heads,
                 device,
                 forward_expansion,
                 dropout,
                 max_length
                 ):
        super(Encoder, self).__init__()
        self.embed_size = embed_size
        self.device = device
        self.word_embedding = nn.Embedding(src_vocab_size, embed_size)
        self.position_embedding = nn.Embedding(max_length, embed_size)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(
                    embed_size,
                    heads,
                    dropout=dropout,
                    forward_expansion=forward_expansion
                )
                for _ in range(num_layers)]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_lengh = x.shape
        positions = torch.arange(0, seq_lengh).expand(N, seq_lengh).to(self.device)
        out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)

        return out

2.VIT

在这里插入图片描述

Reference:

[1].Attention Is All You Need
[2].https://zhuanlan.zhihu.com/p/366592542
[3].代码实现:https://zhuanlan.zhihu.com/p/653170203
[4].An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/211872.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Windows微软常用运行库合集2023

微软常用运行库合集适用于Windows系统的运行库合集包,基于微软官方的运行库而制作的,包括了常用的vb,vc2005/2008/2010/2012/2013/2017/2019/2005-2022,Microsoft Universal C Runtime,VS 2010 Tools For Office Runti…

Windows远程桌面提示出现身份验证错误 要求的函数不支持

现象 解决方案: 在cmd运行框输入:gpedit.msc打开组策略编辑器路径:计算机配置→管理模板→Windows组件→远程桌面服务→远程桌面会话主机→安全开启远程连接要求使用指定的安全层 禁用要求使用网络级别的身份验证对远程连接的用户进行身份验…

HttpRunner自动化测试之实现参数化传递

参数化实现及重复执行 参数化测试:在接口测试中,为了实现不同组数据对同一个功能模块进行测试,需要准备多组测试数据对模块进行测试的过程。 在httprunner中可以通过如下方式实现参数化: 1、在YAML/JSON 中直接指定参数列表 2、…

Python标准库:math库【侯小啾python领航班系列(十六)】

Python标准库:math库【侯小啾python领航班系列(十六)】 大家好,我是博主侯小啾, 🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ…

面试 Java 基础八股文十问十答第三期

面试 Java 基础八股文十问十答第三期 作者:程序员小白条,个人博客 ⭐点赞⭐收藏⭐不迷路!⭐ 21.说下Java8的Stream流的常用方法 答: forEach遍历、find、match进行匹配reduce进行归约,比如求和,乘,除聚合…

上门按摩APP小程序,抓住机遇创新服务新模式;

上门按摩APP小程序:抓住机遇,创新服务新模式; 随着现代人对生活质量要求的提高,上门按摩服务正成为一种新的、受欢迎的生活方式。通过APP小程序,用户可以轻松预约按摩服务,解决身体疲劳问题,享受…

汇编语言实现音乐播放器

目标程序 用汇编语言实现一个音乐播放器,并支持点歌 Overview 乐曲是按照一定的高低、长短和强弱关系组成的音调,在一首乐曲中,每个音符的音高和音长与频率和节拍有关,因此我们要分别为3首要演奏的乐曲定义一个频率表和一个节拍…

Java Web——动态Web开发核心-Servlet

1. 官方文档 官方文档地址:Overview (Servlet 4.0 API Documentation - Apache Tomcat 9.0.83) servlet 与 Tomcat 的关系:Tomcat 支持 Servlet Tomcat 是一个开源的 Java 服务器,它主要用来提供 Web 服务,包括 HTTP 请求和响应…

<JavaEE> volatile关键字 -- 保证内存可见性、禁止指令重排序

目录 一、内存可见性 1.1 Java内存模型(JMM) 1.2 内存可见性演示 二、指令重排序 三、关键字 volatile 一、内存可见性 1.1 Java内存模型(JMM) 1)什么是Java内存模型(JMM)?Java内存模型即Java Memory Model,简…

【重点】【双指针】11. 盛最多水的容器

题目 注意&#xff1a;二维接雨水&#xff0c;有墙的&#xff0c;有线的&#xff0c;着这个属于线的。 class Solution {public int maxArea(int[] height) {if (height.length < 2) {return 0;}int left 0, right height.length - 1, res 0;while (left < right) {…

三轴加速度计LIS2DW12开发(2)----基于中断信号获取加速度数据

三轴加速度计LIS2DW12开发.2--轮基于中断信号获取加速度数据 概述视频教学样品申请生成STM32CUBEMX串口配置IIC配置CS和SA0设置INT1设置串口重定向参考程序初始换管脚获取ID复位操作BDU设置开启INT1中断设置传感器的量程配置过滤器链配置电源模式设置输出数据速率中断判断加速…

【动态规划】LeetCode-931.下降路径最小和

&#x1f388;算法那些事专栏说明&#xff1a;这是一个记录刷题日常的专栏&#xff0c;每个文章标题前都会写明这道题使用的算法。专栏每日计划至少更新1道题目&#xff0c;在这立下Flag&#x1f6a9; &#x1f3e0;个人主页&#xff1a;Jammingpro &#x1f4d5;专栏链接&…

深度学习记录--logistic回归函数的计算图

计算图用于logistic回归函数 先回顾一下单一样本的logistic回归损失函数的公式&#xff0c;公式如下&#xff1a; 将logistic函数用计算图表示出来(以两个基础量为例)&#xff0c;计算图如下&#xff1a; 前向传播已经完成&#xff0c;接下来完成后向传播 运用链式法则依次求…

Vue---Echarts

项目需要用echarts来做数据展示&#xff0c;现记录vue3引入并使用echarts的过程。 1. 使用步骤 安装 ECharts&#xff1a;使用 npm 或 yarn 等包管理工具安装 ECharts。 npm install echarts 在 Vue 组件中引入 ECharts&#xff1a;在需要使用图表的 Vue 组件中&#xff0c;引入…

如何选择适合的光电传感器与 STM32 微控制器进行接口设计

本文介绍了如何选择适合的光电传感器与 STM32 微控制器进行接口设计的方法。首先我们将介绍一些选择光电传感器的关键因素&#xff0c;包括测量范围、响应时间、分辨率和输出类型。然后我们将介绍如何根据所选传感器的特性进行硬件连接和接口设计。最后&#xff0c;我们将提供示…

20231201将RK3399的挖掘机开发板在Andorid12系统下的强制横屏

20231201将RK3399的挖掘机开发板在Andorid12系统下的强制横屏 2023/12/1 22:54 【不完美的地方&#xff1a;修改之后不满屏】 百度&#xff1a;rk3399 android12 横屏 不满屏 Android 显示不满屏 build.prop https://blog.csdn.net/weixin_39966398/article/details/105595184?…

9.ROS的TF坐标变换(三):坐标系关系查看与一个案例

1 查看目前的坐标系变化 我们先安装功能包&#xff1a; sudo apt install ros-melodic-tf2-tools安装成功&#xff01; 我们先启动上次的发布坐标变换的节点&#xff1a; liuhongweiliuhongwei-Legion-Y9000P-IRX8H:~/Desktop/final/my_catkin$ source devel/setup.bash liuho…

RocketMQ- 深入理解RocketMQ的消息模型

1、RocketMQ客户端基本流程 ​ RocketMQ基于Maven提供了客户端的核心依赖&#xff1a; <dependency><groupId>org.apache.rocketmq</groupId><artifactId>rocketmq-client</artifactId><version>4.9.5</version> </dependency&…

RxJava

Single 使用 Flowable 比较重一般使用Single onSubscribe 产生订阅时调用 线程切换1 2 发送顺序事件.just just 源码 钩子方法,进行验证再处理 Single 对象 订阅,RxJavaPlugins.onSubscribe 钩子方法,产生订阅和过滤 Single 核心方法,抽象的,实现为SingleJust 订阅和执行成功回…

同旺科技 USB TO SPI / I2C --- 调试W5500_Ping测试

所需设备&#xff1a; 内附链接 1、USB转SPI_I2C适配器(专业版); 首先&#xff0c;连接W5500模块与同旺科技USB TO SPI / I2C适配器&#xff0c;如下图&#xff1a; 设置寄存器&#xff1a; SHAR&#xff08;源MAC地址寄存器&#xff09;&#xff0c;该寄存器用来设置源MAC…