ViT的极简pytorch实现及其即插即用

先放一张ViT的网络图
在这里插入图片描述
可以看到是把图像分割成小块,像NLP的句子那样按顺序进入transformer,经过MLP后,输出类别。每个小块是16x16,进入Linear Projection of Flattened Patches, 在每个的开头加上cls token和位置信息,也就是position embedding。
去掉数据读取部分,直接上一个极简的ViT代码:

import torch
from torch import nn

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)## 对tensor张量分块 x :1 197 1024   qkv 最后是一个元祖,tuple,长度是3,每个元素形状:1 197 1024
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        image_height, image_width = pair(image_size)   # 224*224
        patch_height, patch_width = pair(patch_size)   # 16 * 16

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            # (b,3,224,224) -> (b,196,768)    14*14=196  16*16*3=768
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),    # (b,196,1024)
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        x = self.to_patch_embedding(img)        # img 1 3 224 224  输出形状x : 1 196 1024
        b, n, _ = x.shape                       # 1 196

        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)    # (1,1,1024)
        x = torch.cat((cls_tokens, x), dim=1)   # (1,197,1024)
        x += self.pos_embedding[:, :(n + 1)]    # (1,197,1024)
        x = self.dropout(x)                     # (1,197,1024)

        x = self.transformer(x)                 # (1,197,1024)

        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]     # (1,1024)

        x = self.to_latent(x)      # (1,1024)
        return self.mlp_head(x)    # (1,1000)


if __name__ == '__main__':
    v = ViT(
        image_size = 224,
        patch_size = 16,
        num_classes = 1000,
        dim = 1024,
        depth = 6,
        heads = 16,
        mlp_dim = 2048,
        dropout = 0.1,
        emb_dropout = 0.1
    )

    img = torch.randn(1, 3, 224, 224)

    preds = v(img)        # (1, 1000)

    print(preds.shape)

去掉cls和最后的全连接分类头,变成即插即用的模块:

import torch
from torch import nn

from einops import rearrange
from einops.layers.torch import Rearrange

# helpers

def pair(t):
    return t if isinstance(t, tuple) else (t, t)

# classes

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim = -1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)## 对tensor张量分块 x :1 197 1024   qkv 最后是一个元祖,tuple,长度是3,每个元素形状:1 197 1024
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, dim = 1024, depth = 3, heads = 16, mlp_dim = 2048, dim_head = 64, dropout = 0.1, emb_dropout = 0.1):
        super().__init__()
        channels, image_height, image_width = image_size   # 256,64,80
        patch_height, patch_width = pair(patch_size)       # 4*4

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)     # 16*20
        patch_dim = 64 * patch_height * patch_width    # 64*8*10

        self.conv1 = nn.Conv2d(256, 64, 1)

        self.to_patch_embedding = nn.Sequential(
            # (b,64,64,80) -> (b,320,1024)    16*20=320  4*4*64=1024
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
            nn.Linear(patch_dim, dim),    # (b,320,1024)
        )

        self.to_img = nn.Sequential(
            # b c (h p1) (w p2) -> (b,64,64,80)      16*20=320  4*4*64=1024
            Rearrange('b (h w) (p1 p2 c) -> b c (h p1) (w p2)', \
                      p1 = patch_height, p2 = patch_width, h = image_height // patch_height, w = image_width // patch_width),
            nn.Conv2d(64, 256, 1),      # (b,64,64,80) -> (b,256,64,80)
        )
        # 位置编码
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

    def forward(self, img):
        x = self.conv1(img)                     # img 1 256 64 80 -> 1 64 64 80
        x = self.to_patch_embedding(x)          # 1 320 1024
        b, n, _ = x.shape                       # 1 320

        x += self.pos_embedding[:, :(n + 1)]    # (1,320,1024)
        x = self.dropout(x)                     # (1,320,1024)

        x = self.transformer(x)                 # (1,320,1024)

        x = self.to_img(x)

        return x                                # (1 256 64 80)


if __name__ == '__main__':

    v = ViT(image_size = (256,64,80), patch_size = 4)

    img = torch.randn(1, 256, 64, 80)

    preds = v(img)         # (1, 256, 64, 80)

    print(preds.shape)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/278745.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【软件测试】为bug而生

为什么定位问题如此重要? 可以明确一个问题是不是真的“bug” 很多时候,我们找到了问题的原因,结果发现这根本不是bug。原因明确,误报就会降低多个系统交互,可以明确指出是哪个系统的缺陷,防止“踢皮球”&…

【23.12.29期--Redis缓存篇】谈一谈Redis的集群模式

谈一谈Redis的集群模式 ✔️ 谈一谈Redis的集群模式✔️主从模式✔️ 特点✔️Redis主从模式Demo ✔️哨兵模式✔️Redis哨兵模式Demo✔️特点 ✔️Cluster模式✔️Redis Cluster模式Demo✔️特点 ✔️ 谈一谈Redis的集群模式 Redis有三种主要的集群模式,用于在分布…

电气产品外壳常用材质PA、PC、PBT、ABS究竟是什么?

在如今工业制造领域,各种改性塑料、复合材料以及轻质合金材料的运用日趋成熟。在电气领域,不同电气产品的外壳、组件材质采用不同材料,以同为科技(TOWE)电气产品为例,工业连接器系列产品采用PA6外壳材质、机…

【SD】一致性角色 - 表情差异生成 【1】

原理:通过segment 局部重绘 可以根据lora 产生面部表情图片 模型:sam_vit_h_4b8939.pth 导入图片到segment 开启:Enable GroundingDINO GroundingDINO Detection Prompt :输入 face 然后点击:Preview Segmentation …

如何文件从电脑传到iPhone,这里提供几个方法

本文介绍了如何使用Finder应用程序、iTunes for Windows、iCloud和谷歌照片将照片从Mac或PC传输到iPhone。 如何将照片从Mac传输到iPhone 如果你有一台Mac电脑,里面装满了你想转移到iPhone的照片,这是一件非常简单的事情。只需遵循以下步骤&#xff1a…

LIUNX进程程序替换

1.什么是程序替换 a.一个程序,只能执行自己的代码 b.如果想要一个程序执行,别的程序的代码呢? 我们就可以创建一个子进程,将这个子进程替换为我们想要执行的程序。 2.样例代码-----execl(接口) 返回值&…

跨进程通信 macOS XPC 创建实例

一:简介 XPC 是 macOS 里苹果官方比较推荐和安全的的进程间通信机制。 集成流程简单,但是比较绕。 主要需要集成 XPC Server 这个模块,这个模块最终会被 apple 的根进程 launchd 管理和以独立进程的方法唤起和关闭, 我们主app 进…

交叉编译aarch64架构支持openssl的curl、libcurl

本文档旨在指导读者在x86_64平台上交叉编译curl和openssl库以支持aarch64架构。在开始之前,请确保您的系统环境已正确配置。 1. 系统环境准备 系统是基于Ubuntu 20.04 LTS,高版本可能会有问题。首页,安装必要的开发工具和库文件。打开终端并…

QML 怎么调用 C++ 中的内容?

以下内容为本人的学习笔记,如需要转载,请声明原文链接 微信公众号「ENG八戒」https://mp.weixin.qq.com/s/z_JlmNe6cYldNf11Oad_JQ 先说明一下测试环境 编译器:vs2017x64 开发环境:Qt5.12 这里主要是总结一下,怎么在…

Java——猫猫图鉴微信小程序(前后端分离版)

目录 一、开源项目 二、项目来源 三、使用框架 四、小程序功能 1、用户功能 2、管理员功能 五、使用docker快速部署 六、更新信息 审核说明 一、开源项目 猫咪信息点-ruoyi-cat: 1、一直想做点项目进行学习与练手,所以做了一个对自己来说可以完成的…

MFC随对话框大小改变同时改变控件大小

先看一下效果; 初始; 窗口变大,控件也变大; 二个也可以; 窗口变大,控件变大; 默认生成的对话框没有WM_SIZE消息的处理程序;打开类向导,选中WM_SIZE消息,对CxxxDlg类添加该消息的处理程序;默认生成的函数名是OnSize; 添加了以后代码中会有三处变化; 在对话框类的…

嵌入式SOC之通用图像处理之OSD文字信息叠加的相关实践记录

机缘巧合 机缘巧合下, 在爱芯元智的xx开发板下进行sdk的开发.由于开发板目前我拿到是当前最新的一版(估计是样品),暂不公开开发板具体型号信息.以下简称板子 .很多优秀的芯片厂商,都会提供与开发板配套的完善的软件以及完善的技术支持(FAE),突然觉得爱芯…

鸿蒙(HarmonyOS 3.1) DevEco Studio 3.1开发环境汉化

鸿蒙(HarmonyOS 3.1) DevEco Studio 3.1开发环境汉化 一、安装环境 操作系统: Windows 10 专业版 IDE:DevEco Studio 3.1 SDK:HarmonyOS 3.1 二、设置过程 打开IDE,在第一个菜单File 中找到Settings...菜单 在Setting...中找到Plugins…

VSCode远程开发配置

目录 概要远程开发插件安装开始连接SSH无密码登录开发环境配置 概要 现在很多公司都是直接远程到服务器上写代码,使用远程开发,可以在与生产环境相同的环境中开发、测试和部署代码,减少因环境不同而导致的问题。当下VSCode远程开发是支持的比…

用通俗易懂的方式讲解大模型:基于 LangChain 和 ChatGLM2 打造自有知识库问答系统

随着人工智能技术的迅猛发展,问答机器人在多个领域中展示了广泛的应用潜力。在这个信息爆炸的时代,许多领域都面临着海量的知识和信息,人们往往需要耗费大量的时间和精力来搜索和获取他们所需的信息。 在这种情况下,垂直领域的 A…

初识Sringboot3+vue3环境准备

环境准备 后端环境准备 下载JDK17https://www.oracle.com/java/technologies/downloads/#jdk17-windows 安装就下一步下一步,选择安装路径 配置环境 环境 JDK17、IDEA2021、maven3.5、vscode 后端 基础:javaSE,javaWeb、JDBC、SMM框架(Spr…

Vscode新手安装与使用

安装与版本选择 VS Code 有两个不同的发布渠道:一个是我们经常使用的稳定版(Stable),每个月发布一个主版本;另外一个发布渠道叫做 Insiders,每周一到周五 UTC 时间早上6点从最新的代码发布一个版本&#x…

【网络安全 | CTF】FlatScience

该题考察SQL注入 正文 后台扫到robots.txt 页面内容如下&#xff1a; 进入login.php 页面源代码如图&#xff1a; 传参debug得到php代码&#xff1a; <?php if(isset($_POST[usr]) && isset($_POST[pw])){$user $_POST[usr];$pass $_POST[pw];$db new SQLite3…

小米手机小游戏隐私问题解决方案

1.由于laya底层代码调用获取设备信息&#xff0c;导致原先启动laya引擎后才去弹出隐私政策条款的功能是过不了审核的&#xff0c;所以需要在android的设计一个隐私条款的弹窗&#xff0c;玩家同意条款后才启动laya引擎&#xff1a; &#xff08;1&#xff09;定义隐私条款弹窗的…

mfc100u.dll文件丢失了要怎么解决?修复mfc100u.dll详细指南

mfc100u.dll文件丢失了要怎么解决?首先让我们扒一扒什么是 mfc100u.dll。这玩意儿是 Microsoft Visual Studio 2010 的一部分&#xff0c;它就像一款程序生活中不可或缺的零件&#xff0c;没了它&#xff0c;程序肯定跑不起来。想想看&#xff0c;没有一个重要的零件&#xff…