后处理的输入
常规意义上的大模型处理流程
import torch
from transformers import LlamaForCausalLM, LlamaTokenizer
# 加载模型和tokenizer
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf")
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
# 输入prompt
prompt = "Hello, I'm Claude. How can I assist you today?"
input_ids = tokenizer.encode(prompt, return_tensors="pt")
# 前向传播获取logits
output = model(input_ids=input_ids)
logits = output.logits
# logits形状: (batch_size, sequence_length, vocab_size)
print(logits.shape)
后处理的输入是logits,其实准确说是hidden states,经过embedding table 映射后得到了最终的logits。
# 采样超参数
temperature = 0.7
top_k = 50
top_p = 0.95
repetition_penalty = 1.2
# 对logits进行处理
logits = logits[:, -1, :] / temperature # 应用温度
filtered_logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
filtered_logits = enforce_repetition_penalty(filtered_logits, input_ids, repetition_penalty)
# 从处理后的logits中采样token
probabilities = torch.softmax(filtered_logits, dim=-1)
next_token = torch.multinomial(probabilities, num_samples=1)
# 将新token添加到输入中,继续生成
input_ids = torch.cat((input_ids, next_token), dim=-1)
- 首先定义了一些采样超参数,如温度(temperature)、top-k、top-p 和重复惩罚系数(repetition_penalty)。
- 接下来, 对 logits 进行处理:
- 应用温度缩放,控制输出的随机性。
- 使用 top-k 和 top-p 过滤,保留概率最高的 k 个 token 或累积概率达到 p 的 token。
- 应用重复惩罚,降低已生成 token 的概率,避免重复。
VLLM Sampler的处理
我们默认跑vllm benchmark test 的时候,sampling 参数配置:
sampling_params = SamplingParams(
n=args.n,
# temperature=0.0 if args.use_beam_search else 1.0,
temperature=0.0,
top_p=1.0,
use_beam_search=args.use_beam_search,
ignore_eos=ignore_eos,
max_tokens=max_tokens,
repetition_penalty=args.repetition_penalty
)
除了这些参数以外,SamplingParams(vllm/sampling_params.py)的默认配置我们主要关注:
其中由于temperature 设置为0,默认使用greedy sampling 方式进行logits 采样。
进入到Sampler 后处理(vllm/model_executor/layers/sampler.py,vllm/model_executor/sampling_metadata.py),do_top_p_top_k
和 do_min_p
采样bypass,最后softmax
的输入shape 没有经过topk/p 的采样,输入shape为[bs, input_size, vocabulary_size]
因此,vocabulary size 如果太大,对softmax 性能的影响是一个很大的挑战。
从性能优化的角度考虑,可以先做一次logit 采样,通过设定合适的p/k 值保证模型输出精度。