transformer的缺陷
自注意力机制的计算范围只限于窗口内,不能直接处理窗口外的元素,不能照顾到整个序列。
由于计算复杂度随着窗口的长度呈几何平方式增长,所以不能一味地增加窗口长度来解决。
Transformer本质上是通过位置编码将序列数据空间化,然后计算空间相关度反向建模时序相关度,忽视了数据内在结构的关联,比较简单暴力,参数效率低,冗余度高。
时序状态空间模型SSM简介
是一种基于RRN的用于描述系统状态随时间变化的数学模型,由状态方程和观测方程组成。
- 连续空间的时序建模
- 状态方程通常是一个一阶或高阶的线性或非线性微分方程,描述系统状态如何随时间演化。
- 观测方程通常是一个线性或非线性方程,表示观测数据与系统状态之间的关系,描述如何从系统状态中获得观测数据。
ABC是固定的,所以叫时不变。
SSM思路是从LTI线性时不变系统到线性最后到非线性,牺牲通用性,换得特定场景的高性能。
- 时序离散化与GNN
对连续系统进行离散化展开,使用零阶保持实现从连续系统转换为离散系统的ABC对应关系。 实际是求导变成了时间左移。
3. 并行处理与CNN
使用CNN对时序数据建模,不同时间尺度作为卷积核
输出可以表示为上式,与离散卷积公式(下式)对比,可以发现二者是类似的,外部都是求和号,对不同时间步和相应权重的线性组合,输入x也只差了一个k的时间偏移,权重都是与随k变化的量。
SSM只是换名字的CNN化的RNN。
Mamba:选择性SSM
由于SSM有两个强假设:线性和时不变,而实际系统大部分为非线性,时变的,所以SSM的应用范围较小。Mamba本质上是SSM模型的改进,放开这两个假设。他主要体现在设计一种机制,使得状态空间具备选择性,达到Transformer的建模能力,又在序列长度上实现了线性扩展,克服Transformer的缺陷。还使用GPU提高性能,并行扫描。
实现选择性,就是在中间部分使得B、C都变成t的函数,成为时变参数,A不是t的函数,但中间在使用delta进行离散化的时候也有t。蓝色部分就是选择性机制。delta_t相当于一个总开关,B_t和C_t相当于旋钮。
关注两种能力:一种是抓重点的能力,即从一句话中找出关键信息,忽略不相关的部分。另一种是上下文联想与推理能力,处理连续信息时保存逻辑一致性和上下文连贯性。
输入x_t通过三条通道影响B_t,两条通道影响C_t,两条通道影响A_t,delta函数是非线性的,所以就使得系统能够完成非线性时变。
步长delta像是放大镜观察窗口,较小时,模型倾向于忽略具体的单词,而更多依赖于之前的上下文信息。使用放大镜忽远忽近的看,实现注意力的选择。
Vision_Mamba的Vim.py源码学习
- 导入库
from typing import Optional
import torch.nn as nn
import torch
import torch.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_
from torch import Tensor
from functools import partial
from mamba_ssm.modules.mamba_simple import Mamba
from rope import *
import random
try:
from mamba_ssm.ops.triton.layernorm import RMSNorm, rms_norm_fn, layer_norm_fn
except ImportError:
RMSNorm, rms_norm_fn, layer_norm_fn = None, None, None
1. PatchEmbedding类
- 图像维度(B, C, H, W) -> (B, embed_dim, grid_size[0], grid_size[1]) -> (B, num_patches, embed_dim)
- 将图片切成小方块
- img_size图像大小, patch_size小方块大小, stride步长, in_channels输入维度, embed_dim特征向量维度,norm_layer是否归一化, flatten指是否展平
- grid_size其中两个元素都表示宽度和高度方向上的小块数量
class PatchEmbedding(nn.Module):
def __init__(self, img_size=224, patch_size=16, stride=16, in_channels=3, embed_dim=768, norm_layer=None, flatten=True):
# img_size图像大小, patch_size小方块大小, stride步长, in_channels输入维度, embed_dim特征向量维度,norm_layer是否归一化, flatten指是否展平
super(PatchEmbedding, self).__init__()
img_size = to_2tuple(img_size) # 输出是一个包含两个元素的元组,为图像的宽度和高度。
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
# grid_size其中两个元素都表示宽度和高度方向上的小块数量
self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[0] - patch_size[0]) // stride + 1)
self.num_patches = self.patch_size[0] * self.patch_size[1] # 小方块数量
self.flatten = flatten
self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() # 初始化归一化层
def forward(self, x):
B, C, H, W = x.shape()
# 如果图像尺寸不符合则报错。
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input img size {(H) * (W)} doesen't match model({self.img_size[0]} * {self.img_size[1]})"
x = self.proj(x) # 切方块 (B, embed_dim, grid_size[0], grid_size[1])
if self.flatten:
x = x.flatten[2].transpose(1, 2) # (B, grid_size[0] * grid_size[1], embed_dim)
x = self.norm(x) # 归一化
return x
2. Block类
- Mamba的Encoder层
- dim是维度,mixer_cls是混合方式, norm_cls=nn.LayerNorm归一化方式, fused_add_norm=False是否使用融合的加法和归一化操作,residual_in_fp32=False残差连接是否以单精度浮点数(32位浮点数)进行,drop_path=0.表示路径丢弃(stochasticdepth)的概率
- 这个block接受两个输入,分别是hidden_states, residual
- 判断是否要使用混合相加标准化的方式
- 将 residual 的数据类型转换为与self.norm.weight的数据类型相同,然后归一化
- 返回更新后的 hidden_states 和 residual
class Block(nn.Module): # encoder layer of Mamba
def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm,
fused_add_norm=False, residual_in_fp32=False, drop_path=0.):
# dim是维度,mixer_cls是混合方式, norm_cls=nn.LayerNorm归一化方式, fused_add_norm=False是否使用融合的加法和归一化操作
# residual_in_fp32=False残差连接是否以单精度浮点数(32位浮点数)进行
# drop_path=0.表示路径丢弃(stochasticdepth)的概率
super(Block, self).__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.mixer = mixer_cls(dim)
self.norm = norm_cls(dim)
self.drop_path = DropPath(drop_path)
if self.fused_add_norm:
# 确保 RMSNorm 已经被正确导入
# 检查 self.norm 是否是 nn.LayerNorm 或 RMSNorm 类的一个实例
assert RMSNorm is not None, "RMSNorm import Fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
# 这个block接受两个输入,分别是hidden_states, residual(可选)
if not self.fused_add_norm: # 是否要使用混合相加标准化的方式
if residual is None:
residual = hidden_states
else:
residual = residual + self.drop_path(hidden_states) # 残差连接
# 用于将 residual 的数据类型转换为与self.norm.weight的数据类型相同。这是为了确保归一化层的权重和输入张量的数据类型一致,避免在计算过程中出现类型不匹配的问题。
# self.norm归一化层
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32) # 残差连接是否以单精度浮点数(32位浮点数)进行
else:
# 如果 self.norm 是 RMSNorm 类的一个实例,则 fused_add_norm 被赋值为 rms_norm_fn;否则,fused_add_norm 被赋值为 layer_norm_fn。
fused_add_norm = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
if residual is None:
hidden_states, residual = fused_add_norm(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps, # 用于避免除以零的情况。它被加到方差上,以保证数值稳定性。
)
else:
hidden_states, residual = fused_add_norm(
self.drop_path(hidden_states),
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
)
hidden_states = self.mixer(hidden_states, inference_params=None)
return hidden_states, residual
3. create_block函数
- 双向 Mamba 配置:如果 if_bi_mamba 为真,则将 bi_mamba_type 设置为 ‘v1’。
- 工厂函数参数:创建 factory_kwargs 字典,包含设备和数据类型的参数,这些参数将用于创建模型层。
- 创建 Mamba 混合器:使用 functools.partial 来创建一个 Mamba 类的偏函数,预设了 layer_idx、bi_mamba_type、if_devide_out、init_layer_scale 和 factory_kwargs。
- 创建归一化类:同样使用 functools.partial 来创建一个归一化类的偏函数,根据 rms_norm 决定是使用 nn.LayerNorm 还是 RMSNorm,并设置 eps 和 factory_kwargs。
- 创建 Block:实例化 Block 类,传入 d_model、mixer_cls、norm_cls、drop_path、fused_add_norm 和 residual_in_fp32 参数。
- 返回 Block 实例
def create_block(
d_model,
ssm_cfg=None, # ssm是否初始化
norm_epsilon=1e-5, # 归一化操作中使用的 epsilon 值,默认为 1e-5
drop_path=0.,
rms_norm=False, # 是否使用均方根归一化
residual_in_fp32=False, # 是否将残差连接以单精度浮点数(32位浮点数)进行
fused_add_norm=False, # 是否使用融合的加法和归一化操作
layer_idx=None, # 层的索引
device=None,
dtype=None,
if_bi_mamba=None, # 是否使用双向 Mamba
bi_mamba_type='none', # 双向 Mamba 的类型
if_devide_out=False,
init_layer_scale=None, # 初始化层缩放的配置
):
if if_bi_mamba:
bi_mamba_type = 'v1'
if ssm_cfg is None:
ssm_cfg = {}
# 使用 factory_kwargs 字典,可以方便地将这些参数传递给工厂函数,而不需要在函数调用时显式地列出每一个参数,工厂函数通常用于创建类的实例。
factory_kwargs = {"device": device, "dtype": dtype}
mixer_cls = partial(
Mamba,
layer_idx=layer_idx,
bi_mamba_type=bi_mamba_type,
if_devide_out=if_devide_out,
init_layer_scale=init_layer_scale,
**ssm_cfg,
**factory_kwargs
)
norm_cls = partial(
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
)
block = Block(
dim=d_model,
mixer_cls=mixer_cls,
norm_cls=norm_cls,
drop_path=drop_path,
fused_add_norm=fused_add_norm,
residual_in_fp32=residual_in_fp32
)
block.layer_idx = layer_idx
return block
4. VisionMamba类
- 特征提取:使用 self.patch_embed 将输入图像 x 转换为嵌入特征。
- 分类令牌处理:根据配置,选择某种方式向特征中添加分类令牌
- 位置编码:如果配置了绝对位置嵌入,将其添加到特征中。
- 随机化处理:根据配置,可能对令牌顺序进行随机化。
- 图像序列翻转:根据 flip_img_sequences_ratio 的配置,对图像序列进行翻转。
- 编码器层处理:遍历 self.layers,对特征进行处理,包括残差连接、DropPath 正则化、归一化和特征混合。
- 最终池化:根据 self.final_pool_type 的配置,对特征进行池化操作。
- self.head 进行分类预测。
- 根据 return_features 参数的设置,返回特征或分类结果。
class VisionMamba(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
stride=16, # 划分小块时的步长
depth=24, # 架构中层的数量
embed_dim=192,
channels=3,
num_classes=1000, # 类别数
ssm_cfg=None, # 状态空间模型(State-Space Model)的配置
drop_rate=0., # Dropout的正则化概率
drop_path_rate=0.1, # Droppath的正则化概率
norm_epsilon: float = 1e-5,
rms_norm=False,
fused_add_norm=False,
residual_in_fp32=False,
device=None,
dtype=None,
pt_hw_seq_len=14,
if_bi_directional=False, # 是否使用双向处理
final_pool_type='none', # 最终池化类型
if_abs_pos_embed=False, # 是否使用绝对位置嵌入
if_rope=False, # 是否使用相对位置编码
if_rope_residual=False, # RoPE 是否使用残差连接
flip_img_sequences_ratio=-1, # 图像序列翻转的比例
if_bi_mamba=False,
bi_mamba_type='none',
if_cls_token=False, # 是否使用分类令牌
if_devide_out=False,
init_layer_scale=None,
use_double_cls_token=False, # 首尾两处分类令牌
use_middle_cls_token=False, # 中间位置分类令牌
**kwargs
):
factory_kwarges = {"device": device, "dtype": dtype}
kwargs.update(factory_kwarges)
super(VisionMamba, self).__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.if_bi_directional = if_bi_directional
self.final_pool_type = final_pool_type
self.if_abs_pos_embed = if_abs_pos_embed
self.if_rope = if_rope
self.if_rope_residual = if_rope_residual
self.flip_img_sequences_ratio = flip_img_sequences_ratio
self.if_cls_token = if_cls_token
self.use_double_cls_token = use_double_cls_token
self.use_middle_cls_token = use_middle_cls_token
self.num_tokens = 1 if if_cls_token else 0
self.num_classes = num_classes
self.d_model = self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbedding(
img_size=img_size, patch_size=patch_size, stride=stride, in_channels=channels, embed_dim=embed_dim
)
num_patches = self.patch_embed.num_patches
# cls_token
if if_cls_token:
if use_double_cls_token:
self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
self.num_tokens = 2
else:
self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
# position embedding
if if_abs_pos_embed:
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens + num_patches, self.embed_dim))
self.pos_drop = nn.Dropout(drop_rate)
# RoPE(Relative Positional Encoding)使用这些相对位置权重来调整不同元素之间的关系得分
if if_rope:
half_head_dim = embed_dim // 2
hw_seq_len = img_size // patch_size
self.rope = VisionRotaryEmbeddingFast(
dim=half_head_dim,
pt_seq_len=pt_hw_seq_len,
ft_seq_len=hw_seq_len,
)
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
# drop path rate
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
# 创建一个从0开始逐渐增加到drop_path_rate的DropPath概率列表,列表的长度等于模型的深度。这样,模型的每一层都会有一个与之对应的DropPath概率,随着层的加深,DropPath的概率逐渐增加,从而实现更深层的正则化效果。
# .item() 方法用于将Tensor中的值转换为标准的Python数值类型
inter_dpr = [0.0] + dpr # 在第一个阶段或层不进行DropPath,而后续的元素来自 dpr 列表,表示从第二层开始每层的DropPath概率。
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
# 初始化Encoder
self.layers = nn.ModuleList(
[
create_block(
d_model=embed_dim,
ssm_cfg=ssm_cfg,
norm_epsilon=norm_epsilon,
rms_norm=rms_norm,
residual_in_fp32=residual_in_fp32,
fused_add_norm=fused_add_norm,
layer_idx=i,
if_bi_mamba=if_bi_mamba,
bi_mamba_type=bi_mamba_type,
drop_path=inter_dpr[i],
if_devide_out=if_devide_out,
init_layer_scale=init_layer_scale,
**factory_kwarges
)
for i in range(depth)
]
)
# 标准化层
self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(embed_dim, eps=norm_epsilon, **factory_kwarges)
if if_abs_pos_embed:
trunc_normal_(self.pos_embed, std=.02) # 张量初始化为截断的正态分布
if if_cls_token:
if use_double_cls_token:
trunc_normal_(self.cls_token_head, std=.02)
trunc_normal_(self.cls_token_tail, std=.02)
else:
trunc_normal_(self.cls_token, std=.02)
# 前向特征传播
def forward_features(self, x, inference_params=None,
if_random_cls_token_position=False,
if_random_token_rank=False):
x = self.patch_embed(x)
B, M, _ = x.shape
# 拼接cls_token
if self.if_cls_token:
if self.use_double_cls_token:
cls_token_head = self.cls_token_head.expand(B, -1, -1)
cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
# 定义分类令牌的位置。这里,0 表示序列的开始处放置 cls_token_head,M + 1 表示在所有嵌入小块之后放置 cls_token_tail。
token_position = [0, M + 1]
# 沿着第二个维度(dim=1)拼接起来
x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
M = x.shape[1] # 更新 M 的值为新的序列长度
else:
if self.use_middle_cls_token:
cls_token = self.cls_token.expand(B, -1, -1)
token_position = M // 2
x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
elif if_random_cls_token_position:
cls_token = self.cls_token.expand(B, -1, -1)
token_position = random.randint(0, M)
x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
print("token_position", token_position)
else:
cls_token = self.cls_token.expand(B, -1, -1)
token_position = 0
x = torch.cat((cls_token, x), dim=1)
M = x.shape[1]
# 装上位置编码
if self.if_abs_pos_embed:
x = x + self.pos_embed
x = self.pos_drop(x)
# 随机序列生成索引
if if_random_token_rank:
shuffle_indices = torch.randperm(M) # 0到M-1的随机排列数组
if isinstance(token_position, list):
# 如果 token_position 是列表,并且列表中有2个元素,那么打印出序列中第一个批次(batch)中,
# 位于 token_position[0] 和 token_position[1] 位置的令牌的特定维度(这里是第一个维度,索引为0)的值。
print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
else:
# 如果 token_position 不是列表,那么打印出序列中第一个批次中,位于 token_position 索引位置的令牌的特定维度的值。
print("original value: ", x[0, token_position, 0])
print("original token_position: ", token_position)
# 以 shuffle_indices 作为第二个维度的索引,从而实现整个序列的随机洗牌,结果存储回 x
x = x[:, shuffle_indices, :]
# 更新位置索引
# 使用 torch.where 函数找到 shuffle_indices 中等于 token_position[i] 的索引位置
# torch.where 返回一个元组,其中第一个元素包含满足条件的索引,通过 [0] 获取这个索引张量,然后通过 .item() 将其转换为一个标量值
if isinstance(token_position, list):
new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item()
for i in range(len(token_position))]
token_position = new_token_position
print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
else:
token_position = torch.where(shuffle_indices == token_position)[0].item()
print("new walue: ", x[0, token_position, 0])
print("new token_position: ", token_position)
# 翻转操作
if_flip_img_sequences = False # 记录是否执行了图像序列的翻转操作
# 即使 flip_img_sequences_ratio 为一个正值,翻转操作也不是每次都执行,而是以这个随机比率的概率执行。
if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
x = x.flip([1]) # x.flip([1]) 表示沿着第一个维度进行翻转
if_flip_img_sequences = True
# Mamba
residual = None
hidden_states = x
# 根据模型的配置和之前的操作(如翻转和RoPE)来更新隐藏状态和残差连接。
if not self.if_bi_directional: # 单向的,则执行循环
for layer in self.layers:
if if_flip_img_sequences and self.if_rope:
# 如果之前已经执行了图像序列的翻转(if_flip_img_sequences 为 True),
# 并且模型包含相对位置编码(self.if_rope 为 True),则对隐藏状态(hidden_states)进行翻转操作
hidden_states = hidden_states.flip([1]) # 通常意味着沿着宽度或高度翻转图像。
if residual is not None: # 如果残差连接(residual)存在,则同样对其进行翻转操作。
residual = residual.flip([1])
if self.if_rope:
# 对 hidden_states 应用相对位置编码
hidden_states = self.rope(hidden_states)
# 如果残差连接存在,并且模型配置为在残差连接上应用 RoPE(self.if_rope_residual 为 True)
if residual is not None and self.if_rope_residual:
residual = self.rope(residual) # 对残差连接应用相对位置编码
hidden_states, residual = layer(
hidden_states, residual, inference_params=inference_params
)
else: # 双向
for i in range(len(self.layers) // 2):
if self.rope:
hidden_states = self.rope(hidden_states)
if residual is not None and self.if_rope_residual:
residual = self.rope(residual)
hidden_states_f, residual_f = self.layers[i*2](
hidden_states, residual, inference_params=inference_params
)
hidden_states_b, residual_b = self.layers[i*2 + 1](
hidden_states.flip([1]),
None if residual == None else residual.flip([1]),
inference_params=inference_params
)
hidden_states = hidden_states_f + hidden_states_b.flip([1])
residual = residual_f + residual_b.flip([1])
if not self.fused_add_norm:
if residual is None:
residual = hidden_states
else:
residual = residual + self.drop_path(hidden_states)
# 将残差连接的结果 residual 转换为归一化层 self.norm_f 的权重数据类型,然后应用归一化
hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
else:
fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
hidden_states = fused_add_norm_fn(
self.drop_path(hidden_states),
self.norm_f.weight,
self.norm_f.bias,
eps=self.norm_f.eps,
residual=residual,
residual_in_fp32=self.residual_in_fp32
)
# 根据不同的配置,从模型的输出中提取并返回分类令牌的特征表示。
if self.if_cls_token:
if self.use_double_cls_token:
# 返回两个分类令牌的特征表示的平均值。
return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
else:
# 直接返回该令牌的特征表示。
if self.use_middle_cls_token:
return hidden_states[:, token_position, :]
elif if_random_cls_token_position:
return hidden_states[:, token_position, :]
else:
return hidden_states[:, token_position, :]
# 池化
if self.final_pool_type == 'none':
return hidden_states[:, -1, :]
elif self.final_pool_type == 'mean':
return hidden_states.mean(dim=1)
elif self.final_pool_type == 'max':
return hidden_states
elif self.final_pool_type == 'all':
return hidden_states
else:
raise NotImplementedError
def forward(self, x,
return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position,
if_random_token_rank=if_random_token_rank)
if return_features:
return x
x = self.head(x)
if self.final_pool_type == "max":
x = x.max(dim=1)[0]
return x
5. test函数
- device 变量根据系统是否支持 CUDA(GPU)来选择运行设备,优先使用 GPU,如果不支持则回退到 CPU。
- 创建 VisionMamba 类的实例 model。
- .to(device) 将模型移动到选定的设备上。
- 创建一个随机初始化的输入张量 x,其尺寸为 (4, 3, 224, 224),表示一个包含 4 张图像的批次,每张图像有 3 个通道,尺寸为 224x224。然后将输入数据移动到与模型相同的设备上。
- 通过 model(x) 执行模型的前向传播,得到预测结果 preds。
def test():
device = "cuda" if torch.cuda.is_available() else "cpu"
model = VisionMamba(
patch_size=16,
embed_dim=192,
depth=24,
rms_norm=True,
residual_in_fp32=True,
fused_add_norm=True,
final_pool_type="mean",
if_abs_pos_embed=True,
if_rope=False,
if_rope_residual=False,
bi_mamba_type="V2",
if_cls_token=True,
if_devide_out=True,
use_middle_cls_token=True
).to(device)
x = torch.randn(size=(4, 3, 224, 224)).to(device)
preds = model(x)
print(f"preds shape if {preds.shape}")
if __name__ == "__main__":
test()