Vision Transformer(vit)的Embedding层结构

代码:

class PatchEmbed(nn.Module):
    """
    2D Image to Patch Embedding
    """
    def __init__(self, img_size=224, patch_size=16, in_c=3, embed_dim=768, norm_layer=None):
        super().__init__()
        img_size = (img_size, img_size)
#图像尺寸默认224×224
        patch_size = (patch_size, patch_size)
#划分图像的块大小,默认为 16×16
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
#将输入图像的高度和宽度分别除以图像块的高度和宽度,得到图像在高度和宽度方向上划分出的块的
#数量,例如:输入(224,224)除以patch:(16,16),得到数量(14,14)
        self.num_patches = self.grid_size[0] * self.grid_size[1]
#计算图像总共划分出的块的数量

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
#使用卷积层(nn.Conv2d)将每个补丁映射到一个高维向量
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
#根据传入的 norm_layer 参数决定是否应用归一化层

    def forward(self, x):
        B, C, H, W = x.shape
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
#检查输入图像的高度和宽度是否与类初始化时设置的 img_size 一致。如果不一致,就会抛出异常并
#给出相应的提示信息,确保输入图像的尺寸符合模型的预期。

        # flatten: [B, C, H, W] -> [B, C, HW]
        # transpose: [B, C, HW] -> [B, HW, C]
        x = self.proj(x).flatten(2).transpose(1, 2)
        x = self.norm(x)
#如果在初始化时传入了有效的归一化层 self.norm,那么就会对经过上述处理后的张量 x 进行归一#化处理
        return x

其中embed_dim=768:

embed_dim 是一个重要的参数,它代表嵌入的维度,即每个图像补丁(patch)经过嵌入层后所映射到的向量空间的维度。在深度学习和特别是在视觉变换器(Vision Transformers,简称ViT)模型中,embed_dim 通常用来指定模型中特征向量的维度。

默认值768的来源:在许多ViT模型的实现中,embed_dim 的默认值被设置为768,这个值来源于一些流行的预训练模型,如BERT,它在自然语言处理任务中取得了很好的效果。这个维度被证明能够有效地捕捉和表示数据的特征,因此在ViT模型中也被广泛采用。

操作:

代码: 

img_size = (img_size, img_size)
#图像尺寸默认224×224
        patch_size = (patch_size, patch_size)
#划分图像的块大小,默认为 16×16
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
#将输入图像的高度和宽度分别除以图像块的高度和宽度,得到图像在高度和宽度方向上划分出的块的
#数量,例如:输入(224,224)除以patch:(16,16),得到数量(14,14)
        self.num_patches = self.grid_size[0] * self.grid_size[1]
#计算图像总共划分出的块的数量

操作:

代码:

        self.proj = nn.Conv2d(in_c, embed_dim, kernel_size=patch_size, stride=patch_size)
#使用卷积层(nn.Conv2d)将每个补丁映射到一个高维向量
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
#根据传入的 norm_layer 参数决定是否应用归一化层

此处1、2、3...为位置编码暂时忽略

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/925364.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Spring Boot 实战:基于 Validation 注解实现分层数据校验与校验异常拦截器统一返回处理

1. 概述 本文介绍了在spring boot框架下,使用validation数据校验注解,针对不同请求链接的前端传参数据,进行分层视图对象的校验,并通过配置全局异常处理器捕获传参校验失败异常,自动返回校验出错的异常数据。 2. 依赖…

Linux查看网络基础命令

文章目录 Linux网络基础命令1. ifconfig 和 ip一、ifconfig命令二、ip命令 2. ss命令一、基本用法二、常用选项三、输出信息四、使用示例 3. sar 命令一、使用sar查看网络使用情况 4. ping 命令一、基本用法二、常用选项三、输出结果四、使用示例 Linux网络基础命令 1. ifconf…

Python酷库之旅-第三方库Pandas(245)

目录 一、用法精讲 1156、pandas.tseries.offsets.MonthEnd.is_month_start方法 1156-1、语法 1156-2、参数 1156-3、功能 1156-4、返回值 1156-5、说明 1156-6、用法 1156-6-1、数据准备 1156-6-2、代码示例 1156-6-3、结果输出 1157、pandas.tseries.offsets.Mon…

DAMODEL丹摩|Faster-Rcnn训练与部署实战

本文仅做测评体验,非广告。 文章目录 1. 丹摩介绍2. Faster-Rcnn介绍3. 准备3. 1 丹摩平台准备实例 3. 2 Faster-Rcnn4. 部署开始5. 训练5. 资源释放6. 结语 1. 丹摩介绍 详细介绍请看:丹摩平台介绍。 丹摩智算平台(DAMODEL)是…

NLP信息抽取大总结:三大任务(带Prompt模板)

信息抽取大总结 1.NLP的信息抽取的本质?2.信息抽取三大任务?3.开放域VS限定域4.信息抽取三大范式?范式一:基于自定义规则抽取(2018年前)范式二:基于Bert下游任务建模抽取(2018年后&a…

LLM*:路径规划的大型语言模型增强增量启发式搜索

路径规划是机器人技术和自主导航中的一个基本科学问题,需要从起点到目的地推导出有效的路线,同时避开障碍物。A* 及其变体等传统算法能够确保路径有效性,但随着状态空间的增长,计算和内存效率会严重降低。相反,大型语言…

C#基础题总结

16.一张单据上有一个5位数的号码为6**42,其中百位数和千位数已模糊不清,但知道该数能被 57 和 67 除尽。设计一个算法,找出该单据所有可能的号码。 17.编程序求2~10000以内的完全数。一个数的因子(除了这个数本身&…

Android数据存储——文件存储、SharedPreferences、SQLite、Litepal

数据存储全方案——详解持久化技术 Android系统中主要提供了3中方式用于简单地实现数据持久化功能,即文件存储、SharedPreference存储以及数据库存储。除了这三种方式外,还可以将数据保存在手机的SD卡中,不给使用文件、SharedPreference或者…

【动手学电机驱动】STM32-FOC(8)MCSDK Profiler 电机参数辨识

STM32-FOC(1)STM32 电机控制的软件开发环境 STM32-FOC(2)STM32 导入和创建项目 STM32-FOC(3)STM32 三路互补 PWM 输出 STM32-FOC(4)IHM03 电机控制套件介绍 STM32-FOC(5&…

ubuntu 安装proxychains

在Ubuntu上安装Proxychains,你可以按照以下步骤操作: 1、更新列表 sudo apt-update 2、安装Proxychains sudo apt-get install proxychains 3、安装完成后,你可以通过编辑/etc/proxychains.conf文件来配置代理规则 以下是一个简单的配置示例&…

ZooKeeper 基础知识总结

先赞后看,Java进阶一大半 ZooKeeper 官网这样介绍道:ZooKeeper 是一种集中式服务,用于维护配置信息、命名、提供分布式同步和提供组服务。 各位hao,我是南哥,相信对你通关面试、拿下Offer有所帮助。 ⭐⭐⭐一份南哥编写…

visionpro官方示例分析(一) 模板匹配工具 缺陷检测工具

1.需求:找出图像中的这个图形。 2.步骤 使用CogPMAlignTool工具,该工具是模板匹配工具,见名知意,所谓模板匹配工具就是说先使用该工具对一张图像建立模板,然后用这个模板在其他图像上进行匹配,匹配上了就说…

代码随想录算法训练营第六十天|Day60 图论

Bellman_ford 队列优化算法(又名SPFA) https://www.programmercarl.com/kamacoder/0094.%E5%9F%8E%E5%B8%82%E9%97%B4%E8%B4%A7%E7%89%A9%E8%BF%90%E8%BE%93I-SPFA.html 本题我们来系统讲解 Bellman_ford 队列优化算法 ,也叫SPFA算法&#xf…

LAMP环境的部署

一、软件安装介绍 在Linux系统中安装软件有rpm安装、yum安装、源码安装等方法,在这里主要给大家介绍 yum 安装,这是一种最简单方便的一种安装方法。 YUM(Yellow dog Upadate Modifie)是改进版的 RPM 管理器,很好地解…

搭建文件服务器并使用Qt实现文件上传和下载(带账号和密码)

文章目录 0 背景1 搭建文件服务器2 代码实现文件上传和下载2.1 在pro文件中添加网络支持2.2 创建网络管理类2.3 文件上传2.4 文件下载 3 扩展(其他方法实现文件上传和下载)3.1 python3.2 npm3.3 ftp服务器 4 完整的代码 0 背景 因为需要使程序具备在远程…

matlab导出3D彩色模型(surface类转stl,并对白模上色)

在matlab中绘制3维图形时,需要将3维图形导出到PPT中展示。但是直接导出图片效果欠佳,无法全方位展示。 最近学习了如何将matlab中的图形导出为stl模型,然后再采用简单的方法对模型上色。 中间尝试过matlab导出stl、ply、3dm等多种格式&…

Java项目中加缓存

Java项目中加缓存 1.更新频率低;但读写频率高的数据很适合加缓存; 2.可以加缓存的地方很多:浏览器的缓存;CDN的缓存;服务器的缓存; 本地内存;分布式远端缓存; 加缓存的时候不要…

VTK的基本概念(一)

文章目录 三维场景的基本要素1.灯光2.相机3.颜色4.纹理映射 三维场景的基本要素 1.灯光 在三维渲染场景中,可以有多个灯光的存在,灯光和相机是三维渲染场景的必备要素,如果没有指定的话,vtkRenderer会自动创建默认的灯光和相机。…

【C知道】数据包捕获(wire shark)

请解释一下数据包捕获和分析工具(如Wireshark)的工作原理和用途。 数据包捕获和分析工具,例如Wireshark(以前称为 Ethereal),是一种网络协议分析软件,它允许用户实时监控计算机网络中的数据传输…

浮点数计算,不丢失精度

在js中对于浮点数直接计算会存在精度丢失的情况,为了保证精度问题,可以做如下处理: 浮点数精度计算 主要流程如下: 浮点数转换成整数 示例代码如下 /** 将一个浮点数转成整数,返回整数和倍数。如 3.14 >> 314…