·请参考本系列目录:【mT5多语言翻译】之一——实战项目总览
[1] 模型翻译推理
在分别使用全量参数微调和PEFT微调训练完模型之后,我们来测试模型的翻译效果。推理代码如下:
# 导入模型
if conf.is_peft:
model = AutoModelForSeq2SeqLM.from_pretrained(conf.peft_save)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)
model.load_state_dict(torch.load(conf.save_path))
model.to(conf.device)
model.eval()
sentences = [
"kor:我要去健身了",
"jpn:我要去健身了",
"kor:他说他会爱我一辈子",
"jpn:他说他会爱我一辈子",
]
tokenizer = AutoTokenizer.from_pretrained(conf.pretrained_path)
ids = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=sentences,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=conf.max_seq_len,
return_attention_mask=False
)
input_ids = ids['input_ids'].to(conf.device)
output_tokens = model.generate(input_ids, num_beams=10, num_return_sequences=3)
for token_set in output_tokens:
print(tokenizer.decode(token_set, skip_special_tokens=True))
因为训练方式有全量参数微调和PEFT微调两种,不同方式保存的模型不同。前者是全量参数,后者是PEFT添加的少量参数。
【注】直接加载PEFT保存的少量参数,也可以加载到mT5模型本身的预训练参数。这是因为在peft模型保存的文件夹中有一个
adapter_config.json
文件,里面保存了基座模型的地址。
最终,可以观察到上述代码的输出为:
나는 피트니스에 가고 싶
나는 피트니스 클럽에 가
나는 피트니스 센터에 가
ジムに行きます。
ジムに行きたいです。
ジムに行くわ
그는 평생을 나를 사랑할
그는 평생 나를 사랑할 것
그는 평생 나를 사랑할 거
彼は私を愛してくれると言っていた。
彼は私を愛してくれると言った。
彼は私を愛してくれると言っていました。
[2] 第三方接口设计
我们把模型推理简单地设计成一个GET请求的接口,代码如下:
# coding: UTF-8
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BertModel, T5Model
from conf import conf
from flask import Flask, request, jsonify
app = Flask(__name__)
# 导入模型
if conf.is_peft:
model = AutoModelForSeq2SeqLM.from_pretrained(conf.peft_save)
else:
model = AutoModelForSeq2SeqLM.from_pretrained(conf.pretrained_path)
model.load_state_dict(torch.load(conf.save_path))
model.to(conf.device)
model.eval()
tokenizer = AutoTokenizer.from_pretrained(conf.pretrained_path)
@app.route('/translate', methods=['GET'])
def translate():
# 从GET请求中获取参数
sentences = request.args.getlist('sentence')
if not sentences:
return jsonify({"error": "No sentences provided."}), 400
# 对句子进行编码
ids = tokenizer.batch_encode_plus(
batch_text_or_text_pairs=sentences,
return_tensors='pt',
padding='max_length',
truncation=True,
max_length=conf.max_seq_len,
return_attention_mask=False
)
input_ids = ids['input_ids'].to(conf.device)
# 生成翻译结果
output_tokens = model.generate(input_ids, num_beams=10, num_return_sequences=3)
# 解码翻译结果
translations = [tokenizer.decode(token_set, skip_special_tokens=True) for token_set in output_tokens]
# 返回结果
return jsonify({"translations": translations})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000)
然后就能去浏览器快乐地测试玩耍了。