昇思25天学习打卡营第16天|应用实践之Vision Transformer图像分类

基本介绍

        今天同样是图像分类任务,也更换了模型,使用的时候计算机视觉版的Transformer,即Vision Transformer,简称ViT。Transformer本是应用于自然语言处理领域的模型,用于处理语言序列,而要将其应用于图像,关键在于如何把图像转化为序列数据。ViT对图像进行分割,分割成一个一个的patch,然后再加上空间编码,由此便把图像转化为序列数据,使得Transformer也可以应用于计算机视觉领域。下面会对ViT进行简单介绍,然后使用ImageNet数据集进行训练,简单训练10轮,并进行推理,以进一步了解ViT。

Vision Transformer简介

        ViT模型的主体结构是基于Transformer模型的Encoder部分,结构如下图所示

ViT由Transformer变换而成,而Transformer的核心是Self-Attention,要学习ViT就得搞懂Self-Attention。Self-Attention的核心内容是为输入向量的每个单词学习一个权重。通过给定一个任务相关的查询向量Query向量,计算Query和各个Key的相似性或者相关性得到注意力分布,即得到每个Key对应Value的权重系数,然后对Value进行加权求和得到最终的Attention数值。具体如下(以下的Self-Attention计算过程来自MindSpore官方教程,并非本人原创):

1. 最初的输入向量首先会经过Embedding层映射成Q(Query),K(Key),V(Value)三个向量,由于是并行操作,所以代码中是映射成为dim x 3的向量然后进行分割,换言之,如果你的输入向量为一个向量序列(𝑥1,𝑥2,𝑥3),其中的𝑥1,𝑥2,𝑥3都是一维向量,那么每一个一维向量都会经过Embedding层映射出Q,K,V三个向量,只是Embedding矩阵不同,矩阵参数也是通过学习得到的。这里大家可以认为,Q,K,V三个矩阵是发现向量之间关联信息的一种手段,需要经过学习得到,至于为什么是Q,K,V三个,主要是因为需要两个向量点乘以获得权重,又需要另一个向量来承载权重向加的结果,所以,最少需要3个矩阵。

2. 自注意力机制的自注意主要体现在它的Q,K,V都来源于其自身,也就是该过程是在提取输入的不同顺序的向量的联系与特征,最终通过不同顺序向量之间的联系紧密性(Q与K乘积经过Softmax的结果)来表现出来。Q,K,V得到后就需要获取向量间权重,需要对Q和K进行点乘并除以维度的平方根,对所有向量的结果进行Softmax处理,通过公式(2)的操作,我们获得了向量之间的关系权重

3. 其最终输出则是通过V这个映射后的向量与Q,K经过Softmax结果进行weight sum获得,这个过程可以理解为在全局上进行自注意表示。每一组Q,K,V最后都有一个V输出,这是Self-Attention得到的最终结果,是当前向量在结合了它与其他向量关联权重后得到的结果。

有了Self-Attention结构之后,通过与Feed Forward,Residual Connection等结构的拼接就可以形成Transformer的基础结构,如下图所示

ViT就是由上述的结构搭建而成。ViT的完整使用流程如下:

对ViT有了基本了解后,我们上手代码,加深理解。ViT(MindSpore版)的代码如下:

class ViT(nn.Cell):
    def __init__(self,
                 image_size: int = 224,
                 input_channels: int = 3,
                 patch_size: int = 16,
                 embed_dim: int = 768,
                 num_layers: int = 12,
                 num_heads: int = 12,
                 mlp_dim: int = 3072,
                 keep_prob: float = 1.0,
                 attention_keep_prob: float = 1.0,
                 drop_path_keep_prob: float = 1.0,
                 activation: nn.Cell = nn.GELU,
                 norm: Optional[nn.Cell] = nn.LayerNorm,
                 pool: str = 'cls') -> None:
        super(ViT, self).__init__()

        self.patch_embedding = PatchEmbedding(image_size=image_size,
                                              patch_size=patch_size,
                                              embed_dim=embed_dim,
                                              input_channels=input_channels)
        num_patches = self.patch_embedding.num_patches

        self.cls_token = init(init_type=Normal(sigma=1.0),
                              shape=(1, 1, embed_dim),
                              dtype=ms.float32,
                              name='cls',
                              requires_grad=True)

        self.pos_embedding = init(init_type=Normal(sigma=1.0),
                                  shape=(1, num_patches + 1, embed_dim),
                                  dtype=ms.float32,
                                  name='pos_embedding',
                                  requires_grad=True)

        self.pool = pool
        self.pos_dropout = nn.Dropout(p=1.0-keep_prob)
        self.norm = norm((embed_dim,))
        self.transformer = TransformerEncoder(dim=embed_dim,
                                              num_layers=num_layers,
                                              num_heads=num_heads,
                                              mlp_dim=mlp_dim,
                                              keep_prob=keep_prob,
                                              attention_keep_prob=attention_keep_prob,
                                              drop_path_keep_prob=drop_path_keep_prob,
                                              activation=activation,
                                              norm=norm)
        self.dropout = nn.Dropout(p=1.0-keep_prob)
        self.dense = nn.Dense(embed_dim, num_classes)

    def construct(self, x):
        """ViT construct."""
        x = self.patch_embedding(x)
        cls_tokens = ops.tile(self.cls_token.astype(x.dtype), (x.shape[0], 1, 1))
        x = ops.concat((cls_tokens, x), axis=1)
        x += self.pos_embedding

        x = self.pos_dropout(x)
        x = self.transformer(x)
        x = self.norm(x)
        x = x[:, 0]
        if self.training:
            x = self.dropout(x)
        x = self.dense(x)

        return x

模型训练

       由于数据集准备并不难,所以不做展示,直接使用模型进行训练,训练代码如下:

# define super parameter
epoch_size = 10
momentum = 0.9
num_classes = 1000
resize = 224
step_size = dataset_train.get_dataset_size()

# construct model
network = ViT()

# load ckpt
vit_url = "https://download.mindspore.cn/vision/classification/vit_b_16_224.ckpt"
path = "./ckpt/vit_b_16_224.ckpt"

vit_path = download(vit_url, path, replace=True)
param_dict = ms.load_checkpoint(vit_path)
ms.load_param_into_net(network, param_dict)

# define learning rate
lr = nn.cosine_decay_lr(min_lr=float(0),
                        max_lr=0.00005,
                        total_step=epoch_size * step_size,
                        step_per_epoch=step_size,
                        decay_epoch=10)

# define optimizer
network_opt = nn.Adam(network.trainable_params(), lr, momentum)


# define loss function
class CrossEntropySmooth(LossBase):
    """CrossEntropy."""

    def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
        super(CrossEntropySmooth, self).__init__()
        self.onehot = ops.OneHot()
        self.sparse = sparse
        self.on_value = ms.Tensor(1.0 - smooth_factor, ms.float32)
        self.off_value = ms.Tensor(1.0 * smooth_factor / (num_classes - 1), ms.float32)
        self.ce = nn.SoftmaxCrossEntropyWithLogits(reduction=reduction)

    def construct(self, logit, label):
        if self.sparse:
            label = self.onehot(label, ops.shape(logit)[1], self.on_value, self.off_value)
        loss = self.ce(logit, label)
        return loss


network_loss = CrossEntropySmooth(sparse=True,
                                  reduction="mean",
                                  smooth_factor=0.1,
                                  num_classes=num_classes)

# set checkpoint
ckpt_config = CheckpointConfig(save_checkpoint_steps=step_size, keep_checkpoint_max=100)
ckpt_callback = ModelCheckpoint(prefix='vit_b_16', directory='./ViT', config=ckpt_config)

# initialize model
# "Ascend + mixed precision" can improve performance
ascend_target = (ms.get_context("device_target") == "Ascend")
if ascend_target:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O2")
else:
    model = train.Model(network, loss_fn=network_loss, optimizer=network_opt, metrics={"acc"}, amp_level="O0")

# train model
model.train(epoch_size,
            dataset_train,
            callbacks=[ckpt_callback, LossMonitor(125), TimeMonitor(125)],
            dataset_sink_mode=False,)

完整训练的话起码有80个轮次,时间太长,再加上我们使用了预训练参数,所以我们只训练10轮

模型验证

        与训练过程相似,首先进行数据增强,然后定义ViT网络结构,加载预训练模型参数。随后设置损失函数,评价指标等,编译模型后进行验证。本案例采用了业界通用的评价标准Top_1_Accuracy和Top_5_Accuracy评价指标来评价模型表现。模型表现如下:

因为预训练参数的原因,效果还是不错的

模型推理

        使用一张杜宾犬的图片进行预测,结果如下,是准确的。

总结

        今日学习使用ViT,若之前对Attention完全没有了解,直接上手难度很大的,不过官方文档写的很好,加上本人有些Transformer的基础,所以认真花费一些时间,结合代码,对ViT的结构和流程有了一个基本了解。ViT可以应用的任务很多,希望下次可以尝试将其应用到目标检测。

Jupyter在线运行情况

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/788446.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CentOS6用文件配置IP模板

CentOS6用文件配置IP模板 到 CentOS6.9 , 默认还不能用 systemctl , 能用 service chkconfig sshd on 对应 systemctl enable sshd 启用,开机启动该服务 ### chkconfig sshd on 对应 systemctl enable sshd 启用,开机启动该服务 sudo chkconfig sshd onservice sshd start …

三级_网络技术_12_路由设计技术基础

1.R1、R2是一个自治系统中采用RIP路由协议的两个相邻路由器,R1的路由表如下图(a)所示,当R1收到R2发送的如下图(b)的(V.D)报文后,R1更新的4个路由表项中距离值从上到下依次为0、3、3、4 那么,①②③④可能的取值依次为()。 0、4、…

LaySNS模板仿RiPro日主题素材源码资源下载响应式CMS模板

该主题是网上泛滥的RiPro主题仿制而成的laysns模板,原主题是很强大的, 全站功能是通过ajax响应实现的,但本人技术有限,只会仿,不会移植,(主要ajax这里不知道怎么弄)。 另外就是网上…

【linux服务器篇】-Redis-RDM远程连接redis

redis desktop manager 使用远程连接工具RDM连接redis 市面上比较常见的其中一款工具redis desktop manager 简单的说: Redis Desktop Manager 简单的来讲就是Redis可视化工具,可以让我们看到Redis中存储的内容。 redis desktop manager是一款功能强…

大型综合医院、妇幼保健院智慧产科信息系统源码,支持二次开发,授权后可商用。

一套采用java语言开发,前端框架为Vue,ElementUIMySQL数据库,前后端分离架构的数字化产科管理系统源码,自主版权,多个大型综合医院、妇幼保健院应用案例,支持二次开发,授权后可商用。 系统特点&a…

Qt/QML学习-ListView

QML学习 ListView例程视频讲解代码 main.qml import QtQuick 2.15 import QtQuick.Window 2.15 import QtQuick.Controls 2.15Window {id: windowwidth: 640height: 480visible: truetitle: qsTr("ListView")Rectangle {height: listView.heightwidth: listView.wi…

Pearson 相关系数的可视化辅助判断和怎么用

Pearson 相关系数的可视化辅助判断和怎么用 flyfish Pearson 相关系数 是一种用于衡量两个连续型变量之间线性相关程度的统计量。其定义为两个变量协方差与标准差的乘积的比值。公式如下: r ∑ ( x i − x ˉ ) ( y i − y ˉ ) ∑ ( x i − x ˉ ) 2 ∑ ( y i −…

328. 奇偶链表

https://leetcode.cn/problems/odd-even-linked-list/https://leetcode.cn/problems/odd-even-linked-list/ 解题思路: 把第一个和第二个节点分别作为奇数、偶数的头节点,当遇到奇节点,删除,并插入到奇数头节点后,这样…

【普中】基于51单片机的矩阵电子密码锁LCD1602液晶显示 proteus仿真+程序+设计报告+讲解视频

【普中】基于51单片机的矩阵电子密码锁LCD1602液晶显示设计 1.主要功能:讲解视频:2.仿真3. 程序代码4. 设计报告5. 设计资料内容清单&&下载链接资料下载链接: 【普中】基于51单片机的矩阵电子密码锁LCD1602液晶显示设计 ( proteus仿真…

如何切换手机的ip地址

在数字时代的浪潮中,智能手机已成为我们日常生活中不可或缺的一部分。然而,随着网络安全问题的日益凸显,保护个人隐私和数据安全变得尤为重要。其中,IP地址作为网络身份的重要标识,其安全性与隐私性备受关注。本文将详…

使用 Hugging Face 的 Transformers 库加载预训练模型遇到的问题

题意: Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Huggingface PyTorch 这个错误信息 "Size mismatch for embed_out.weight: copying a param with shape torch.Size([0]) from checkpoint - Hugg…

[高频 SQL 50 题(基础版)]第一千七百五十七题,可回收且低脂产品

题目: 表:Products ---------------------- | Column Name | Type | ---------------------- | product_id | int | | low_fats | enum | | recyclable | enum | ---------------------- product_id 是该表的主键(具有唯…

解决树形表格 第一列中文字没有对齐

二级分类与一级分类的文字没有对齐 <el-table:data"templateStore.hangyeList"style"width: 100%"row-key"id":tree-props"{ children: subData, hasChildren: hasChildren }" ><el-table-column prop"industryCode&quo…

Kettle常用参数配置

目录 一、时区二、时间戳三、tinyint类型转换 一、时区 Kettle链接mysql出现报错&#xff1a;Connection failed. Verify all connection parameters and confirm that the appropriate driver is installed. The server time zone value is unrecognized or represents more…

无人机之穿越机注意事项篇

一、检查设备 每次飞行前都要仔细检查穿越机的每个部件&#xff0c;确保所有功能正常&#xff0c;特别是电池和电机。 二、遵守法律 了解并遵循你所在地区关于无人机的飞行规定&#xff0c;避免非法飞行。 三、评估环境 在飞行前检查周围环境&#xff0c;确保没有障碍物和…

[激光原理与应用-102]:南京科耐激光-激光焊接-焊中检测-智能制程监测系统IPM介绍 - 6 - 激光焊接系统的组成

目录 一、激光焊接系统的组成概述 1.1、核心部件 1.2、焊接执行部件 1.3、辅助系统 1.4、控制系统 1.5、其他辅助设备 二、激光器 2.1 按出光类型分 1. 脉冲激光器 2. 连续激光器 3. 准连续激光器&#xff08;QCW&#xff09; 4. 其他常见激光器 5. 应用领域 2.2…

HTTP入门

目录 1. 原理介绍 2. HTTP协议简介 2.1简介 ​编辑 2.2基本工作原理 2.3HTTP三个要点 3. chrome浏览器和开发者工具 4. HTTP的消息结构 4.1主要流程和概念 4.2请求和响应例子 5. 完整的网页请求过程 6. 请求 6.1 请求行 6.2请求方法 6.3请求参数 GET请求的参数…

【问题记录】Windows中Node的express无法直接识别

问题描述 在使用express_generator的时候windows平台中出现无法识别express命令的问题&#xff0c;另外就算添加了全局环境变量也没用。 问题解决 查看官方文档发现在node版本8之前的时候使用的是express&#xff0c;但是之后的版本使用npx&#xff0c;这个工具的出现主要想…

头歌资源库(23)资源分配

一、 问题描述 某工业生产部门根据国家计划的安排&#xff0c;拟将某种高效率的5台机器&#xff0c;分配给所属的3个工厂A,B,C&#xff0c;各工厂在获得这种机器后&#xff0c;可以为国家盈利的情况如表1所示。问&#xff1a;这5台机器如何分配给各工厂&#xff0c;才能使国家盈…

Flutter 开启混淆打包apk,并反编译apk确认源码是否被混淆

第一步&#xff1a;开启混淆并打包apk flutter build apk --obfuscate --split-debug-info./out/android/app.android-arm64.symbols 第二步&#xff1a;从dex2jar download | SourceForge.net 官网下载dex2jar 下载完终端进入该文件夹&#xff0c;然后运行以下命令就会在该…