MBART: Multilingual Denoising Pre-training for Neural Machine Translation
模型下载
mBART 是一个基于序列到序列的去噪自编码器,使用 BART 目标在多种语言的大规模单语语料库上进行预训练。mBART 是首批通过去噪完整文本在多种语言上预训练序列到序列模型的方法之一,而以往的方法则仅集中在编码器、解码器或重构文本的部分内容。
首先需要在github上下载mbart的预训练模型,我们要完成的任务是微调:README
下载mbart.CC25模型,下载解压后的目录如下:
├── mbart.cc25.v2
└── dict.txt
└── model.pt
└── sentence.bpe.model
dict是模型的词典文件,model是mbart的预训练模型,sentence.bpe.model是sentencepiece训练分词的模型。
我们在en->vi
双语对上对预训练模型进行微调:
数据为IWSLT15的双语对,下载好数据之后对训练集、验证集和测试集重新命名,整理后的目录如下
├── en_vi
└── test.en_XX
└── test.vi_VN
└── train.en_XX
└── train.vi_VN
└── valid.en_XX
└── valid.vi_VN
环境准备
这里需要特定的fairseq版本来完成以下的这些命令,因此推荐新创建一个conda环境来隔离版本,这里我们命名为mbart_ft.
fairseq=0.10.2
python=3.8
numpy=1.19.5
数据分词
首先使用sentencepiece模型对数据进行分词,由于原文中描述没有额外的 true-casing、normalizing punctuation / characters.因此我们直接分词即可。
#!/bin/bash
SPM=/path/to/sentencepiece/build/src/spm_encode
MODEL=/mbart/mbart.cc25.v2/sentence.bpe.model
DATA=/mbart/en_vi
DEST=/mbart/en_vi/spm
TRAIN=train
VALID=valid
TEST=test
SRC=en_XX
TGT=vi_VN
${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${SRC} > ${DEST}/${TRAIN}.spm.${SRC} &
${SPM} --model=${MODEL} < ${DATA}/${TRAIN}.${TGT} > ${DEST}/${TRAIN}.spm.${TGT} &
${SPM} --model=${MODEL} < ${DATA}/${VALID}.${SRC} > ${DEST}/${VALID}.spm.${SRC} &
${SPM} --model=${MODEL} < ${DATA}/${VALID}.${TGT} > ${DEST}/${VALID}.spm.${TGT} &
${SPM} --model=${MODEL} < ${DATA}/${TEST}.${SRC} > ${DEST}/${TEST}.spm.${SRC} &
${SPM} --model=${MODEL} < ${DATA}/${TEST}.${TGT} > ${DEST}/${TEST}.spm.${TGT} &
wait
echo "SentencePiece encoding completed!"
我们创建spm目录,分词后的目录文件为:
├── en_vi
├── spm
└── test.spm.en_XX
└── test.spm.vi_VN
└── train.spm.en_XX
└── train.spm.vi_VN
└── valid.spm.en_XX
└── valid.spm.vi_VN
数据预处理
使用fairseq将数据处理为满足训练的、输入模型的格式。包含两种语言的词典文件、二进制格式和分词转换为id的文件。
#!/bin/bash
BASEDIR=/mbart/en_vi
DATA=${BASEDIR}/spm
DEST=${BASEDIR}/ids
DICT=/mbart/mbart.cc25.v2/dict.txt
SRC=en_XX
TGT=en_vi
TRAIN=train
VALID=valid
TEST=test
fairseq-preprocess \
--source-lang ${SRC} \
--target-lang ${TGT} \
--trainpref ${DATA}/${TRAIN}.spm \
--validpref ${DATA}/${VALID}.spm \
--testpref ${DATA}/${TEST}.spm \
--destdir ${DEST}/ \
--thresholdtgt 0 \
--thresholdsrc 0 \
--srcdict ${DICT} \
--tgtdict ${DICT} \
--workers 70
预处理后的模型数据准备的目录为:
├── en_vi
├── ids
└── dict.en_XX.txt
└── dict.vi_VN.txt
└── preprocess.log
└── test.en_XX-vi_VN.en_XX.bin
└── test.en_XX-vi_VN.en_XX.idx
└── test.en_XX-vi_VN.vi_VN.bin
└── test.en_XX-vi_VN.vi_VN.idx
└── train.en_XX-vi_VN.en_XX.bin
└── train.en_XX-vi_VN.en_XX.idx
└── train.en_XX-vi_VN.vi_VN.bin
└── train.en_XX-vi_VN.vi_VN.idx
└── valid.en_XX-vi_VN.en_XX.bin
└── valid.en_XX-vi_VN.en_XX.idx
└── valid.en_XX-vi_VN.vi_VN.bin
└── valid.en_XX-vi_VN.vi_VN.idx
训练集和验证集在训练过程中被用到,而测试集只在评价生成中被用到。
模型的训练
需要注意的是,与官方文档设置的参数相比,有一处需要修改。
--max-update 40000
:模型参数更新次数。
----total-num-update 40000
这是设置学习率调度器的更新次数,即学习率更新40k次训练停止。
在mbart的原文中:
We use a maximum of 40K training updates for all low and medium resource pairs and 100K for high resource pairs.
我们的en-vi双语数据属于低资源语言对,因此参数更新次数40K次,即应该设置--max-update 40000
#!/bin/bash
source /path/to/conda/etc/profile.d/conda.sh
conda activate mbart_ft
BASEDIR=/mbart/en_vi
PRETRAIN=/mbart/mbart.cc25.v2/model.pt # 已下载的预训练模型路径
DATA=${BASEDIR}/ids # 预处理后的二进制数据路径
SRC=en_XX
TGT=vi_VN
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
fairseq-train ${DATA} \
--encoder-normalize-before --decoder-normalize-before \
--arch mbart_large --layernorm-embedding \
--task translation_from_pretrained_bart \
--source-lang ${SRC} --target-lang ${TGT} \
--criterion label_smoothed_cross_entropy --label-smoothing 0.2 \
--optimizer adam --adam-eps 1e-06 --adam-betas '(0.9, 0.98)' \
--lr-scheduler polynomial_decay --lr 3e-05 --warmup-updates 2500 --max-update 40000 \
--dropout 0.3 --attention-dropout 0.1 --weight-decay 0.0 \
--max-tokens 1024 --update-freq 2 \
--save-interval 1 --save-interval-updates 5000 --keep-interval-updates 10 --no-epoch-checkpoints \
--seed 222 --log-format simple --log-interval 2 \
--restore-file $PRETRAIN \
--reset-optimizer --reset-meters --reset-dataloader --reset-lr-scheduler \
--langs $langs \
--save-dir ${BASEDIR} \
--ddp-backend no_c10d
在一块RTX 4090显卡上,运行3个小时后,训练结束。我们设置了每5000次更新保存一次检查点,微调模型保存的文件位置为 --save-dir ${BASEDIR}
微调后的目录文件为:
├── en_vi
├── ids
├── spm
└── dict.en_XX.txt
└── checkpoint_2_5000.pt
└── checkpoint_4_10000.pt
└── checkpoint_6_15000.pt
└── checkpoint_8_20000.pt
└── checkpoint_10_25000.pt
└── checkpoint_12_30000.pt
└── checkpoint_14_35000.pt
└── checkpoint_16_40000.pt
└── checkpoint_best.pt
└── checkpoint_last.pt
└── ...train...test...valid
模型的解码
我们使用checkpoint_best.pt
对其进行解码以及测BLEU分数。
这里我将分词模型复制到了en_vi文件夹中,并且添加--cpu
使得解码在cpu上运行。解码生成的文件为/mbart/en_vi/ids/en_vi
2025.1.13修订:
需要注意的是,相比于官方文档,这里删除了--bpe "sentencepiece"
、--sentencepiece-model $model_dir/sentence.bpe.model
以及--sacrebleu
若保留--sacrebleu
,由于版本间不匹配会报错
若保留--bpe "sentencepiece"
,则除了模型推理行“H”,其他源句子、目标句子和行“D”均没有空格出现。说明解码过程中并不需要此参数。
--remove-bpe 'sentencepiece'
:用于去除分词过程中产生的spm标记。
#!/bin/bash
source /path/to/conda/etc/profile.d/conda.sh
conda activate mbart_ft
model_dir=/mbart/en_vi/ids
langs=ar_AR,cs_CZ,de_DE,en_XX,es_XX,et_EE,fi_FI,fr_XX,gu_IN,hi_IN,it_IT,ja_XX,kk_KZ,ko_KR,lt_LT,lv_LV,my_MM,ne_NP,nl_XX,ro_RO,ru_RU,si_LK,tr_TR,vi_VN,zh_CN
TOKENIZER=${model_dir}/sentence.bpe.model
fairseq-generate ${model_dir} \
--path $model_dir/../checkpoint_best.pt \
--task translation_from_pretrained_bart \
--gen-subset test \
--cpu \
-t vi_VN -s en_XX \
--remove-bpe 'sentencepiece' \
--batch-size 32 \
--langs $langs > ${model_dir}/en_vi
查看生成文件en_vi的片段:
S-74 I lost all hope .[en_XX]
T-74 Tôi hoàn toàn tuyệt vọng .
H-74 -0.4808153808116913 Tôi đã mất hết hy vọng .
D-74 -0.4808153808116913 Tôi đã mất hết hy vọng .
P-74 -0.3194 -0.9490 -0.5736 -0.8777 -0.7397 -0.0389 -0.4746 -0.2814 -0.2920 -0.2618
S-372 Today I am 22 .[en_XX]
T-372 Hôm nay tôi 22 tuổi .
H-372 -0.3478223383426666 Hôm nay tôi 22 tuổi .
D-372 -0.3478223383426666 Hôm nay tôi 22 tuổi .
P-372 -0.5605 -0.0631 -0.4549 -0.2989 -0.4617 -0.4079 -0.3166 -0.3061 -0.2606
S-336 Thank you very much .[en_XX]
T-336 Cám ơn rất nhiều .
H-336 -0.46486935019493103 Cám ơn các bạn rất nhiều .
D-336 -0.46486935019493103 Cám ơn các bạn rất nhiều .
P-336 -1.8484 -0.0979 -0.1278 -0.9053 -0.2160 -0.4894 -0.1446 -0.4404 -0.2856 -0.3061 -0.2521
S-1267 Thank you very much .[en_XX]
T-1267 Cảm ơn rất nhiều .
H-1267 -0.46486935019493103 Cám ơn các bạn rất nhiều .
D-1267 -0.46486935019493103 Cám ơn các bạn rất nhiều .
P-1267 -1.8484 -0.0979 -0.1278 -0.9053 -0.2160 -0.4894 -0.1446 -0.4404 -0.2856 -0.3061 -0.2521
S-21 But many die .[en_XX]
T-21 Nhưng rất nhiều người đã chết .
H-21 -0.5680863261222839 Nhưng nhiều người chết .
D-21 -0.5680863261222839 Nhưng nhiều người chết .
P-21 -0.3266 -1.4395 -0.1804 -1.2362 -0.5122 -0.2999 -0.2973 -0.2526
S:是源句子,在en->vi
双语对上,源语言是英语。
T:是人工翻译句子,即测试集中的句子;
H:是模型输出的解码句子,第一个数字为其得分;
D:第一个数字为得分和H一致,但相比于H去掉了所有的空格,和S、T格式相同;
P:翻译过程中每个单词的预测概率。
运行解码脚本后,在ids目录中会生成 en_vi 文件。
├── en_vi
├── ids
└── sentence.bpe.model
└── en_vi
└── train...test...valid...dict...
├── spm
├── train...valid...test...
模型的评价
cat en_vi | grep -P "^H" |sort -V |cut -f 3- | sed 's/\[vi_VN\]//g' > en_vi.hyp
cat en_vi | grep -P "^T" |sort -V |cut -f 2- | sed 's/\[vi_VN\]//g' > en_vi.ref
sacrebleu en_vi.ref -i en_vi.hyp -m bleu -tok "spm"
这里将 H 开头的行提取,并去掉前两个字段,仅保留模型输出的解码句子,将他们合成 en_vi.hyp文件;
将 T 开头的行提取,并去掉第一个字段,保留test文件中的目标句子,将他们合成 en_vi.ref 文件。
这两行代码运行后,目录ids中应该多出两个文件。
├── en_vi
├── ids
└── en_vi
└── en_vi.hyp
└── en_vi.ref
└── train...test...valid...dict...
├── spm
├── train...valid...test...
这两个文件的行数应该一致,使用sacrebleu来测bleu的分数,指定 -tok 分词方式是 “spm” 即sentencepiece。
我们测试的模型评价结果为:
{
"name": "BLEU",
"score": 35.4,
"signature": "nrefs:1|case:mixed|eff:no|tok:flores101|smooth:exp|version:2.4.3",
"verbose_score": "66.3/42.8/28.8/19.5 (BP = 0.997 ratio = 0.997 hyp_len = 34971 ref_len = 35063)",
"nrefs": "1",
"case": "mixed",
"eff": "no",
"tok": "flores101",
"smooth": "exp",
"version": "2.4.3"
}
可以看到,我们复现了原文中在en->vi
方向的 BLEU 评分。
附录
原版本未删除--bpe "sentencepiece"
、--sentencepiece-model $model_dir/sentence.bpe.model
参数,fairseq推理后生成的en_vi文件为:
S-74 Ilostallhope.[en_XX]
T-74 Tôihoàntoàntuyệtvọng.
H-74 -0.4808153808116913 Tôi đã mất hết hy vọng .
D-74 -0.4808153808116913 Tôiđãmấthếthyvọng.
P-74 -0.3194 -0.9490 -0.5736 -0.8777 -0.7397 -0.0389 -0.4746 -0.2814 -0.2920 -0.2618
S-372 TodayIam22.[en_XX]
T-372 Hômnaytôi22tuổi.
H-372 -0.3478223383426666 Hôm nay tôi 22 tuổi .
D-372 -0.3478223383426666 Hômnaytôi22tuổi.
P-372 -0.5605 -0.0631 -0.4549 -0.2989 -0.4617 -0.4079 -0.3166 -0.3061 -0.2606
S-336 Thankyouverymuch.[en_XX]
T-336 Cámơnrấtnhiều.
H-336 -0.46486935019493103 Cám ơn các bạn rất nhiều .
D-336 -0.46486935019493103 Cámơncácbạnrấtnhiều.
P-336 -1.8484 -0.0979 -0.1278 -0.9053 -0.2160 -0.4894 -0.1446 -0.4404 -0.2856 -0.3061 -0.2521
S-1267 Thankyouverymuch.[en_XX]
T-1267 Cảmơnrấtnhiều.
H-1267 -0.46486935019493103 Cám ơn các bạn rất nhiều .
D-1267 -0.46486935019493103 Cámơncácbạnrấtnhiều.
P-1267 -1.8484 -0.0979 -0.1278 -0.9053 -0.2160 -0.4894 -0.1446 -0.4404 -0.2856 -0.3061 -0.2521
S-21 Butmanydie.[en_XX]
T-21 Nhưngrấtnhiềungườiđãchết.
H-21 -0.5680863261222839 Nhưng nhiều người chết .
D-21 -0.5680863261222839 Nhưngnhiềungườichết.
P-21 -0.3266 -1.4395 -0.1804 -1.2362 -0.5122 -0.2999 -0.2973 -0.2526
可以看到,测试集源句子S以及目标句子T的空格被误删除。由此提取的模型生成文件en_vi.hyp和翻译参考文件en_vi.ref同样误删空格。且模型输出句子H是正常的,这就说明是在解码过程中出现的问题。