在一些模型参数量比较大的llm和多模态网络中,比如
- llama系列70B
- MOE系列
- VL系列(QWenVL, GLM4V)
等,在推理的时候,对显存要求比较大,不是H800这种巨无霸显存很难跑起来。
像llama系列和MOE系列还好,可以借助deepseed等加速框架对齐进行TP切分,从而达到多卡切分参数的效果,但是像VL系列,TP等策略就不太好使了。
transformers框架提供了多设备load模型的方式,通过设置device_map,让模型均匀的分布在多卡,从而以类模型并行的方式,比如用上4-6个8g-24g显存的设备就可以跑起来70B, moe, vl这些。
具体代码如下,以GLM-4v为例:
from transformers import LlamaConfig,LlamaForCausalLM,AutoTokenizer,AutoModel, AutoConfig
from accelerate import init_empty_weights,infer_auto_device_map,load_checkpoint_in_model,dispatch_model
import torch
cuda_list = '0,1,2,3'.split(',')
memory = '20GiB'
model_path = '/home/root/.cache/hub/ZhipuAI/glm-4v-9b'
tokenizer = AutoTokenizer.from_pretrained(
model_path,
trust_remote_code=True,
encode_special_tokens=True,
use_fast=False,
)
model = AutoModel.from_pretrained(
model_path,
trust_remote_code=True,
device_map="auto",
torch_dtype=torch.bfloat16,
max_memory={0: "20GiB", 1: "20GiB", 2: "20GiB", 3: "20GiB"}
)
torch.set_grad_enabled(False)
model.eval()
from PIL import Image
messages = [
{
"role": "user",
"content": "图片中可以看到多少人玩滑板?",
"image": Image.open("/home/data/CogVLM-SFT-311K/llava_instruction_multi_conversations_formate/images/000000000.jpg").convert("RGB")
}
]
inputs = tokenizer.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_tensors="pt",
return_dict=True,
).to(model.device)
generate_kwargs = {
"max_new_tokens": 128,
"do_sample": True,
"top_p": 0.8,
"temperature": 0.8,
"repetition_penalty": 1.2,
"eos_token_id": model.config.eos_token_id,
}
outputs = model.generate(**inputs, **generate_kwargs)
response = tokenizer.decode(
outputs[0][len(inputs["input_ids"][0]) :], skip_special_tokens=True
).strip()
print("=========")
print(response)
最终,笔者利用4个32G的设备,成功推理了GLM-4V的模型,每个仅用了30%的显存。
显存占用效果为:
Device Monitor of AIC
AIC Pwr(W) | Die Temp(C) Oclk(MHz) Dclk(MHz) Eclk(MHz) %Mem %Dec %Enc %AI %Dsp
--------------*-------------------------------------------------------------------------------------------
0 58.8 | 0 60.5 880 20 20 21.10 0.00 0.00 0.00 0.00
| 1 62.5 880 20 20 24.89 0.00 0.00 0.00 0.00
| 2 61.2 880 20 20 34.92 0.00 0.00 0.00 0.00
| 3 63.0 880 20 20 37.71 0.00 0.00 87.61 0.00
=========
从拍摄者的角度看,有两个人的脚踩在滑板上。但是由于视角问题只能看到两个轮子以及他们脚下的一部分木板并不能完全确认是否有人在上面滑行。
所以无法准确回答这个问题需要更多信息来判断哪些人正在积极地参与滑板运动而不是仅仅站在附近或等待机会使用他们的设备。通常来说
/home/zyhan/anaconda3/envs/py310/lib/python3.10/tempfile.py:860: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpcqkta1ax'>
_warnings.warn(warn_message, ResourceWarning)