SSM与Mamba模型学习

transformer的缺陷

自注意力机制的计算范围只限于窗口内,不能直接处理窗口外的元素,不能照顾到整个序列。
由于计算复杂度随着窗口的长度呈几何平方式增长,所以不能一味地增加窗口长度来解决。
Transformer本质上是通过位置编码将序列数据空间化,然后计算空间相关度反向建模时序相关度,忽视了数据内在结构的关联,比较简单暴力,参数效率低,冗余度高。

时序状态空间模型SSM简介

是一种基于RRN的用于描述系统状态随时间变化的数学模型,由状态方程和观测方程组成。

  1. 连续空间的时序建模
  • 状态方程通常是一个一阶或高阶的线性或非线性微分方程,描述系统状态如何随时间演化。
    在这里插入图片描述
  • 观测方程通常是一个线性或非线性方程,表示观测数据与系统状态之间的关系,描述如何从系统状态中获得观测数据。
    在这里插入图片描述
    ABC是固定的,所以叫时不变。
    SSM思路是从LTI线性时不变系统到线性最后到非线性,牺牲通用性,换得特定场景的高性能。
  1. 时序离散化与GNN
    在这里插入图片描述
    在这里插入图片描述

对连续系统进行离散化展开,使用零阶保持实现从连续系统转换为离散系统的ABC对应关系。 实际是求导变成了时间左移。
3. 并行处理与CNN
使用CNN对时序数据建模,不同时间尺度作为卷积核
在这里插入图片描述

在这里插入图片描述输出可以表示为上式,与离散卷积公式(下式)对比,可以发现二者是类似的,外部都是求和号,对不同时间步和相应权重的线性组合,输入x也只差了一个k的时间偏移,权重都是与随k变化的量。
离散卷积公式SSM只是换名字的CNN化的RNN。

Mamba:选择性SSM

由于SSM有两个强假设:线性和时不变,而实际系统大部分为非线性,时变的,所以SSM的应用范围较小。Mamba本质上是SSM模型的改进,放开这两个假设。他主要体现在设计一种机制,使得状态空间具备选择性,达到Transformer的建模能力,又在序列长度上实现了线性扩展,克服Transformer的缺陷。还使用GPU提高性能,并行扫描。
在这里插入图片描述
实现选择性,就是在中间部分使得B、C都变成t的函数,成为时变参数,A不是t的函数,但中间在使用delta进行离散化的时候也有t。蓝色部分就是选择性机制。delta_t相当于一个总开关,B_t和C_t相当于旋钮。

关注两种能力:一种是抓重点的能力,即从一句话中找出关键信息,忽略不相关的部分。另一种是上下文联想与推理能力,处理连续信息时保存逻辑一致性和上下文连贯性。

输入x_t通过三条通道影响B_t,两条通道影响C_t,两条通道影响A_t,delta函数是非线性的,所以就使得系统能够完成非线性时变。

步长delta像是放大镜观察窗口,较小时,模型倾向于忽略具体的单词,而更多依赖于之前的上下文信息。使用放大镜忽远忽近的看,实现注意力的选择。

Vision_Mamba的Vim.py源码学习

  • 导入库
from typing import Optional

import torch.nn as nn
import torch
import torch.functional as F
from timm.models.layers import DropPath, to_2tuple, trunc_normal_, lecun_normal_
from torch import Tensor
from functools import partial
from mamba_ssm.modules.mamba_simple import Mamba
from rope import *
import random

try:
    from mamba_ssm.ops.triton.layernorm import RMSNorm, rms_norm_fn, layer_norm_fn
except ImportError:
    RMSNorm, rms_norm_fn, layer_norm_fn = None, None, None
1. PatchEmbedding类
  • 图像维度(B, C, H, W) -> (B, embed_dim, grid_size[0], grid_size[1]) -> (B, num_patches, embed_dim)
  • 将图片切成小方块
  • img_size图像大小, patch_size小方块大小, stride步长, in_channels输入维度, embed_dim特征向量维度,norm_layer是否归一化, flatten指是否展平
  • grid_size其中两个元素都表示宽度和高度方向上的小块数量
class PatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size=16, stride=16, in_channels=3, embed_dim=768, norm_layer=None, flatten=True):
        # img_size图像大小, patch_size小方块大小, stride步长, in_channels输入维度, embed_dim特征向量维度,norm_layer是否归一化, flatten指是否展平
        super(PatchEmbedding, self).__init__()
        img_size = to_2tuple(img_size)   # 输出是一个包含两个元素的元组,为图像的宽度和高度。
        patch_size = to_2tuple(patch_size)
        self.img_size = img_size
        self.patch_size = patch_size

        # grid_size其中两个元素都表示宽度和高度方向上的小块数量
        self.grid_size = ((img_size[0] - patch_size[0]) // stride + 1, (img_size[0] - patch_size[0]) // stride + 1)
        self.num_patches = self.patch_size[0] * self.patch_size[1]  # 小方块数量
        self.flatten = flatten

        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=stride)
        self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()  # 初始化归一化层

    def forward(self, x):
        B, C, H, W = x.shape()
        # 如果图像尺寸不符合则报错。
        assert H == self.img_size[0] and W == self.img_size[1], \
            f"Input img size {(H) * (W)} doesen't match model({self.img_size[0]} * {self.img_size[1]})"

        x = self.proj(x)    # 切方块  (B, embed_dim, grid_size[0], grid_size[1])
        if self.flatten:
            x = x.flatten[2].transpose(1, 2)    # (B, grid_size[0] * grid_size[1], embed_dim)
        x = self.norm(x)    # 归一化

        return x
2. Block类
  • Mamba的Encoder层
  • dim是维度,mixer_cls是混合方式, norm_cls=nn.LayerNorm归一化方式, fused_add_norm=False是否使用融合的加法和归一化操作,residual_in_fp32=False残差连接是否以单精度浮点数(32位浮点数)进行,drop_path=0.表示路径丢弃(stochasticdepth)的概率
  • 这个block接受两个输入,分别是hidden_states, residual
  • 判断是否要使用混合相加标准化的方式
  • 将 residual 的数据类型转换为与self.norm.weight的数据类型相同,然后归一化
  • 返回更新后的 hidden_states 和 residual
class Block(nn.Module):  # encoder layer of Mamba
    def __init__(self, dim, mixer_cls, norm_cls=nn.LayerNorm,
                 fused_add_norm=False, residual_in_fp32=False, drop_path=0.):
        # dim是维度,mixer_cls是混合方式, norm_cls=nn.LayerNorm归一化方式, fused_add_norm=False是否使用融合的加法和归一化操作
        # residual_in_fp32=False残差连接是否以单精度浮点数(32位浮点数)进行
        # drop_path=0.表示路径丢弃(stochasticdepth)的概率

        super(Block, self).__init__()
        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.mixer = mixer_cls(dim)
        self.norm = norm_cls(dim)

        self.drop_path = DropPath(drop_path)

        if self.fused_add_norm:
            # 确保 RMSNorm 已经被正确导入
            # 检查 self.norm 是否是 nn.LayerNorm 或 RMSNorm 类的一个实例
            assert RMSNorm is not None, "RMSNorm import Fails"
            assert isinstance(
                self.norm, (nn.LayerNorm, RMSNorm)
            ), "Only LayerNorm and RMSNorm are supported for fused_add_norm"

    def forward(self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None):
        # 这个block接受两个输入,分别是hidden_states, residual(可选)
        if not self.fused_add_norm:  # 是否要使用混合相加标准化的方式
            if residual is None:
                residual = hidden_states
            else:
                residual = residual + self.drop_path(hidden_states)  # 残差连接
            # 用于将 residual 的数据类型转换为与self.norm.weight的数据类型相同。这是为了确保归一化层的权重和输入张量的数据类型一致,避免在计算过程中出现类型不匹配的问题。
            # self.norm归一化层
            hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))

            if self.residual_in_fp32:
                residual = residual.to(torch.float32)   # 残差连接是否以单精度浮点数(32位浮点数)进行

        else:
            # 如果 self.norm 是 RMSNorm 类的一个实例,则 fused_add_norm 被赋值为 rms_norm_fn;否则,fused_add_norm 被赋值为 layer_norm_fn。
            fused_add_norm = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
            if residual is None:
                hidden_states, residual = fused_add_norm(
                    hidden_states,
                    self.norm.weight,
                    self.norm.bias,
                    residual=residual,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    eps=self.norm.eps,  # 用于避免除以零的情况。它被加到方差上,以保证数值稳定性。
                )
            else:
                hidden_states, residual = fused_add_norm(
                    self.drop_path(hidden_states),
                    self.norm.weight,
                    self.norm.bias,
                    residual=residual,
                    prenorm=True,
                    residual_in_fp32=self.residual_in_fp32,
                    eps=self.norm.eps,
                )
        hidden_states = self.mixer(hidden_states, inference_params=None)

        return hidden_states, residual
3. create_block函数
  • 双向 Mamba 配置:如果 if_bi_mamba 为真,则将 bi_mamba_type 设置为 ‘v1’。
  • 工厂函数参数:创建 factory_kwargs 字典,包含设备和数据类型的参数,这些参数将用于创建模型层。
  • 创建 Mamba 混合器:使用 functools.partial 来创建一个 Mamba 类的偏函数,预设了 layer_idx、bi_mamba_type、if_devide_out、init_layer_scale 和 factory_kwargs。
  • 创建归一化类:同样使用 functools.partial 来创建一个归一化类的偏函数,根据 rms_norm 决定是使用 nn.LayerNorm 还是 RMSNorm,并设置 eps 和 factory_kwargs。
  • 创建 Block:实例化 Block 类,传入 d_model、mixer_cls、norm_cls、drop_path、fused_add_norm 和 residual_in_fp32 参数。
  • 返回 Block 实例
def create_block(
        d_model,
        ssm_cfg=None,   # ssm是否初始化
        norm_epsilon=1e-5,  # 归一化操作中使用的 epsilon 值,默认为 1e-5
        drop_path=0.,
        rms_norm=False,  # 是否使用均方根归一化
        residual_in_fp32=False,   # 是否将残差连接以单精度浮点数(32位浮点数)进行
        fused_add_norm=False,   # 是否使用融合的加法和归一化操作
        layer_idx=None,  # 层的索引
        device=None,
        dtype=None,
        if_bi_mamba=None,   # 是否使用双向 Mamba
        bi_mamba_type='none',   # 双向 Mamba 的类型
        if_devide_out=False,
        init_layer_scale=None,  # 初始化层缩放的配置
):
    if if_bi_mamba:
        bi_mamba_type = 'v1'
    if ssm_cfg is None:
        ssm_cfg = {}
    # 使用 factory_kwargs 字典,可以方便地将这些参数传递给工厂函数,而不需要在函数调用时显式地列出每一个参数,工厂函数通常用于创建类的实例。
    factory_kwargs = {"device": device, "dtype": dtype}

    mixer_cls = partial(
        Mamba,
        layer_idx=layer_idx,
        bi_mamba_type=bi_mamba_type,
        if_devide_out=if_devide_out,
        init_layer_scale=init_layer_scale,
        **ssm_cfg,
        **factory_kwargs
    )
    norm_cls = partial(
        nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
    )

    block = Block(
        dim=d_model,
        mixer_cls=mixer_cls,
        norm_cls=norm_cls,
        drop_path=drop_path,
        fused_add_norm=fused_add_norm,
        residual_in_fp32=residual_in_fp32
    )
    block.layer_idx = layer_idx

    return block
4. VisionMamba类
  • 特征提取:使用 self.patch_embed 将输入图像 x 转换为嵌入特征。
  • 分类令牌处理:根据配置,选择某种方式向特征中添加分类令牌
  • 位置编码:如果配置了绝对位置嵌入,将其添加到特征中。
  • 随机化处理:根据配置,可能对令牌顺序进行随机化。
  • 图像序列翻转:根据 flip_img_sequences_ratio 的配置,对图像序列进行翻转。
  • 编码器层处理:遍历 self.layers,对特征进行处理,包括残差连接、DropPath 正则化、归一化和特征混合。
  • 最终池化:根据 self.final_pool_type 的配置,对特征进行池化操作。
  • self.head 进行分类预测。
  • 根据 return_features 参数的设置,返回特征或分类结果。
class VisionMamba(nn.Module):
    def __init__(
            self,
            img_size=224,
            patch_size=16,
            stride=16,   # 划分小块时的步长
            depth=24,   # 架构中层的数量
            embed_dim=192,
            channels=3,
            num_classes=1000,   # 类别数
            ssm_cfg=None,   # 状态空间模型(State-Space Model)的配置
            drop_rate=0.,   # Dropout的正则化概率
            drop_path_rate=0.1,  # Droppath的正则化概率
            norm_epsilon: float = 1e-5,
            rms_norm=False,
            fused_add_norm=False,
            residual_in_fp32=False,
            device=None,
            dtype=None,
            pt_hw_seq_len=14,
            if_bi_directional=False,    # 是否使用双向处理
            final_pool_type='none',  # 最终池化类型
            if_abs_pos_embed=False,  # 是否使用绝对位置嵌入
            if_rope=False,   # 是否使用相对位置编码
            if_rope_residual=False,  # RoPE 是否使用残差连接
            flip_img_sequences_ratio=-1,    # 图像序列翻转的比例
            if_bi_mamba=False,
            bi_mamba_type='none',
            if_cls_token=False,  # 是否使用分类令牌
            if_devide_out=False,
            init_layer_scale=None,
            use_double_cls_token=False,  # 首尾两处分类令牌
            use_middle_cls_token=False,  # 中间位置分类令牌
            **kwargs
    ):
        factory_kwarges = {"device": device, "dtype": dtype}
        kwargs.update(factory_kwarges)
        super(VisionMamba, self).__init__()

        self.residual_in_fp32 = residual_in_fp32
        self.fused_add_norm = fused_add_norm
        self.if_bi_directional = if_bi_directional
        self.final_pool_type = final_pool_type
        self.if_abs_pos_embed = if_abs_pos_embed
        self.if_rope = if_rope
        self.if_rope_residual = if_rope_residual
        self.flip_img_sequences_ratio = flip_img_sequences_ratio
        self.if_cls_token = if_cls_token
        self.use_double_cls_token = use_double_cls_token
        self.use_middle_cls_token = use_middle_cls_token
        self.num_tokens = 1 if if_cls_token else 0
        self.num_classes = num_classes
        self.d_model = self.num_features = self.embed_dim = embed_dim

        self.patch_embed = PatchEmbedding(
            img_size=img_size, patch_size=patch_size, stride=stride, in_channels=channels, embed_dim=embed_dim
        )
        num_patches = self.patch_embed.num_patches

        # cls_token
        if if_cls_token:
            if use_double_cls_token:
                self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
                self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
                self.num_tokens = 2
            else:
                self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))

        # position embedding
        if if_abs_pos_embed:
            self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens + num_patches, self.embed_dim))
            self.pos_drop = nn.Dropout(drop_rate)

        # RoPE(Relative Positional Encoding)使用这些相对位置权重来调整不同元素之间的关系得分
        if if_rope:
            half_head_dim = embed_dim // 2
            hw_seq_len = img_size // patch_size
            self.rope = VisionRotaryEmbeddingFast(
                dim=half_head_dim,
                pt_seq_len=pt_hw_seq_len,
                ft_seq_len=hw_seq_len,
            )
            self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()

        # drop path rate
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
        # 创建一个从0开始逐渐增加到drop_path_rate的DropPath概率列表,列表的长度等于模型的深度。这样,模型的每一层都会有一个与之对应的DropPath概率,随着层的加深,DropPath的概率逐渐增加,从而实现更深层的正则化效果。
        # .item() 方法用于将Tensor中的值转换为标准的Python数值类型
        inter_dpr = [0.0] + dpr  # 在第一个阶段或层不进行DropPath,而后续的元素来自 dpr 列表,表示从第二层开始每层的DropPath概率。
        self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()

        # 初始化Encoder
        self.layers = nn.ModuleList(
            [
                create_block(
                    d_model=embed_dim,
                    ssm_cfg=ssm_cfg,
                    norm_epsilon=norm_epsilon,
                    rms_norm=rms_norm,
                    residual_in_fp32=residual_in_fp32,
                    fused_add_norm=fused_add_norm,
                    layer_idx=i,
                    if_bi_mamba=if_bi_mamba,
                    bi_mamba_type=bi_mamba_type,
                    drop_path=inter_dpr[i],
                    if_devide_out=if_devide_out,
                    init_layer_scale=init_layer_scale,
                    **factory_kwarges
                )
                for i in range(depth)
            ]
        )

        # 标准化层
        self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(embed_dim, eps=norm_epsilon, **factory_kwarges)

        if if_abs_pos_embed:
            trunc_normal_(self.pos_embed, std=.02)  # 张量初始化为截断的正态分布
        if if_cls_token:
            if use_double_cls_token:
                trunc_normal_(self.cls_token_head, std=.02)
                trunc_normal_(self.cls_token_tail, std=.02)
            else:
                trunc_normal_(self.cls_token, std=.02)

    # 前向特征传播
    def forward_features(self, x, inference_params=None,
                         if_random_cls_token_position=False,
                         if_random_token_rank=False):
        x = self.patch_embed(x)
        B, M, _ = x.shape

        # 拼接cls_token
        if self.if_cls_token:
            if self.use_double_cls_token:
                cls_token_head = self.cls_token_head.expand(B, -1, -1)
                cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
                # 定义分类令牌的位置。这里,0 表示序列的开始处放置 cls_token_head,M + 1 表示在所有嵌入小块之后放置 cls_token_tail。
                token_position = [0, M + 1]
                # 沿着第二个维度(dim=1)拼接起来
                x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
                M = x.shape[1]  # 更新 M 的值为新的序列长度

            else:
                if self.use_middle_cls_token:
                    cls_token = self.cls_token.expand(B, -1, -1)
                    token_position = M // 2
                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
                elif if_random_cls_token_position:
                    cls_token = self.cls_token.expand(B, -1, -1)
                    token_position = random.randint(0, M)
                    x = torch.cat((x[:, :token_position, :], cls_token, x[:, token_position:, :]), dim=1)
                    print("token_position", token_position)
                else:
                    cls_token = self.cls_token.expand(B, -1, -1)
                    token_position = 0
                    x = torch.cat((cls_token, x), dim=1)
                M = x.shape[1]

        # 装上位置编码
        if self.if_abs_pos_embed:
            x = x + self.pos_embed
            x = self.pos_drop(x)

        # 随机序列生成索引
        if if_random_token_rank:
            shuffle_indices = torch.randperm(M)  # 0到M-1的随机排列数组

            if isinstance(token_position, list):
                # 如果 token_position 是列表,并且列表中有2个元素,那么打印出序列中第一个批次(batch)中,
                # 位于 token_position[0] 和 token_position[1] 位置的令牌的特定维度(这里是第一个维度,索引为0)的值。
                print("original value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
            else:
                # 如果 token_position 不是列表,那么打印出序列中第一个批次中,位于 token_position 索引位置的令牌的特定维度的值。
                print("original value: ", x[0, token_position, 0])
            print("original token_position: ", token_position)
            # 以 shuffle_indices 作为第二个维度的索引,从而实现整个序列的随机洗牌,结果存储回 x
            x = x[:, shuffle_indices, :]

            # 更新位置索引
            # 使用 torch.where 函数找到 shuffle_indices 中等于 token_position[i] 的索引位置
            # torch.where 返回一个元组,其中第一个元素包含满足条件的索引,通过 [0] 获取这个索引张量,然后通过 .item() 将其转换为一个标量值
            if isinstance(token_position, list):
                new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item()
                                      for i in range(len(token_position))]
                token_position = new_token_position
                print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
            else:
                token_position = torch.where(shuffle_indices == token_position)[0].item()
                print("new walue: ", x[0, token_position, 0])
            print("new token_position: ", token_position)

        # 翻转操作
        if_flip_img_sequences = False   # 记录是否执行了图像序列的翻转操作
        # 即使 flip_img_sequences_ratio 为一个正值,翻转操作也不是每次都执行,而是以这个随机比率的概率执行。
        if self.flip_img_sequences_ratio > 0 and (self.flip_img_sequences_ratio - random.random()) > 1e-5:
            x = x.flip([1])  # x.flip([1]) 表示沿着第一个维度进行翻转
            if_flip_img_sequences = True

        # Mamba
        residual = None
        hidden_states = x

        # 根据模型的配置和之前的操作(如翻转和RoPE)来更新隐藏状态和残差连接。
        if not self.if_bi_directional:  # 单向的,则执行循环
            for layer in self.layers:

                if if_flip_img_sequences and self.if_rope:
                    # 如果之前已经执行了图像序列的翻转(if_flip_img_sequences 为 True),
                    # 并且模型包含相对位置编码(self.if_rope 为 True),则对隐藏状态(hidden_states)进行翻转操作
                    hidden_states = hidden_states.flip([1])     # 通常意味着沿着宽度或高度翻转图像。
                    if residual is not None:    # 如果残差连接(residual)存在,则同样对其进行翻转操作。
                        residual = residual.flip([1])

                if self.if_rope:
                    # 对 hidden_states 应用相对位置编码
                    hidden_states = self.rope(hidden_states)
                    # 如果残差连接存在,并且模型配置为在残差连接上应用 RoPE(self.if_rope_residual 为 True)
                    if residual is not None and self.if_rope_residual:
                        residual = self.rope(residual)  # 对残差连接应用相对位置编码

                hidden_states, residual = layer(
                    hidden_states, residual, inference_params=inference_params
                )

        else:   # 双向
            for i in range(len(self.layers) // 2):
                if self.rope:
                    hidden_states = self.rope(hidden_states)
                    if residual is not None and self.if_rope_residual:
                        residual = self.rope(residual)

                hidden_states_f, residual_f = self.layers[i*2](
                    hidden_states, residual, inference_params=inference_params
                )
                hidden_states_b, residual_b = self.layers[i*2 + 1](
                    hidden_states.flip([1]),
                    None if residual == None else residual.flip([1]),
                    inference_params=inference_params
                )

                hidden_states = hidden_states_f + hidden_states_b.flip([1])
                residual = residual_f + residual_b.flip([1])

        if not self.fused_add_norm:
            if residual is None:
                residual = hidden_states
            else:
                residual = residual + self.drop_path(hidden_states)
            # 将残差连接的结果 residual 转换为归一化层 self.norm_f 的权重数据类型,然后应用归一化
            hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))

        else:
            fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm) else layer_norm_fn
            hidden_states = fused_add_norm_fn(
                self.drop_path(hidden_states),
                self.norm_f.weight,
                self.norm_f.bias,
                eps=self.norm_f.eps,
                residual=residual,
                residual_in_fp32=self.residual_in_fp32
            )

        # 根据不同的配置,从模型的输出中提取并返回分类令牌的特征表示。
        if self.if_cls_token:
            if self.use_double_cls_token:
                # 返回两个分类令牌的特征表示的平均值。
                return (hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
            else:
                # 直接返回该令牌的特征表示。
                if self.use_middle_cls_token:
                    return hidden_states[:, token_position, :]
                elif if_random_cls_token_position:
                    return hidden_states[:, token_position, :]
                else:
                    return hidden_states[:, token_position, :]

        # 池化
        if self.final_pool_type == 'none':
            return hidden_states[:, -1, :]
        elif self.final_pool_type == 'mean':
            return hidden_states.mean(dim=1)
        elif self.final_pool_type == 'max':
            return hidden_states
        elif self.final_pool_type == 'all':
            return hidden_states
        else:
            raise NotImplementedError

    def forward(self, x,
                return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
        x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position,
                                  if_random_token_rank=if_random_token_rank)
        if return_features:
            return x

        x = self.head(x)
        if self.final_pool_type == "max":
            x = x.max(dim=1)[0]
        return x
5. test函数
  • device 变量根据系统是否支持 CUDA(GPU)来选择运行设备,优先使用 GPU,如果不支持则回退到 CPU。
  • 创建 VisionMamba 类的实例 model。
  • .to(device) 将模型移动到选定的设备上。
  • 创建一个随机初始化的输入张量 x,其尺寸为 (4, 3, 224, 224),表示一个包含 4 张图像的批次,每张图像有 3 个通道,尺寸为 224x224。然后将输入数据移动到与模型相同的设备上。
  • 通过 model(x) 执行模型的前向传播,得到预测结果 preds。
def test():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = VisionMamba(
        patch_size=16,
        embed_dim=192,
        depth=24,
        rms_norm=True,
        residual_in_fp32=True,
        fused_add_norm=True,
        final_pool_type="mean",
        if_abs_pos_embed=True,
        if_rope=False,
        if_rope_residual=False,
        bi_mamba_type="V2",
        if_cls_token=True,
        if_devide_out=True,
        use_middle_cls_token=True
    ).to(device)

    x = torch.randn(size=(4, 3, 224, 224)).to(device)
    preds = model(x)
    print(f"preds shape if {preds.shape}")


if __name__ == "__main__":
    test()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/671808.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【自然语言处理】【Scaling Law】Observational Scaling Laws:跨不同模型构建Scaling Law

相关博客 【自然语言处理】【Scaling Law】Observational Scaling Laws:跨不同模型构建Scaling Law 【自然语言处理】【Scaling Law】语言模型物理学 第3.3部分:知识容量Scaling Laws 【自然语言处理】Transformer中的一种线性特征 【自然语言处理】【大…

关于苹果发布IOS18系统,以及Siri升级贾维斯

随着科技的不断进步,手机操作系统也在持续升级,为用户提供更加智能化、便捷化的体验。近期,苹果公司即将推出的iOS 18系统引起了广泛关注。作为iPhone历史上的重大更新,iOS 18系统带来了众多新功能,将进一步提升iPhone…

美国科技股为何突然崩了?

英伟达毛利率那么高,谁来“买单”?高盛认为,投资AI的成本巨大,引发了市场对科技股盈利能力和估值合理性的担忧。软件股今年以来的疲态,可能也反映了投资者对AI的担忧。 直到最近还势不可挡的科技股突然崩塌。 隔夜美…

Java基础知识点(标识符、数据类型、变量、运算符、包机制、流程控制、方法、数组)

文章目录 标识符数据类型强弱类型语言数据类型基础类型 类型转换 常量与变量变量的定义变量作用域变量命名规范常量 运算符包机制流程控制选择结构循环结构 方法(Method)数组概述申明创建java.util.Arrays类 标识符 Java标识符的命名规则如下&#xff1…

SIMBA:单细胞嵌入与特征

目前大多数单细胞分析管道仅限于细胞嵌入,并且严重依赖于聚类,而缺乏显式建模不同特征类型之间相互作用的能力。此外,这些方法适合于特定的任务,因为不同的单细胞问题的表述方式不同。为了解决这些缺点,SIMBA作为一种图…

RabbitMQ二、RabbitMQ的六种模式

一、RabbitMQ的六种模式 RabbitMQ共有六种工作模式: 简单模式(Simple)工作队列模式(Work Queue)发布订阅模式(Publish/Subscribe)路由模式(Routing)通配符模式&#xff…

ThinkPHP5发送邮件如何配置?有哪些技巧?

ThinkPHP5发送邮件的性能怎么优化?批量发信的方法? 邮件发送功能是许多应用程序的关键组成部分,尤其是在用户注册、密码重置和通知等功能中尤为重要。AokSend将详细介绍如何在thinkphp5中配置和使用邮件发送功能,并确保你可以轻松…

DPDK基础组件二(igb_uio、kni、rcu)

The Linux driver implementer’s API guide — The Linux Kernel documentation 一、igb_uid驱动 参考博客:https://zhuanlan.zhihu.com/p/543217445 UIO(Userspace I/O)是运行在用户空间的I/O技术 代码位置:dpdk----/kernel/linux/igb_uio目录 igb_uio 是 dpdk 内部实…

从0开发一个Chrome插件:搭建开发环境

前言 这是《从0开发一个Chrome插件》系列的第三篇文章,本系列教你如何从0去开发一个Chrome插件,每篇文章都会好好打磨,写清楚我在开发过程遇到的问题,还有开发经验和技巧。 《从0开发一个Chrome插件》专栏: 从0开发一…

文章解读与仿真程序复现思路——电力系统自动化EI\CSCD\北大核心《考虑动态定价的新能源汽车能源站优化运行》

本专栏栏目提供文章与程序复现思路,具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

Linux网络-守护进程版字典翻译服务器

文章目录 前言一、pid_t setsid(void);二、守护进程翻译字典服务器(守护线程版)效果图 前言 根据上章所讲的后台进程组和session会话,我们知道如果可以将一个进程放入一个独立的session,可以一定程度上守护该进程。 一、pid_t se…

Vite项目构建chrome extension,实现多入口

本项目使用Vite5 Vue3进行构建。 要使用vite工程构建浏览器插件,无非就是要实现popup页面和options页面。这就需要在项目中用到多入口打包(生成多个html文件)。 实现思路: 通过配置vite工程,使得项目打包后有两个h…

项目3 构建移动电商服务器集群

项目引入 经过前期加班加点地忙碌,我们的网站顺利上线了!年中促销活动也如约而至,虽然公司全体对这次活动进行多方面地准备和“布防”,可是意外还是发生了。就在促销优惠购物活动的当天,猛然增加的用户访问量直接导致浏…

SpringBoot-SchedulingConfigurer源码初识:理解定时任务抛异常终止本次调度,但不会影响下一次执行调度

SchedulingConfigurer源码初识:理解定时任务抛异常终止本次调度,但不会影响下一次执行调度 EnableSchedulingScheduledAnnotationBeanPostProcessor进入finishRegistration方法 ScheduledTaskRegistrar处理触发器任务(TriggerTask&#xff09…

回溯算法之电话号码字母组合

题目: 给定一个仅包含数字 2-9 的字符串,返回所有它能表示的字母组合。答案可以按 任意顺序 返回。 给出数字到字母的映射如下(与电话按键相同)。注意 1 不对应任何字母。 示例 1: 输入:digits "2…

【python】多线程(3)queue队列之不同延时时长的参数调用问题

链接1:【python】多线程(笔记)(1) 链接2:【python】多线程(笔记)(2)Queue队列 0.问题描述 两个线程,但是不同延时时长,导致数据输出…

vue 引用第三方库 Swpier轮播图

本文全程干货,没有废话 1.使用 npm 安装 swiper,使用 save 保存到 packjson 中 npm install --save swiper 2、把 swiper看成是第三方库或者是组件,然后按照,引用,挂载组件,使用组件三步法。 3、在 script…

overleaf 写参考文献引用

目录 1、 新建.bib 文件 2、导入引用 3、在文档中引用参考文献 4、生成参考文献列表 1、 新建.bib 文件 在Overleaf项目中,你可以选择导入现有的 .bib 文件或在项目中创建一个新的 .bib 文件来管理你的参考文献。 导入.bib 文件: 在项目文件树中点击…

1985-2020 年阿拉斯加和育空地区按植物功能类型划分的模型表层覆盖率

ABoVE: Modeled Top Cover by Plant Functional Type over Alaska and Yukon, 1985-2020 1985-2020 年阿拉斯加和育空地区按植物功能类型划分的模型表层覆盖率 简介 文件修订日期:2022-05-31 数据集版本: 1.1 本数据集包含阿拉斯加和育空地区北极和北方地区按…

C语言| 输出菱形*

C语言| 输出*三角形-CSDN博客 输出菱形。 【分析思路】 学会输出*的三角形之后输出菱形就很简单了。我们分析一下,菱形是由两个对称的三角形组成的,也因为是对称的,所以输出的菱形的行数肯定是一个奇数。 1 我们在编程的时候,要…