Mamba-minimal Mamba的最小限度实现 (二)

文章目录

    • 链接
    • 导入所需包
    • class ModelArgs
    • class Mamba
      • def __ init __
      • def forward
    • class ResidualBlock
    • class RNSNorm
    • 文本生成demo

manba的简单最小限度实现,和原始论文实现 state-spaces/mamba (github.com)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里是剩余部分介绍,主要包括利用MambaBlock和其他组件如残差连接,归一化等定义一个序列模型。

MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客

链接

来自johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)

导入所需包

from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum

class ModelArgs

模型参数设置

参数介绍
d_model模型维度,和输入数据通道对应
n_layer残差块的数目
d_state潜在状态维度
expand扩展因子,d_in = d_state * state
dt_rankdelta的秩
d_conv1D卷积的卷积核大小
vocab_size词汇表的大小
pad_vocab_size_multiple确保vocab_size是设定值的倍数
conv_bias1D卷积的bias选项
biaslm_head映射的bias选项
@dataclass
class ModelArgs:
    d_model: int
    n_layer: int
    vocab_size: int
    d_state: int = 16
    expand: int = 2
    dt_rank: Union[int, str] = 'auto'
    d_conv: int = 4 
    pad_vocab_size_multiple: int = 8
    conv_bias: bool = True
    bias: bool = False
   
    
    def __post_init__(self):
        self.d_inner = int(self.expand * self.d_model)
        
        if self.dt_rank == 'auto':
            self.dt_rank = math.ceil(self.d_model / 16)
            
        if self.vocab_size % self.pad_vocab_size_multiple != 0:
            self.vocab_size += (self.pad_vocab_size_multiple
                                - self.vocab_size % self.pad_vocab_size_multiple)

class Mamba

一个完整的序列处理Mamba模型,包含多个被包裹的MambaBlock。

nn.Embedding参照深度学习:pytorch nn.Embedding详解-CSDN博客

lm_head层则是预测下一个token的输出层,它将模型的输出映射到一个概率分布上,以便于模型预测下一个token,权重和Embedding公用。

输入一个序列 x ( b a t c h _ s i z e , l e n g t h ) x(batch\_size, length) x(batch_size,length) 简写为 ( b , l ) (b, l) (b,l),输出取词的概率 ( b , l , v o c a b _ s i z e ) (b, l, vocab\_size) (b,l,vocab_size)

组件尺寸变换
embedding(b, l) -> (b, l, d_model)
layers(b, l, d_model) -> (b, l, d_model)
norm_f\
lm_head(b, l, d_model) -> (b, l, vocab_size)

def __ init __

class Mamba(nn.Module):
    def __init__(self, args: ModelArgs):
        """Full Mamba model."""
        super().__init__()
        self.args = args
        
        self.embedding = nn.Embedding(args.vocab_size, args.d_model)
        self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
        self.norm_f = RMSNorm(args.d_model)

        self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
        self.lm_head.weight = self.embedding.weight  # Tie output projection to embedding weights.
                                                     # See "Weight Tying" paper

def forward

 def forward(self, input_ids):
      
        x = self.embedding(input_ids)
        
        for layer in self.layers:
            x = layer(x)
            
        x = self.norm_f(x)
        logits = self.lm_head(x)
        return logits

class ResidualBlock

一个包裹MambaBlock的一个残差块

MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客

class ResidualBlock(nn.Module):
    def __init__(self, args: ModelArgs):

        super().__init__()
        self.args = args
        self.mixer = MambaBlock(args)
        self.norm = RMSNorm(args.d_model)
        

    def forward(self, x):
   
        output = self.mixer(self.norm(x)) + x

        return output
            

class RNSNorm

所用到的归一化

可以参考RMSNorm论文阅读-CSDN博客

LLM中的RMSNorm - 知乎 (zhihu.com)

class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(d_model))


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

文本生成demo

来自demo.ipynb

这里是一个colab_demo

加载模型

from model import Mamba, ModelArgs
from transformers import AutoTokenizer

# One of:
#     'state-spaces/mamba-2.8b-slimpj'
#     'state-spaces/mamba-2.8b'
#     'state-spaces/mamba-1.4b'
#     'state-spaces/mamba-790m'
#     'state-spaces/mamba-370m'
#     'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'

model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')

生成文本
在概率为top-k的输出中采样

import torch
import torch.nn.functional as F


def generate(model,
             tokenizer,
             prompt: str,
             n_tokens_to_gen: int = 50,
             sample: bool = True,
             top_k: int = 40):
    model.eval()
    
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids
    
    for token_n in range(n_tokens_to_gen):
        with torch.no_grad():
            indices_to_input = input_ids
            next_token_logits = model(indices_to_input)[:, -1]
        
        probs = F.softmax(next_token_logits, dim=-1)
        (batch, vocab_size) = probs.shape
        
        if top_k is not None:
            (values, indices) = torch.topk(probs, k=top_k)
            probs[probs < values[:, -1, None]] = 0
            probs = probs / probs.sum(axis=1, keepdims=True)
        
        if sample:
            next_indices = torch.multinomial(probs, num_samples=1)
        else:
            next_indices = torch.argmax(probs, dim=-1)[:, None]
        
        input_ids = torch.cat([input_ids, next_indices], dim=1)

    output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
    
    return output_completions

在这里插入图片描述

print(generate(model, tokenizer, 'Mamba is the'))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/442781.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

c++ 常用的STL

前言 写这篇博客目的是为了记录在刷算法题中使用过的STL&#xff0c;因为有些不太常用的会遗忘。这篇博客只是作为笔记&#xff0c;不是详细的STL&#xff0c;因此只会对常用方法说明&#xff0c;不会详细介绍。此外在后面用到新的STL内容时会再补充。 列队 基础列队 基本列…

Prompt进阶系列1:LangGPT(从编程语言反思LLM的结构化可复用提示设计框架)

Prompt进阶系列1:LangGPT(从编程语言反思LLM的结构化可复用提示设计框架) 大语言模型 (Large Language Models, LLMs) 在不同领域都表现出了优异的性能。然而&#xff0c;对于非AI专家来说&#xff0c;制定高质量的提示来引导 LLMs 是目前AI应用领域的一项重要挑战。现有的提示…

ARM/Linux嵌入式面经(二):芯片原厂

uart如何进行通信&#xff0c;模块发给uart数据信息后经历了什么 UART&#xff08;Universal Asynchronous Receiver/Transmitter&#xff0c;通用异步收发传输器&#xff09;是一种用于串行通信的协议&#xff0c;它使用一对传输线&#xff08;TX和RX&#xff09;进行双向通信…

阿里云服务器地域所在城市分布表_北京杭州深圳及全球区域

2024年最新阿里云服务器地域分布表&#xff0c;地域指数据中心所在的地理区域&#xff0c;通常按照数据中心所在的城市划分&#xff0c;例如华北2&#xff08;北京&#xff09;地域表示数据中心所在的城市是北京。阿里云地域分为四部分即中国、亚太其他国家、欧洲与美洲和中东&…

Autodl中tar.gz无法解压问题

在云服务器中&#xff0c;我们常把文件压缩转移至其他实例中&#xff0c;但是我们有时会发现&#xff0c;压缩可以但是解压缩却遇到了问题&#xff0c;下面给出了正确的压缩和解压指令。 1 压缩 tar -cvf archieve.tar.gz aaa(archieve.tar.gz为压缩后的压缩包名称&#xff0c…

【Java.mysql】——增删查改(CRUD)之 增查(CR) 附加数据库基础知识

目录 &#x1f6a9;数据库操作 &#x1f388;创建数据库 &#x1f388;使用数据库 &#x1f388;删除数据库 &#x1f6a9;数据类型 &#x1f6a9;表的操作 &#x1f388;创建表 &#x1f308;查看表结构 &#x1f388;删除表 ❗练习(综合运用) &#x1f5a5;️新增…

【黑马程序员】STL实战--演讲比赛管理系统

文章目录 演讲比赛管理系统需求说明比赛规则程序功能 创建管理类功能描述创建演讲比赛管理类 菜单功能添加菜单成员函数声明菜单成员函数实现菜单功能测试 退出功能添加退出功能声明退出成员函数实现退出功能测试 演讲比赛功能功能分析创建选手类比赛成员属性添加初始化属性创建…

KVM技术原理及安装KVM并且在KVM里面安装RHEL8

目录 一、kvm原理 1..1虚拟化概念 1.2 虚拟化产生背景 1.3虚拟化架构 1.4主流的虚拟化技术 1.5阐述个人对虚拟化技术的几种分类认知 二、安装KVM并且在KVM里面安装RHEL8 2.1在RHEL8主机上安装KVM 2.2安装完成后&#xff0c;使用virt-manager命令打开虚拟机管理图形界面…

Python错题集-8:AttributeError(找不到对应的对象的属性)

1问题描述 AttributeError: AxesSubplot object has no attribute arc 2代码详情 import matplotlib.pyplot as plt# 创建一个新的图形和坐标轴 fig, ax plt.subplots()# 定义弧线的参数 center (0.5, 0.5) # 圆心坐标 (x, y) width 1.0 # 半径 height 0.5 # 半径 ang…

Charles抓包工具使用

Charles简介 Charles是一款基于HTTP协议的代理服务器和HTTP监视器&#xff0c;通过将自己设置为电脑或浏览器的网络访问代理&#xff0c;能够截取请求和请求结果&#xff0c;从而达到分析抓包的目的。它允许开发者查看所有连接互联网的HTTP通信&#xff0c;包括请求、响应和HTT…

【漏洞复现】Linksys E2000 position.js 身份验证绕过漏洞(CVE-2024-27497)

0x01 产品简介 Linksys E2000是一款由思科&#xff08;Cisco&#xff09;品牌推出的无线路由器&#xff0c;它是一款支持2.4GHz和5GHz双频段的无线路由器&#xff0c;用户可以避开拥挤的2.4GHz频段&#xff0c;独自享受5GHz频段的高速无线生活。 0x02 漏洞概述 Linksys E200…

JAVA虚拟机、Dalvik虚拟机和ART虚拟机简要对比

1、什么是JVM&#xff1f; JVM本质上就是一个软件&#xff0c;是计算机硬件的一层软件抽象&#xff0c;在这之上才能够运行Java程序&#xff0c;JAVA在编译后会生成类似于汇编语言的JVM字节码&#xff0c;与C语言编译后产生的汇编语言不同的是&#xff0c;C编译成的汇编语言会…

hadoop集群部署教程

文章目录 前言一、相关介绍1. 配置文件位置1.1 只读默认配置文件1.2 可修改配置文件1.3 相关环境变量配置文件 二、安装准备1. 准备centos2. 配置集群免密登录3. 部署规划4. 安装条件5. 安装jdk 三、安装hadoop1. 下载并解压hadoop2. 设置环境变量2.1 设置hadoop安装目录环境变…

Diddler抓包工具——学习笔记

F12抓包 302【重定向】&#xff1a;当你发送了一个请求之后&#xff0c;那么这个请求重定向到了另外的资源 跳转和重定向的区别&#xff1a; 跳转是会把数据传到新的地址 重定向不会把新的数据传到新的地址 使用F12抓包时一定要打开Preserve Log开关&#xff0c;作用是保留…

【CSP试题回顾】202009-1-称检测点查询

CSP-202009-1-称检测点查询 解题代码 #include <iostream> #include <vector> #include <cmath> #include <algorithm> using namespace std;int n, X, Y, x, y; struct MyDistance {int index;int dis; }; vector<MyDistance>list; bool cmp(…

福州·名城银河湾220㎡现代简约风装修案例分享。福州中宅装饰,福州装修

以手作维度构境, 跳脱约定成俗的风格, 转化内外地域分际, 于静谧中凝聚丰厚的美学能量, 谦虚且沉默以对。 平面设计图 项目信息 项目名称 | 名城银河湾 设计地址 | 福建福州 项目面积 | 220㎡ 项目户型 | 5室2厅2厨3卫 设计风格 | 现代轻奢 首席设计师丨欧阳光玉 中…

C#,老鼠迷宫问题的回溯法求解(Rat in a Maze)算法与源代码

1 老鼠迷宫问题 迷宫中的老鼠&#xff0c;作为另一个可以使用回溯解决的示例问题。 迷宫以块的NN二进制矩阵给出&#xff0c;其中源块是最左上方的块&#xff0c;即迷宫[0][0]&#xff0c;目标块是最右下方的块&#xff0c;即迷宫[N-1][N-1]。老鼠从源头开始&#xff0c;必须…

《C缺陷和陷阱》-笔记(2)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 目录 文章目录 前言 一、理解函数声明 1.(*(void(*)( ))0)( ); 2.signal 函数接受两个参数&#xff1a; 3.使用typedef 简化函数声明&#xff1a; 二、运算符的优先级…

C/C++实现代码雨效果

C/C实现代码雨效果 目录 C/C实现代码雨效果 说明使用的库说明测试代码效果图 说明 最近整理电脑资料&#xff0c;翻出了以前写的代码&#xff0c;顺便整理一下到博客上&#xff0c;当做一次备份记录 先看看静态效果 需要分为以下步骤实现 生成代码串把代码串绘制到窗口中使…

差分约束

&#xff08;1&#xff09;求不等式组的可行解 源点需满足的条件&#xff1a;从源点出发&#xff0c;一定可以到达所有的边 求最短路 步骤&#xff1a;1.先将每个不等式xi<xj ck&#xff0c;转化成一条从xj走到xi&#xff0c;长度为ck的一条边 2.找一个超级源点&#xff0c…