ViT如何支持变长序列(patches)输入?

问题:当增加输入图像的分辨率时,例如DeiT 从 224 到 384,一般来说会保持 patch size(例如9),因此 patch 的数量 N 会发生了变化。那么视觉transformer是如何处理变长序列输入的?

 

回答

在讨论视觉ViT中,对于图像分类任务,不管序列多长(所有patches加起来的长度),一般是在输入序列的开始添加一个特殊的cls token,这个cls token 不对应图像中的任何一个具体的 patch,而是作为一个全局的表示,用于汇总整个图像的信息。在经过 Transformer 层处理之后,这个 class token 被用来代表整个图像的特征(只取这个latent token用于MLP),用于最终的分类任务。这意味着不管输入的图像分辨率如何变化,从而导致序列的长度(即 patch 的数量)如何变化,模型总是会关注这一个特定的 token 来进行分类判断。意思就是不管你切分为多少个patches,这些patches进入transformer encoder得到latent token,我们只取第0个token(cls token)用于分类任务,剩余的token都不使用。

Transformer 架构中的MLP(多层感知机)和FFN(前馈神经网络)是共享的,这意味着模型中的所有tokens(在图像处理上下文中,可以认为是图像被分割成的小块或"patches")通过相同的MLP和FFN进行处理。无论输入序列的长度(即tokens总共的数量)如何变化,每个token都会被相同的MLP和FFN处理。这是因为MLP和FFN在Transformer架构中是以相同的方式应用于序列中的每一个元素,而不依赖于序列的总长度。这种设计允许Transformer模型处理可变长度的输入序列,因为每个token的处理方式是一样的。

但是由于提高输入图像的分辨率会增加序列的长度(因为 patch 的数量增多),原有的位置编码(position embedding)无法直接适用于新的序列长度。你只需要对现有的位置编码进行插值,以生成适应新序列长度的位置编码,然后通过微调(fine-tune)这些插值生成的编码,便可以使模型能够适应新的输入分辨率

插值方式可以看VAE的实现:

interpolate_pos_embed(model, checkpoint_model)

 这个过程确保了视觉 Transformer 模型可以处理不同分辨率的输入图像,同时保持了使用单一的 class token 来汇总和利用全图信息进行分类的策略。

归根到底:

在ViT(Vision Transformer)的上下文中,模型处理的输入维度会因为图像分辨率大小的不同而导致patch数量的变化,从而影响到位置编码层的输入维度。

然而,ViT的核心Transformer架构设计为处理任意长度的序列。这意味着无论patch的数量如何,Transformer的主体结构(多头自注意力机制和MLP)都能够处理。这是因为这些组件是基于序列中的每个patch独立操作的,而不是依赖于整个序列的维度。因此,Transformer核心是不直接受图像大小影响的。

可以看到下面的示例中“整个ViT过程中image-size只会影响PE的维度大小,而不会影响其他的任何参数,所以只需要在新的分辨率的图像来的时候,修改PE的维度就可以了(加大或减小到patch的长度)

image_height, image_width = pair(image_size)
patch_height, patch_width = pair(patch_size)
num_patches = (image_height // patch_height) * (image_width // patch_width)
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))

整个修改流程:

  1. 加载在低分辨率下的图像训练好的pretrain-checkpoint
  2. 修改低分辨率下的pretrain-checkpoint的PE维度,从而适应当前分辨率下的维度
  3. 保存修改后的pretrain-checkpoint
  4. 加载高维度下的model
  5. 将修改后的pretrain-checkpoint的用在高维度下的model(load_state_dict)
  6. 微调就可以了

代码实现:

下面的代码实现是256-->512维度的变化,类似于MAE的训练的时候用unmask的patch,微调的时候使用unmask+mask的patch

1、基本的ViT的实现:

import torch
from torch import nn, einsum, optim
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
from tqdm import tqdm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


def pair(t):
    return t if isinstance(t, tuple) else (t, t)



class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)



class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)


class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout),
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)  # (b, n(65), dim*3) ---> 3 * (b, n, dim)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)  # q, k, v   (b, h, n, dim_head(64))

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)


class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout=0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads=heads, dim_head=dim_head, dropout=dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x


class ViT(nn.Module):
    def __init__(self, *, image_size=256, patch_size=8, num_classes=1000, dim=1024, depth=6, heads=16, mlp_dim=2048, pool='cls', channels=3, dim_head=64, dropout=0., emb_dropout=0.):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert  image_height % patch_height ==0 and image_width % patch_width == 0

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim)
        )
        # 当image-size改变的时候,只会影响pos_embedding的维度,所以当图像分辨率变化的时候,只有这一个地方需要修改
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)                                    # b c (h p1) (w p2) -> b (h w) (p1 p2 c) -> b (h w) dim
        b, n, _ = x.shape                                                   # b表示batchSize, n表示每个块的空间分辨率, _表示一个块内有多少个值
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b=b)  # self.cls_token: (1, 1, dim) -> cls_tokens: (batchSize, 1, dim)
        x = torch.cat((cls_tokens, x), dim=1)                        # 将cls_token拼接到patch token中去       (b, 65, dim)
        x += self.pos_embedding[:, :(n+1)]                                  # 加位置嵌入(直接加)      (b, 65, dim)
        x = self.dropout(x)
        x = self.transformer(x)                                                 # (b, 65, dim)
        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]                   # (b, dim)
        x = self.to_latent(x)                                                   # Identity (b, dim)
        return self.mlp_head(x)                                                 #  (b, num_classes)

2、训练的代码

# ============================== begin train ==============================
model = ViT(
        image_size = 256,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )


# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 模拟训练数据
train_data = torch.randn(16, 3, 256, 256)       # 示例数据,16个样本
train_labels = torch.randint(0, 1000, (16,))    # 示例标签,1000个类

# 训练模型
def train(model, data, labels, criterion, optimizer, epochs=1):
    model.train()
    for epoch in tqdm(range(epochs)):
        for i in range(len(data)):
            optimizer.zero_grad()
            output = model(data[i].unsqueeze(0))  # 前向传播
            loss = criterion(output, labels[i].unsqueeze(0))  # 计算损失
            loss.backward()
            optimizer.step()
        print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

# 训练模型
train(model, train_data, train_labels, criterion, optimizer, epochs=10)

# 保存模型checkpoint
torch.save({
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
}, 'vit_model_checkpoint.pth')

3、修改PE维度并用在不一样的分辨率下的图像上

这里“resize_pos_embedding”我直接固定了,有更优雅的实现方式,懒得写了

# --------------------------------------------------------
# Interpolate position embeddings for high-resolution
# References:
# DeiT: https://github.com/facebookresearch/deit
# --------------------------------------------------------



def resize_pos_embedding(checkpoint_model, new_image_size, patch_size=32):
    num_patches = (new_image_size // patch_size) ** 2
    new_num_tokens = num_patches + 1  # 加上一个cls_token
    old_pos_embedding = checkpoint_model['pos_embedding']
    old_num_tokens, embedding_dim = old_pos_embedding.shape[1], old_pos_embedding.shape[2]

    # # 插值pos_embedding
    # new_pos_embedding = F.interpolate(
    #     old_pos_embedding.permute(0, 2, 1).reshape(1, embedding_dim, int(old_num_tokens ** 0.5), int(old_num_tokens ** 0.5)),
    #     size=int(new_num_tokens ** 0.5),
    #     mode='bilinear',
    #     align_corners=False
    # ).reshape(1, embedding_dim, new_num_tokens).permute(0, 2, 1)
    #
    # # 更新模型的pos_embedding
    # checkpoint_model['pos_embedding'] = nn.Parameter(new_pos_embedding)

    new_pos_embedding = nn.Parameter(torch.randn(1,257,1024))
    checkpoint_model['pos_embedding'] = new_pos_embedding


# ============================== begin test ==============================
image_size = 512

# Step 1: 加载新尺寸下的模型
model = ViT(
        image_size = image_size,
        patch_size = 32,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

# Step 2: 加载原来小尺寸下训练好的checkpoint
checkpoint_model  = torch.load('vit_model_checkpoint.pth')

# Step 3: interpolate position embedding,这是新分辨率下模型中唯一需要修改的地方
model_state = checkpoint_model['model_state_dict']
resize_pos_embedding(model_state, image_size)

# Step 4: 加载新分辨率下的模型
model.load_state_dict(model_state)

# 准备评估数据
test_loader = DataLoader(datasets.FakeData(size=16, image_size=(3, image_size, image_size), num_classes=1000, transform=transforms.ToTensor()), batch_size=16, shuffle=True)

model.eval()
correct = 0
total = 0
with torch.no_grad():
    for imgs, labels in test_loader:
        outputs = model(imgs)
        _, predicted = torch.max(outputs.data, 1)
        print(predicted)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Accuracy of the model on the test images: {16 * correct / total}%')

ViT、Deit这类视觉transformer是如何处理变长序列输入的? - 知乎

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/471581.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MySQL的目录结构

安装目录 /usr/local/mysql数据目录 /usr/local/mysql/data配置目录 /usr/local/etc/my.cnf点击返回 MySQL 快速学习目录

【NLP】TF-IDF算法原理及其实现

🌻个人主页:相洋同学 🥇学习在于行动、总结和坚持,共勉! #学习笔记# 目录 01 TF-IDF算法介绍 02 TF-IDF应用 03 Sklearn实现TF-IDF算法 04 使用TF-IDF算法提取关键词 05 TF-IDF算法的不足 TF-IDF算法非常容易理…

matlab 基于小波变换的油气管道泄露信号检测

1、内容简介 略 71-可以交流、咨询、答疑 基于小波变换的油气管道泄露信号检测 去噪、小波变换、油气管道泄露、信号检测 2、内容说明 摘 要: 油气管道泄漏会造成严重危害,因此,亟需寻找一种能快速检测油气管道信号的技术。传统的 傅里…

Vue2(八):脚手架结构、render函数、ref属性、props配置项、mixin(混入)、插件、scoped样式

一、脚手架结构分析 crlc终止刚刚搭建的vue。 ├── node_modules ├── public │ ├── favicon.ico: 页签图标 │ └── index.html: 主页面 ├── src │ ├── assets: 存放静态资源 │ │ └── logo.png │ │── component: 存放组件 │ │ …

Gin框架 源码解析

https://zhuanlan.zhihu.com/p/136253346 https://www.cnblogs.com/randysun/category/2071204.html 这个博客里其他go的内容也讲的很好 启动 因为 gin 的安装教程已经到处都有了,所以这里省略如何安装, 建议直接去 github 官方地址的 README 中浏览安装…

【数据库基础增删改查】条件查询、分页查询

系列文章目录 🌈座右铭🌈:人的一生这么长、你凭什么用短短的几年去衡量自己的一生! 💕个人主页:清灵白羽 漾情天殇_计算机底层原理,深度解析C,自顶向下看Java-CSDN博客 ❤️相关文章❤️:清灵白羽 漾情天…

AI浸入社交领域,泛娱乐APP如何抓住新风口?

2023年是大模型技术蓬勃发展的一年,自ChatGPT以惊艳姿态亮相以来,同年年底多模态大模型技术在国内及全球范围内的全面爆发,即模型能够理解并生成包括文本、图像、视频、音频等多种类型的内容。例如,基于大模型的文本到图像生成工具…

Samtec科普 | 一文了解患者护理应用连接器

【摘要/前言】 通过医疗专业人士为患者提供护理的种种需求,已经不限于手术室与医院的各种安全状况。当今许多患者的护理都是在其他环境进行,例如医生办公室、健康中心,还有越来越普遍的住家。尤其是需要长期看护的患者,所需的科技…

Mysql数据库概念与安装

目录 一、数据库概述 1、数据库的基本概念 2、数据库管理系统(DBMS) 2.1 数据库管理系统概念 2.2 数据库管理系统工作模式 3、数据库系统(DBS) 3.1 数据库系统概念 3.2 数据库系统发展史 4、关系型数据库与非关系型数据库…

机器学习——终身学习

终身学习 AI不断学习新的任务,最终进化成天网控制人类终身学习(LLL),持续学习,永不停止的学习,增量学习 用线上收集的资料不断的训练模型 问题就是对之前的任务进行遗忘,在之前的任务上表现不好…

【机器学习】无监督学习算法之:K均值聚类

K均值聚类 1、引言2、K均值聚类2.1 定义2.2 原理2.3 实现方式2.4 算法公式2.4.1 距离计算公式2.4.1 中心点计算公式 2.5 代码示例 3、总结 1、引言 小屌丝:鱼哥, K均值聚类 我不懂,能不能给我讲一讲? 小鱼:行&#xf…

AI助手 - Fitten Code

前言 上一篇介绍了商汤AI编程小助手小浣熊 Raccoon,过程中又发现了另外一款国产AI编程助手,那就是本篇要介绍的非十科技出品的Fitten Code。 ​ Fitten Code 主打一个快:超高准确率、超快的响应速度。号称代码生成比GitHub Copilot 快两倍&am…

蓝桥杯模块综合——高质量讲解AT24C02,BS18B20,BS1302,AD/DA(PCF8591),超声波模块

AT24C02——就是一个存储的东西,可以给他写东西,掉电不丢失。 void EEPROM_Write(unsigned char * EEPROM_String,unsigned char addr , unsigned char num) {IIC_Start();IIC_SendByte(0xA0);IIC_WaitAck();IIC_SendByte(addr);IIC_WaitAck();while(nu…

奶牛均分

解法&#xff1a; 假设编号从左到右递增&#xff0c;奶牛每次只能去往左边的牛圈。因此等分最大奶牛数小于等于最右边牛圈奶牛数&#xff0c;不妨设数为k&#xff0c;那么a[i]>k&#xff0c;a[i-1]>2k。。。 做后缀和二分答案就可找到k #include<iostream> #inc…

字符串筛选排序 - 华为OD统一考试(C卷)

OD统一考试&#xff08;C卷&#xff09; 分值&#xff1a; 100分 题解&#xff1a; Java / Python / C 题目描述 输入一个由n个大小写字母组成的字符串&#xff0c; 按照 ASCII 码值从小到大的排序规则&#xff0c;查找字符串中第 k 个最小ASCII 码值的字母(k>1) , 输出该…

CSS学习(3)-浮动和定位

一、浮动 1. 元素浮动后的特点 脱离文档流。不管浮动前是什么元素&#xff0c;浮动后&#xff1a;默认宽与高都是被内容撑开&#xff08;尽可能小&#xff09;&#xff0c;而且可以设置宽 高。不会独占一行&#xff0c;可以与其他元素共用一行。不会 margin 合并&#xff0c;…

C语言易错知识点

1、数组长度及所占字节数 char x[] {"Hello"},y[]{H,e,l,l,o}; x数组的长度为5&#xff0c;y的长度也是5 x、y数组所占字符串为6为 51(\0)6 strlen&#xff08;&#xff09;函数得到的是数组的长度 2、%%与%的优先级 #include<stdio.h> int main(){ int a…

HarmonyOS4.0—自定义渐变导航栏开发教程

前言 今天要分享的是一个自定义渐变导航栏&#xff0c;本项目基于鸿蒙4.0。 先看效果&#xff1a; 这种导航栏在开发中也比较常见&#xff0c;特点是导航栏背景色从透明到不透明的渐变&#xff0c;以及导航栏标题和按钮颜色的变化。 系统的导航栏无法满足要求&#xff0c;我们…

Visual Studio 2013 - 高亮设置括号匹配 (方括号)

Visual Studio 2013 - 高亮设置括号匹配 [方括号] 1. 高亮设置 括号匹配 (方括号)References 1. 高亮设置 括号匹配 (方括号) 工具 -> 选项… -> 环境 -> 字体和颜色 References [1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/

基于信号分解的几种一维时间序列降噪方法(MATLAB R2021B)

自适应信号分解算法是一种适合对非平稳信号分析的方法&#xff0c;它将一个信号分解为多个模态叠加的形式&#xff0c;进而可以准确反应信号中所包含的频率分量以及瞬时频率随时间变化的规律。自适应信号分解算法与众多“刚性”方法(如傅里叶变换&#xff0c;小波变换)不同&…