推理阶段。
在后台,1. DeepSpeed会把运行高性能kernel(kernel injection),加快推理速度,这些对用户是透明的; 2. DeepSpeed会根据mp_size来将模型放置在多个GPU卡上,自动模型并行;
import os import torch import transformers import deepspeed local_rank = int(os.getenv("LOCAL_RANK", "0")) world_size = int(os.getenv("WORLD_SIZE", "1")) # create the model pipeline pipe = transformers.pipeline(task="text2text-generation", model="google/t5-v1_1-small", device=local_rank) # Initialize the DeepSpeed-Inference engine pipe.model = deepspeed.init_inference( pipe.model, mp_size=world_size, dtype=torch.float ) output = pipe('Input String')
Train好的模型,即使没有用Model并行,或mp_size是另一个数,DeepSpeed支持推理时改变mp_size
量化推理:dtype支持torch.int8
新版推理:
DeepSpeed/blogs/deepspeed-fastgen at master · microsoft/DeepSpeed · GitHub
Fine-tune BERT:
使用QA数据集SQuAD: (文章,question, answer(答案就在原文中;答对任何一个应该算对))
Forward推理代码:
loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
input_ids
:token ids. 文本输入经过Tokenize后的id。segment_ids
: 0表示question,1表示context.input_mask
: 实际句子token or padding token。start_positions
: answer在原文中的起始pos。(推理的时候不用给出)end_positions
: answer在原文中的终止pos。 (推理的时候不用给出)--deepspeed_transformer_kernel参数(在deepspeed <my_script.py>后面加上该选项)
DeepSpeed专门为Transformer层实现的优化版kernel
config = DeepSpeedTransformerConfig(batch_size = 64, max_seq_length = 128, hidden_size = 1024, heads = 16, attn_dropout_ratio = 0.1, hidden_dropout_ratio = 0.1, num_hidden_layers = 24, initializer_range = 0.02, local_rank = 0, seed = 1234, fp16 = True, pre_layer_norm=True, attn_dropout_checkpoint=False, normalize_invertible=False, gelu_checkpoint=False) self.layer = nn.ModuleList([ copy.deepcopy(DeepSpeedTransformerLayer(cuda_config)) for _ in range(config.num_hidden_layers) ])
跟优化有关的选项:
stochastic_mode
: 大约加速2%. 随机跳过一些层的计算;
normalize_invertible
: 扔掉LayerNorm的输入激活值;(backward时,梯度可由output激活值算出)
attn_dropout_checkpoint
: attention dropout是一种正则技术,训练中随机disable一些attention weights;保存这些attention dropout mask比较耗费空间。此选项可以不保存这些mask,backward时重新计算(很快)。bert 中的四次dropout都在哪些位置?_transformer dropout 放在哪里-CSDN博客
gelu_checkpoint
: 扔掉Gelu的输出激活;因为可以从输入激活快速计算出来;