Mamba-minimal Mamba的最小限度实现 (一)

文章目录

    • 参数和数据尺寸约定
    • class MambaBlock
      • def forward
      • def __ int__
      • def ssm
      • def selective_scan

johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)

manba的简单最小限度实现,和原始论文实现state-spaces/mamba (github.com)](https://github.com/state-spaces/mamba/tree/main)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里介绍Mamba Block的实现

参数和数据尺寸约定

之后的数据尺寸以(b, l, d_in) 或者(b, l, d_model, d_state)简单表示

参数及简写Mamba论文简写
batch_size bB
序列长度 lL
隐藏维度 d / d_model
潜在状态维度 n / d_stateN
扩展因子 expandE
d_in / d_innerD
数据依赖步长 Δ \Delta Δ / delta
delta秩 dt_rank

class MambaBlock

def forward

根据forward简单梳理MambaBlock的结构

在这里插入图片描述

中间变量来源shape
输入x(b, l, d_model)
x_and_resx经过输入映射后(b, l, 2* d_in)
x切分后作为ssm分支输入(b, l, d_in)
res切分后作为门控分支输入(b, l, d_in)
y经过卷积,激活,ssm,门控后的输出(b, l, d_in)
outputy经过输出映射后得到(b, l, d_model)
def forward(self, x):
  
        (b, l, d) = x.shape
        
        x_and_res = self.in_proj(x)  # shape (b, l, 2 * d_in)
        (x, res) = x_and_res.split(split_size=[self.args.d_inner, self.args.d_inner], dim=-1)

        x = rearrange(x, 'b l d_in -> b d_in l')
        x = self.conv1d(x)[:, :, :l]
        x = rearrange(x, 'b d_in l -> b l d_in')
        
        x = F.silu(x)

        y = self.ssm(x)
        
        y = y * F.silu(res)
        
        output = self.out_proj(y)

        return output
    

def __ int__

初始化主要初始了几个部分

组件定义

操作及简写维度变换
输入映射 in_proj(b, l, d_model) -> (b, l, 2*d_in)
序列变换 conv1d只取前l (b, d_in, l) -> (b, d_in, l)
非线性激活 silu
输出映射 out_proj(b, l, d_in) -> (b, l, d)

在这里插入图片描述
ssm初始化

操作及简写作用
参数生成映射 x_proj生成数据依赖的参数B, C, Δ \Delta Δ
delta映射 dt_proj Δ \Delta Δ从dt_rank映射到d_in
矩阵A初始化简单初始化
矩阵D初始化简单初始化
def __init__(self, args: ModelArgs):

        super().__init__()
        self.args = args

        self.in_proj = nn.Linear(args.d_model, args.d_inner * 2, bias=args.bias)

        self.conv1d = nn.Conv1d(
            in_channels=args.d_inner,
            out_channels=args.d_inner,
            bias=args.conv_bias,
            kernel_size=args.d_conv,
            groups=args.d_inner,
            padding=args.d_conv - 1,
        )
        
         # ssm模型的初始化部分
         # x_proj takes in `x` and outputs the input-specific Δ, B, C
        self.x_proj = nn.Linear(args.d_inner, args.dt_rank + args.d_state * 2, bias=False)
        
        # dt_proj projects Δ from dt_rank to d_in
        self.dt_proj = nn.Linear(args.dt_rank, args.d_inner, bias=True)

        A = repeat(torch.arange(1, args.d_state + 1), 'n -> d n', d=args.d_inner)
        self.A_log = nn.Parameter(torch.log(A))
        self.D = nn.Parameter(torch.ones(args.d_inner))
        self.out_proj = nn.Linear(args.d_inner, args.d_model, bias=args.bias)

def ssm

这是我们数据处理流水线的搭建,这一部分是ssm模型参数定义,是ssm模型中相对于数据“不变”的部分。

SSM参数shape来源
状态矩阵A(d_in, n)在初始化中定义,非数据依赖
输入矩阵B(b, l, n)由x_db1切分而来,因此数据依赖
输出矩阵C(b, l, n)由x_db1切分而来,因此数据依赖
直接传递矩阵D(d_in)在初始化中定义,非数据依赖
数据依赖步长 Δ \Delta Δ(b, l, d_in)由x_db1切分而来,因此数据依赖

其中一部分变量初始化于class MambaBlock的初始化部分

中间变量及简写来源
数据生成变量 x_db1x经过参数映射x_proj生成
最终delta Δ \Delta Δ切分而来的 Δ \Delta Δ经过映射和softplus
 def ssm(self, x):

        (d_in, n) = self.A_log.shape
        
        A = -torch.exp(self.A_log.float())  # shape (d_in, n)
        D = self.D.float()

        x_dbl = self.x_proj(x)  # (b, l, dt_rank + 2*n)
        
        (delta, B, C) = x_dbl.split(split_size=[self.args.dt_rank, n, n], dim=-1)  # delta: (b, l, dt_rank). B, C: (b, l, n)
        delta = F.softplus(self.dt_proj(delta))  # (b, l, d_in)
        
        y = self.selective_scan(x, delta, A, B, C, D)  # This is similar to run_SSM(A, B, C, u) in The Annotated S4 [2]
        
        return y
SSM参数shape
状态矩阵A(d_in, n)
输入矩阵B(b, l, n)
输出矩阵C(b, l, n)
直接传递矩阵D(d_in)

def selective_scan

我们的数据流水线搭建好以后,接下来就要让它动起来,这一部分是数据处理的动态或者动力。

在这里插入图片描述

在这里, A A A使用ZOH零阶保持离散化, B B B则简化为欧拉离散化

前向欧拉离散化
x k = ( I + Δ k A ) x k − 1 + Δ k B ⋅ u k x ( t + Δ ) = ( I + Δ A ) x ( t ) + Δ B ⋅ u ( t ) \begin{aligned} x_{k}& \begin{aligned}=(\boldsymbol{I}+\Delta_{k}\boldsymbol{A})x_{k-1}+\Delta_{k}\boldsymbol{B}\cdot u_{k}\end{aligned} \\ x(t+\Delta)& =(\boldsymbol{I}+\Delta\boldsymbol{A})x(t)+\Delta\boldsymbol{B}\cdot u(t) \end{aligned} xkx(t+Δ)=(I+ΔkA)xk1+ΔkBuk=(I+ΔA)x(t)+ΔBu(t)

零阶保持离散化
x k = e Δ k A x k − 1 + ( Δ k A ) − 1 ( e Δ k A − I ) ⋅ Δ k B ⋅ u k x ( t + Δ ) = e Δ A x ( t ) + ( Δ A ) − 1 ( e Δ A − I ) ⋅ Δ B ⋅ u ( t ) \begin{aligned} x_{k}& =e^{\Delta_{k}\boldsymbol A}x_{k-1}+(\Delta_{k}\boldsymbol A)^{-1}(e^{\Delta_{k}\boldsymbol A}-\boldsymbol{I})\cdot\Delta_{k}\boldsymbol B\cdot u_{k} \\ x(t+\Delta)& =e^{\Delta \boldsymbol A}x(t)+(\Delta \boldsymbol A)^{-1}(e^{\Delta \boldsymbol A}-\boldsymbol{I})\cdot\Delta \boldsymbol B\cdot u(t) \end{aligned} xkx(t+Δ)=eΔkAxk1+(ΔkA)1(eΔkAI)ΔkBuk=eΔAx(t)+(ΔA)1(eΔAI)ΔBu(t)

这里selective_scan是顺序形式,因此与原论文CUDA编写的并行感知算法相比要慢

def selective_scan(self, u, delta, A, B, C, D):

        (b, l, d_in) = u.shape
        n = A.shape[1]
        

        deltaA = torch.exp(einsum(delta, A, 'b l d_in, d_in n -> b l d_in n'))
        deltaB_u = einsum(delta, B, u, 'b l d_in, b l n, b l d_in -> b l d_in n')
        
        # Perform selective scan (see scan_SSM() in The Annotated S4 [2])

        x = torch.zeros((b, d_in, n), device=deltaA.device)
        ys = []    
        for i in range(l):
            x = deltaA[:, i] * x + deltaB_u[:, i]
            y = einsum(x, C[:, i, :], 'b d_in n, b n -> b d_in')
            ys.append(y)
        y = torch.stack(ys, dim=1)  # shape (b, l, d_in)
        
        y = y + u * D
    
        return y

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/441951.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

智能音箱技术解析

目录 前言智能音箱执行步骤解析1.1 探测唤醒词或触发词1.2 语音识别1.3 意图识别1.4 执行指令 2 典型的智能音箱2.1 百度小度音响2.2 小米小爱同学2.3 苹果 HomePod 3 功能应用举例3.1 设置计时器3.2 播放音乐 结语 前言 智能音箱已经成为日常生活中不可或缺的一部分&#xff…

亚信安慧AntDB:为数据安全和稳定而生

AntDB充分考虑了用户的需求,将用户体验置于优先位置,通过深入分析用户的使用情况,对数据库的性能和功能进行了全方位的优化。无论是对于小规模应用还是大规模企业级系统,AntDB都能够提供稳定高效的数据库服务,满足不同…

[BUG] docker运行Java程序时配置代理-Dhttp.proxyHost后启动报错

[BUG] docker运行Java程序时配置代理-Dhttp.proxyHost后启动报错 bug现象描述 版本:2.0.4(客户端和服务端都是) 环境:私有云环境,只有少量跳板机器可以访问公网,其他机器均通过配置代理方式访问公网 bug现…

新一代 Git 工具,AI 赋能!深度集成、简化操作 | 开源日报 No.194

gitbutlerapp/gitbutler Stars: 7.2k License: NOASSERTION gitbutler 是一个基于 Git 的版本控制客户端。旨在为现代工作流程构建一个全新的 Git 分支管理工具。 虚拟分支:可以同时在多个分支上工作,而无需不断切换分支简化提交管理:通过拖…

码垛【FB块】

转载&#xff1a; FUNCTION BLOCK 码垛 VAR INPUT 当前数:INT; 点l:Point; 点2:Point; X行数:REAL; Y列数:REAL; 2层数:REAL; END VAR VAR OUTPUT 目标点:Point; 点数量:INT; END VAR VAR // X差值:点2.x-点1.x; IF X行数>1 AND X差值<>0 THEN X间隔:X差值/(X行数-1)…

07.axios封装实例

一.简易axios封装-获取省份列表 1. 需求&#xff1a;基于 Promise 和 XHR 封装 myAxios 函数&#xff0c;获取省份列表展示到页面 2. 核心语法&#xff1a; function myAxios(config) {return new Promise((resolve, reject) > {// XHR 请求// 调用成功/失败的处理程序}) …

偶极子和环形天线的辐射机理仿真分析

目录 0 引言 1 偶极子天线的辐射因素分析 1.1 偶极子天线模型设计 1.2 谐振点的出现规律 1.3 天线尺寸对辐射的影响 1.4 天线角度对辐射的影响

c++ primer plus 第十五章笔记 友元,异常和其他

友元类&#xff1a; 两个类不存在继承和包含的关系&#xff0c;但是我想通过一个类的成员函数来修改另一个类的私有成员和保护成员的时候&#xff0c;可以使用友元类。 class A {private:int num;//私有成员//...public: //...friend class B;//声明一个友元类 }class…

ChatGPT Plus 自动扣费失败,如何续订

ChatGPT Plus 自动扣费失败&#xff0c;如何续订 如果您的 ChatGPT Plus 订阅过期或扣费失败&#xff0c;本教程将指导您如何重新订阅。 本周更新 ChatGPT Plus 是一种每月20美元的订阅服务。扣费会自动进行&#xff0c;如果您的账户余额不足&#xff0c;OpenAI 将在一次扣费…

css 背景图片居中显示

background 简写 background: #ffffff url(https://profile-avatar.csdnimg.cn/b9abdd57de464582860bf8ade52373b6_misnice.jpg) center center / 100% no-repeat;效果如图&#xff1a;

git - 笔记

为什么要学习Git 为什么要学习Git软件 为什么学习 因为在主流开发中&#xff0c;基于互联网软件开发的项目都会使用Git软件来进行项目开发过程中的资源管理 比如人力资源 代码资源 比如前端资源 .html .java等代码资源 文档资源 像项目开发中涉及到的需求文档等 这种项目中管理…

CRM术语速览:掌握这十个专业名词,成为CRM专家

无论您是销售人员还是采购经理&#xff0c;熟悉CRM管理系统专业术语都是一门必修课。擅于运用CRM专业术语帮助您理解CRM管理系统的功能、更好的开展业务。本文与您分享不得不知道的十大CRM专业术语&#xff0c;CRM常用术语合集。常见的CRM术语包括MQL、SQL、SDR、销售漏斗等等。…

带摄像头的 AirPods,苹果会怎么做出来?

苹果对智能产品的设计&#xff0c;正在放飞自我。 根爆料&#xff0c;苹果在「未来设备」的规划里&#xff0c;有两个大胆的想法&#xff1a; 一是带有屏幕的 HomePod 正在研发中&#xff0c;当中将集成 Apple TV、FaceTime 等重多功能&#xff1b;二是配备摄像头的 AirPod…

【ARM Trace32(劳特巴赫) 高级篇 21 -- SystemTrace ITM 使用介绍】

文章目录 SystemTrace ITMSystemTrace ITM 常用命令Trace Data AnalysisSystemTrace ITM CoreSight ITM (Instrumentation Trace Macrocell) provides the following information: Address, data value and instruction address for selected data cyclesInterrupt event info…

201909青少年软件编程(Scratch)等级考试试卷(三级)

青少年软件编程&#xff08;Scratch&#xff09;等级考试试卷&#xff08;三级&#xff09;2019年9月 第1题&#xff1a;【 单选题】 执行下面的脚本后&#xff0c;变量“分数”的值是多少&#xff1f;&#xff08;&#xff09; A:5 B:6 C:10 D:25 【正确答案】: C 【试题…

【Java网络编程】TCP核心特性(下)

1. 拥塞控制 拥塞控制&#xff1a;是基于滑动窗口机制下的一大特性&#xff0c;与流量控制类似都是用来限制发送方的传送速率的 区别就在于&#xff1a;"流量控制"是从接收方的角度出发&#xff0c;根据接收方剩余接收缓冲区大小来动态调整发送窗口的&#xff1b;而…

文章解读与仿真程序复现思路——电网技术EI\CSCD\北大核心《考虑多能互补灵活性和用户低碳意愿的区域综合能源系统鲁棒优化调度》

本专栏栏目提供文章与程序复现思路&#xff0c;具体已有的论文与论文源程序可翻阅本博主免费的专栏栏目《论文与完整程序》 论文与完整源程序_电网论文源程序的博客-CSDN博客https://blog.csdn.net/liang674027206/category_12531414.html 电网论文源程序-CSDN博客电网论文源…

STM32 SDRAM知识点

1.SDRAM和SRAM的区别 SRAM不需要刷新电路即能保存它内部存储的数据。而SDRAM&#xff08;Dynamic Random Access Memory&#xff09;每隔一段时间&#xff0c;要刷新充电一次&#xff0c;否则内部的数据即会消失&#xff0c;因此SRAM具有较高的性能&#xff0c;但是SRAM也有它…

npm 操作报错记录1- uninstall 卸载失效

npm 操作报错记录1- uninstall 卸载失效 1、问题描述 安装了包 vue/cli-plugin-eslint4.5.0 vue/eslint-config-prettier9.0.0 但是没有使用 -d &#xff0c;所以想重新安装&#xff0c;就使用 uninstall 命令卸载&#xff0c;结果卸载了没反应&#xff0c;也没有报错&#xf…

一文帮助快速入门Django

文章目录 创建django项目应用app配置pycharm虚拟环境打包依赖 路由传统路由include路由分发namenamespace 视图中间件orm关系对象映射操作表数据库配置model常见字段及参数orm基本操作 cookie和sessiondemo类视图 创建django项目 指定版本安装django&#xff1a;pip install dj…