文章目录
- 链接
- 导入所需包
- class ModelArgs
- class Mamba
- def __ init __
- def forward
- class ResidualBlock
- class RNSNorm
- 文本生成demo
manba的简单最小限度实现,和原始论文实现 state-spaces/mamba (github.com)相比,为了可读性对参数没有很好的初始化,原论文用CUDA写了并行扫描,所以速度会快。
这里是剩余部分介绍,主要包括利用MambaBlock和其他组件如残差连接,归一化等定义一个序列模型。
MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客
链接
来自johnma2006/mamba-minimal: Simple, minimal implementation of the Mamba SSM in one file of PyTorch. (github.com)
导入所需包
from __future__ import annotations
import math
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from einops import rearrange, repeat, einsum
class ModelArgs
模型参数设置
参数 | 介绍 |
---|---|
d_model | 模型维度,和输入数据通道对应 |
n_layer | 残差块的数目 |
d_state | 潜在状态维度 |
expand | 扩展因子,d_in = d_state * state |
dt_rank | delta的秩 |
d_conv | 1D卷积的卷积核大小 |
vocab_size | 词汇表的大小 |
pad_vocab_size_multiple | 确保vocab_size是设定值的倍数 |
conv_bias | 1D卷积的bias选项 |
bias | lm_head映射的bias选项 |
@dataclass
class ModelArgs:
d_model: int
n_layer: int
vocab_size: int
d_state: int = 16
expand: int = 2
dt_rank: Union[int, str] = 'auto'
d_conv: int = 4
pad_vocab_size_multiple: int = 8
conv_bias: bool = True
bias: bool = False
def __post_init__(self):
self.d_inner = int(self.expand * self.d_model)
if self.dt_rank == 'auto':
self.dt_rank = math.ceil(self.d_model / 16)
if self.vocab_size % self.pad_vocab_size_multiple != 0:
self.vocab_size += (self.pad_vocab_size_multiple
- self.vocab_size % self.pad_vocab_size_multiple)
class Mamba
一个完整的序列处理Mamba模型,包含多个被包裹的MambaBlock。
nn.Embedding参照深度学习:pytorch nn.Embedding详解-CSDN博客
lm_head层则是预测下一个token的输出层,它将模型的输出映射到一个概率分布上,以便于模型预测下一个token,权重和Embedding公用。
输入一个序列 x ( b a t c h _ s i z e , l e n g t h ) x(batch\_size, length) x(batch_size,length) 简写为 ( b , l ) (b, l) (b,l),输出取词的概率 ( b , l , v o c a b _ s i z e ) (b, l, vocab\_size) (b,l,vocab_size)
组件 | 尺寸变换 |
---|---|
embedding | (b, l) -> (b, l, d_model) |
layers | (b, l, d_model) -> (b, l, d_model) |
norm_f | \ |
lm_head | (b, l, d_model) -> (b, l, vocab_size) |
def __ init __
class Mamba(nn.Module):
def __init__(self, args: ModelArgs):
"""Full Mamba model."""
super().__init__()
self.args = args
self.embedding = nn.Embedding(args.vocab_size, args.d_model)
self.layers = nn.ModuleList([ResidualBlock(args) for _ in range(args.n_layer)])
self.norm_f = RMSNorm(args.d_model)
self.lm_head = nn.Linear(args.d_model, args.vocab_size, bias=False)
self.lm_head.weight = self.embedding.weight # Tie output projection to embedding weights.
# See "Weight Tying" paper
def forward
def forward(self, input_ids):
x = self.embedding(input_ids)
for layer in self.layers:
x = layer(x)
x = self.norm_f(x)
logits = self.lm_head(x)
return logits
class ResidualBlock
一个包裹MambaBlock的一个残差块
MambaBlock的介绍Mamba-minimal Mamba的最小限度实现 (一)-CSDN博客
class ResidualBlock(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()
self.args = args
self.mixer = MambaBlock(args)
self.norm = RMSNorm(args.d_model)
def forward(self, x):
output = self.mixer(self.norm(x)) + x
return output
class RNSNorm
所用到的归一化
可以参考RMSNorm论文阅读-CSDN博客
LLM中的RMSNorm - 知乎 (zhihu.com)
class RMSNorm(nn.Module):
def __init__(self,
d_model: int,
eps: float = 1e-5):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(d_model))
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight
return output
文本生成demo
来自demo.ipynb
这里是一个colab_demo
加载模型
from model import Mamba, ModelArgs
from transformers import AutoTokenizer
# One of:
# 'state-spaces/mamba-2.8b-slimpj'
# 'state-spaces/mamba-2.8b'
# 'state-spaces/mamba-1.4b'
# 'state-spaces/mamba-790m'
# 'state-spaces/mamba-370m'
# 'state-spaces/mamba-130m'
pretrained_model_name = 'state-spaces/mamba-370m'
model = Mamba.from_pretrained(pretrained_model_name)
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/gpt-neox-20b')
生成文本
在概率为top-k的输出中采样
import torch
import torch.nn.functional as F
def generate(model,
tokenizer,
prompt: str,
n_tokens_to_gen: int = 50,
sample: bool = True,
top_k: int = 40):
model.eval()
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
for token_n in range(n_tokens_to_gen):
with torch.no_grad():
indices_to_input = input_ids
next_token_logits = model(indices_to_input)[:, -1]
probs = F.softmax(next_token_logits, dim=-1)
(batch, vocab_size) = probs.shape
if top_k is not None:
(values, indices) = torch.topk(probs, k=top_k)
probs[probs < values[:, -1, None]] = 0
probs = probs / probs.sum(axis=1, keepdims=True)
if sample:
next_indices = torch.multinomial(probs, num_samples=1)
else:
next_indices = torch.argmax(probs, dim=-1)[:, None]
input_ids = torch.cat([input_ids, next_indices], dim=1)
output_completions = [tokenizer.decode(output.tolist()) for output in input_ids][0]
return output_completions
print(generate(model, tokenizer, 'Mamba is the'))