【深度学习】解析Vision Transformer (ViT): 从基础到实现与训练

之前介绍:

https://qq742971636.blog.csdn.net/article/details/132061304

文章目录

  • 背景
      • 实现代码示例
      • 解释
  • 训练
      • 数据准备
      • 模型定义
      • 训练和评估
      • 总结

在这里插入图片描述

Vision Transformer(ViT)是一种基于transformer架构的视觉模型,它最初是由谷歌研究团队在论文《An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale》中提出的。ViT将图像分割成固定大小的patches(例如16x16),并将每个patch视为一个词(类似于NLP中的单词)进行处理。以下是ViT的详细讲解:

背景

在计算机视觉领域,传统的卷积神经网络(CNNs)一直是处理图像的主流方法。然而,CNNs存在一些局限性,如在处理长距离依赖关系时表现不佳。ViT引入了transformer架构,通过全局注意力机制,有效地处理图像中的长距离依赖关系。

实现代码示例

ViT代码:

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = img_size // patch_size
        self.num_patches = self.grid_size ** 2

        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)  # [B, embed_dim, H, W]
        x = x.flatten(2)  # [B, embed_dim, num_patches]
        x = x.transpose(1, 2)  # [B, num_patches, embed_dim]
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = head_dim ** -0.5

        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
        self.drop_path = nn.Identity() if drop_path == 0 else nn.Dropout(drop_path)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(in_features=dim, hidden_features=int(dim * mlp_ratio), drop=drop)

    def forward(self, x):
        x = x + self.drop_path(self.attn(self.norm1(x)))
        x = x + self.drop_path(self.mlp(self.norm2(x)))
        return x

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_channels=3, num_classes=1000, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
        super().__init__()
        self.num_classes = num_classes
        self.num_features = self.embed_dim = embed_dim

        self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        num_patches = self.patch_embed.num_patches

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
        self.pos_drop = nn.Dropout(p=drop_rate)

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        self.blocks = nn.ModuleList([
            Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i])
            for i in range(depth)])
        self.norm = nn.LayerNorm(embed_dim)

        self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()

        nn.init.trunc_normal_(self.pos_embed, std=0.02)
        nn.init.trunc_normal_(self.cls_token, std=0.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            nn.init.trunc_normal_(m.weight, std=0.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.zeros_(m.bias)
        elif isinstance(m, nn.LayerNorm):
            nn.init.ones_(m.weight)
            nn.init.zeros_(m.bias)

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        cls_token_final = x[:, 0]
        x = self.head(cls_token_final)
        return x

# 示例输入
img = torch.randn(1, 3, 224, 224)
model = VisionTransformer()
output = model(img)
print(output.shape)  # 输出大小为 [1, 1000]

解释

  1. PatchEmbedding:将输入图像分割为不重叠的patches,并通过卷积操作将其转换为embedding。
  2. Attention:实现自注意力机制。
  3. MLP:实现多层感知器(MLP),包括GELU激活函数和Dropout。
  4. Block:包含一个注意力层和一个MLP层,每层都有残差连接和层归一化。
  5. VisionTransformer:组合上述模块,形成完整的ViT模型。包含位置嵌入和分类头。

训练

为了在GPU上训练ViT模型,你可以使用PyTorch中的DataLoader来处理数据,并确保模型和数据都在GPU上。以下是一个详细的代码示例,包括数据准备、模型定义、训练和评估。

数据准备

假设你的数据结构如下:

dataset/
    class1/
        img1.jpg
        img2.jpg
        ...
    class2/
        img1.jpg
        img2.jpg
        ...
    ...

你可以使用 torchvision.datasets.ImageFolder 来加载数据。

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

# 数据转换和增强
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# 加载数据
data_dir = 'dataset'
train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_dir, 'val'), transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)

# 获取类别数
num_classes = len(train_dataset.classes)

模型定义

定义ViT模型并将其移动到GPU上。

# VisionTransformer定义(使用上面的定义)
model = VisionTransformer(num_classes=num_classes).cuda()

# 损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

# 如果有多个GPU,使用DataParallel
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

训练和评估

定义训练和评估函数,并进行训练。

def train_one_epoch(model, criterion, optimizer, data_loader, device):
    model.train()
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in tqdm(data_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = running_corrects.double() / len(data_loader.dataset)
    return epoch_loss, epoch_acc


def evaluate(model, criterion, data_loader, device):
    model.eval()
    running_loss = 0.0
    running_corrects = 0

    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

    epoch_loss = running_loss / len(data_loader.dataset)
    epoch_acc = running_corrects.double() / len(data_loader.dataset)
    return epoch_loss, epoch_acc


# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
num_epochs = 25

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, criterion, optimizer, train_loader, device)
    val_loss, val_acc = evaluate(model, criterion, val_loader, device)

    print(f'Epoch {epoch}/{num_epochs - 1}')
    print(f'Train Loss: {train_loss:.4f} Acc: {train_acc:.4f}')
    print(f'Val Loss: {val_loss:.4f} Acc: {val_acc:.4f}')

# 保存模型
torch.save(model.state_dict(), 'vit_model.pth')

总结

这段代码展示了如何使用PyTorch在GPU上训练Vision Transformer模型。包括数据加载、模型定义、训练和评估步骤。请根据你的实际需求调整批量大小、学习率和训练轮数等参数。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/714960.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

29.添加录入注入信息界面

上一个内容:28.启动与暂停程序 以 28.启动与暂停程序 它的代码为基础进行修改 效果图: 新建Dialog 给新建的dialog添加空间,如下图 给每个输入框创建一个变量 代码: void CWndAddGame::OnBnClickedButton1() {static TCHAR BASE…

基于springboot的学生宿舍管理系统(带 1w+字文档)

基于springboot的学生宿舍管理系统(带 1w字文档) 基于 springbootvue 前后端分离的学生宿舍管理系统:前端 vue2、elementui,后端 maven、springmvc、spring、mybatis; 项目简介 本项目可供学习参考,商业慎用。项目带完整安装部署…

FPGA----petalinux开机启动自定义脚本/程序的保姆级教程

1、petalinux的重启命令:reboot、关机命令:shutdown -h now、开机按键:在关机后,ZCU106的右上角指示灯会变为红色,此时按下左上角第一个按键可启动操作系统。 2、好久没写博客了,本次给大家带来的是petalin…

记录一次centos扩容

背景 在Vscode上连虚拟机写项目,突然提示磁盘空间不足(no space left on device),一开始打算删些东西,这里参考博客,写得挺清楚的,但是操作后我发现实在没啥文件可以删除,所以干脆不删了,直接扩…

爱心代码来喽

今天给大家分享一个爱心代码&#xff0c;送给我的粉丝们。愿你们天天开心&#xff0c;事事顺利&#xff0c;学业和事业有成。 下面是运行代码&#xff1a; #include<stdio.h> #include<Windows.h> int main() { system(" color 0c"); printf(&q…

【百度智能体】零代码创建职场高情商话术助手智能体

一、前言 作为一个程序猿&#xff0c;工科男思维&#xff0c;走上职场后&#xff0c;总会觉得自己不会处理人际关系&#xff0c;容易背锅说错话&#xff0c;这时候如果有个助手能够时时刻刻提醒自己该如何说话如何做事情就好了。 而我们现在可以通过百度文心智能体平台构建各…

c++编程(18)——deque的模拟实现(2)容器篇

欢迎来到博主的专栏——c编程 博主ID&#xff1a;代码小豪 文章目录 deque的数据结构deque的构造默认构造填充构造 deque的其他操作deque的插入、删除push_back和push_frontpop_back和pop_frontclear、erase和insert操作 传送门 在上一篇中&#xff0c;我们已经实现了deque最核…

循环队列

循环队列是一种线性数据结构&#xff0c;其操作表现基于 FIFO&#xff08;First In First Out&#xff0c;先进先出&#xff09;原则并且队尾被连接在队首以形成一个循环。 这种结构克服了普通队列在元素入队和出队时需要移动大量元素的缺点。 在循环队列中&#xff0c;当元素…

Centos实现Mysql8.4安装及主主同步

8.4的Msyql在同步的时候与之前的版本有很大不同&#xff0c;这里记录一下安装流程 Mysql安装 官网下载 选择自己的版本&#xff0c;选第一个 复制下载链接 在服务器上创建一个msyql目录 使用命令下载,链接换自己的 wget https://dev.mysql.com/get/mysql84-community-relea…

跟着刘二大人学pytorch(第---10---节课之卷积神经网络)

文章目录 0 前言0.1 课程链接&#xff1a;0.2 课件下载地址&#xff1a; 回忆卷积卷积过程&#xff08;以输入为单通道、1个卷积核为例&#xff09;卷积过程&#xff08;以输入为3通道、1个卷积核为例&#xff09;卷积过程&#xff08;以输入为N通道、1个卷积核为例&#xff09…

接口测试工作准备

前面已经讲了接口测试的原理&#xff0c;接下来讲接口测试如何准备。分为了解项目背景、收集项目相关资料、部署接口测试环境。 1、了解项目背景 1、首先我们应该去了解项目的应用范围&#xff0c;了解业务场景需要调用的接口&#xff0c;确定接口测试的接口个数、接口名字、接…

Spring配置那些事

一、引言 配置是一个项目中不那么起眼&#xff0c;但却有非常重要的东西。在工程项目中&#xff0c;我们一般会将可修改、易变、不确定的值作为配置项&#xff0c;在配置文件/配置中心中设置。 比方说&#xff0c;不同环境有不同的数据库地址、不同的线程池大小等&#xff0c…

创建STM32F10X空项目教程

创建STM32F10X系列的空项目工程 官网下载STM32标准外设软件库 STM32标准外设软件库 创建一个空文件夹作为主工程文件夹在主工程文件夹中&#xff0c;创建三个空文件夹 CMSIS - 存放内核函数及启动引导文件 FWLIB - 存放库函数 USER - 存放用户的函数将STM32标准外设软件库文件…

Git学习记录v1.0

1、常用操作 git clonegit configgit branchgitt checkoutgit statusgit addgit commitgit pushgit pullgit loggit tag 1.1 git clone 从git服务器拉取代码 git clone https://gitee.com/xxx/studyJava.git1.2 git config 配置开发者用户名和邮箱 git config user.name …

Javaweb04-Servlet技术2(HttpServletResponse, HttpServletRequest)

Servlet技术基础 HttpServletResponse对象 HttpServletResponce对象是继承ServletResponse接口&#xff0c;专门用于封装Http请求 HttpServletResponce有关响应行的方法 方法说明功能描述void setStatus(int stauts)用于设置HTTP响应消息的状态码&#xff0c;并生成响应状态…

【C++ 11 新特性】lambda 表达式详解

文章目录 1. 常见 lambda 面试题&#x1f58a; 1. 常见 lambda 面试题&#x1f58a; &#x1f34e;① 如果⼀个 lambda 表达式作为参数传递给⼀个函数&#xff0c;那这个函数可以使⽤这个 lambda 表达式捕获的变量吗 ? &#x1f427; 函数本身无法直接访问到 lambda表达式捕获…

教资认定报名照片要求小于190kb…

教资认定报名照片要求小于190kb…… 要求&#xff1a;文件小于190kb&#xff0c;宽度290-300&#xff0c;高度408-418 方法&#xff1a;vx搜随时照-教资认定 直接制作合规尺寸即可&#xff0c;还可以打印纸质版邮寄到家

ffmpeg解封装rtsp并录制视频-(2)使用VLC模拟一个rtsp服务器并用ffmpeg解封装该rtsp流

VCL模拟服务器并打开播放该视频文件&#xff1a; - 准备好一个mp4文件&#xff0c;打开vlc软件 - 选择“媒体”》“流” - 添加一个mp4文件 - 点击下方按钮选择“串流” - 下一步目标选择rtsp 点击“添加” - 端口默认8554 - 路径设置 /test - 用…

JavaFX VBox

VBox布局将子节点堆叠在垂直列中。新添加的子节点被放置在上一个子节点的下面。默认情况下&#xff0c;VBox尊重子节点的首选宽度和高度。 当父节点不可调整大小时&#xff0c;例如Group节点&#xff0c;最大垂直列的宽度基于具有最大优选宽度的节点。 默认情况下&#xff0c;…

揭秘神秘的种子:Adobe联合宾夕法尼亚大学发布文本到图像扩散模型大规模种子分析

文章链接&#xff1a;https://arxiv.org/pdf/2405.14828 最近对文本到图像&#xff08;T2I&#xff09;扩散模型的进展促进了创造性和逼真的图像合成。通过变化随机种子&#xff0c;可以为固定的文本提示生成各种图像。在技术上&#xff0c;种子控制着初始噪声&#xff0c;并…