medusa简单实现和transformer的seq2seq的实现区别
medusa代码简单如下
import tensorflow as tf
from tensorflow.keras import layers, Model
# 基础 Transformer 模型
class SimpleTransformer(Model):
def __init__(self, vocab_size, hidden_dim, num_heads, num_layers):
super(SimpleTransformer, self).__init__()
self.embedding = layers.Embedding(vocab_size, hidden_dim)
self.transformer = [
layers.MultiHeadAttention(num_heads=num_heads, key_dim=hidden_dim) for _ in range(num_layers)
]
self.fc_out = layers.Dense(vocab_size) # 最后的语言模型输出层
def call(self, inputs):
x = self.embedding(inputs)
for mha in self.transformer:
x = mha(x, x)
return self.fc_out(x)
# 定义带有多个解码头的 MEDUSA 模型
class MedusaTransformer(Model):
def __init__(self, base_model, num_heads, hidden_dim, vocab_size):
super(MedusaTransformer, self).__init__()
self.base_model = base_model # 基础的 Transformer 模型
self.medusa_heads = [ # 定义多个解码头
layers.Dense(hidden_dim, activation='swish') for _ in range(num_heads)
]
self.output_heads = [ # 每个解码头的输出层
layers.Dense(vocab_size) for _ in range(num_heads)
]
def call(self, inputs):
# 获取基础模型的最后隐藏状态
hidden_states = self.base_model(inputs)
# 使用多个解码头并行生成多个候选词的分布
outputs = [output_head(medusa_head(hidden_states)) for medusa_head, output_head in zip(self.medusa_heads, self.output_heads)]
return outputs
# 初始化模型
vocab_size = 10000
hidden_dim = 512
num_heads = 8
num_layers = 6
base_model = SimpleTransformer(vocab_size, hidden_dim, num_heads, num_layers)
# 添加 3 个解码头
medusa_model = MedusaTransformer(base_model, num_heads=3, hidden_dim=hidden_dim, vocab_size=vocab_size)
# 示例输入
input_tokens = tf.constant([[1, 2, 3]]) # 示例输入
# 使用 MEDUSA 模型进行推理
outputs = medusa_model(input_tokens)
print("Medusa Outputs:", outputs)
seq2seq如下
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleTransformer(nn.Module):
def __init__(self, vocab_size, hidden_dim, num_heads, num_layers):
super(SimpleTransformer, self).__init__()
# 词嵌入层
self.embedding = nn.Embedding(vocab_size, hidden_dim)
# Transformer 编码器层
self.transformer = nn.Transformer(
d_model=hidden_dim,
nhead=num_heads,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers
)
# 语言模型输出层
self.fc_out = nn.Linear(hidden_dim, vocab_size)
def forward(self, src, tgt):
# 进行嵌入
src_embedding = self.embedding(src)
tgt_embedding = self.embedding(tgt)
# 通过 Transformer 模型
transformer_output = self.transformer(src_embedding, tgt_embedding)
# 输出词的分布
output = self.fc_out(transformer_output)
return output
def generate_text(model, tokenizer, start_token, max_length=20):
# 将起始 token 转化为输入张量
input_tokens = torch.tensor([start_token]).unsqueeze(0) # Shape: (1, 1)
model.eval() # 设置模型为评估模式
for _ in range(max_length):
# 使用模型进行前向传播,预测下一个 token
output = model(input_tokens, input_tokens)
next_token_logits = output[:, -1, :] # 获取最后一个时间步的输出
next_token_probs = F.softmax(next_token_logits, dim=-1)
next_token = torch.argmax(next_token_probs, dim=-1).item() # 选择概率最大的词
# 将下一个 token 添加到输入序列中
input_tokens = torch.cat([input_tokens, torch.tensor([[next_token]])], dim=1)
# 如果生成结束符,停止生成
if next_token == tokenizer['<eos>']:
break
# 返回生成的词序列
generated_text = [tokenizer[i] for i in input_tokens.squeeze().tolist()]
return " ".join(generated_text)
# 模型和词典初始化
vocab_size = 10000 # 假设词汇表大小为 10000
hidden_dim = 512
num_heads = 8
num_layers = 6
model = SimpleTransformer(vocab_size, hidden_dim, num_heads, num_layers)
# 示例的 tokenizer 和起始 token
tokenizer = {i: f"word{i}" for i in range(vocab_size)}
tokenizer['<eos>'] = vocab_size - 1 # 假设词汇表最后一个是结束符
start_token = 0 # 假设 0 号词是起始 token
# 生成文本
generated_text = generate_text(model, tokenizer, start_token)
print("Generated Text:", generated_text)
昨天我看了medusa这个概论并且大致看了源代码
MEDUSA不同于传统的transformer采取的自回归生成方法,采用了一种类似informer的策略,使用多个解码头来并行生成多个后续词的候选项,这样就不必严格按照逐词顺序生成。
MEDUSA 利用了树状注意力机制(tree-based attention)来构建多个解码头,解码头里面有多个生成的候选词,并在每一步解码中同时使用笛卡尔积进行组合并验证这些候选词。验证方法:MEDUSA 引入了一种叫做典型接受(Typical Acceptance)的方案,用于评估候选序列的合理性。这个方案主要是通过设定一个基于熵的阈值,来选择哪些候选词被接受。
这样,通过并行处理,MEDUSA 大大减少了所需的解码步骤数量。
添加 MEDUSA 解码头:
在 LM Head 之后,添加多个解码头(如 Medusa Head 1, Medusa Head 2, ...),每个解码头都是一个用于预测不同位置后续词的前馈网络层(Feed-Forward Layer)。解码头的数量可以根据需要设定,通常为 3-5 个。