李宏毅2023机器学习作业HW05解析和代码分享

ML2023Spring - HW5 相关信息:
课程主页
课程视频
Sample code
HW05 视频
HW05 PDF

个人完整代码分享: GitHub | Gitee | GitCode
运行日志记录: wandb

P.S. HW05/06 是在 Judgeboi 上提交的,完全遵循 hint 就可以达到预期效果。

因为无法在 Judgeboi 上提交,所以 HW05/06 代码仓库中展示的是在验证集上的分数。

每年的数据集 size 和 feature 并不完全相同,但基本一致,过去的代码仍可用于新一年的 Homework。

仓库中 HW05 的代码分成了英文 EN 和中文 ZH 两个版本。

(碎碎念:翻译比较麻烦,所以之后的 Homework 代码暂只有英文版本)

任务目标(seq2seq)

  • Machine translation 机器翻译,英译中

性能指标(BLEU)

参考链接:

BLEU: a Method for Automatic Evaluation of Machine Translation

Foundations of NLP Explained — Bleu Score and WER Metrics

BLEU(Bilingual Evaluation Understudy) 双语评估替换

公式:
BLEU = B P ⋅ exp ⁡ ( ∑ n = 1 N w n l o g   p n ) 1 N \text{BLEU} = BP \cdot \exp\left( \sum_{n=1}^{N} w_n log\ p_n\right)^{\frac{1}{N}} BLEU=BPexp(n=1Nwnlog pn)N1
首先要明确两个概念

  1. N-gram
    用来描述句子中的一组 n 个连续的单词。比如,“Thank you so much” 中的 n-grams:

    • 1-gram: “Thank”, “you”, “so”, “much”
    • 2-gram: “Thank you”, “you so”, “so much”
    • 3-gram: “Thank you so”, “you so much”
    • 4-gram: “Thank you so much”

    需要注意的一点是,n-gram 中的单词是按顺序排列的,所以 “so much Thank you” 不是一个有效的 4-gram。

  2. 精确度(Precision)
    精确度是 Candidate text 中与 Reference text 相同的单词数占总单词数的比例。 具体公式如下:
    $ \text{Precision} = \frac{\text{Number of overlapping words}}{\text{Total number of words in candidate text}} $
    比如:
    Candidate: Thank you so much, Chris
    Reference: Thank you so much, my brother
    这里相同的单词数为4,总单词数为5,所以 Precision = 4 5 \text{Precision} = \frac{{4}}{{5}} Precision=54
    但存在一个问题:

    • Repetition 重复

      Candidate: Thank Thank Thank
      Reference: Thank you so much, my brother

      此时的 Precision = 3 3 \text{Precision} = \frac{{3}}{{3}} Precision=33

解决方法:Modified Precision

很简单的思想,就是匹配过的不再进行匹配。

Candidate: Thank Thank Thank
Reference: Thank you so much, my brother

Precision 1 = 1 3 \text{Precision}_1 = \frac{{1}}{{3}} Precision1=31

  • 具体计算如下:

    C o u n t c l i p = min ⁡ ( C o u n t ,   M a x _ R e f _ C o u n t ) = min ⁡ ( 3 ,   1 ) = 1 Count_{clip} = \min(Count,\ Max\_Ref\_Count)=\min(3,\ 1)=1 Countclip=min(Count, Max_Ref_Count)=min(3, 1)=1
    $ p_n = \frac{\sum_{\text{n-gram}} Count_{clip}}{\sum_{\text{n-gram}} Count} = \frac{1}{3}$

现在还存在一个问题:译文过短

Candidate: Thank you
Reference: Thank you so much, my brother

p 1 = 2 2 = 1 p_1 = \frac{{2}}{{2}} = 1 p1=22=1

这里引出了 brevity penalty,这是一个惩罚因子,公式如下:

B P = { 1 if  c > r e 1 − r c if  c ≤ r BP = \begin{cases} 1& \text{if}\ c>r\\ e^{1-\frac{r}{c}}& \text{if}\ c \leq r \end{cases} BP={1e1crif c>rif cr

其中 c 是 candidate 的长度,r 是 reference 的长度。

当候选译文的长度 c 等于参考译文的长度 r 的时候,BP = 1,当候选翻译的文本长度较短的时候,用 e 1 − r c e^{1-\frac{r}{c}} e1cr 作为 BP 值。

回到原来的公式:$ \text{BLEU} = BP \cdot \exp\left( \sum_{n=1}^{N} w_n log\ p_n\right)^{\frac{1}{N}}$,汇总一下符号定义:

  • B P BP BP 文本长度的惩罚因子
  • N N N n-gram 中 n 的最大值,作业中设置为 4。
  • w n w_n wn 权重
  • p n p_n pn n-gram 的精度 (precision)

数据解析

  • Paired data
    • TED2020: 演讲
      • Raw: 400,726 (sentences)
      • Processed: 394, 052 (sentences)
    • 英文和中文两个版本
  • Monolingual data
    • 只有中文版本的 TED 演讲数据

Baselines

这里存在一个问题,就是HW05是在 Judgeboi 上进行提交的,所以没办法获取最终的分数,所以简单的使用 simple baseline 对应的 validate BLEU 来做个映射。

因为有 EN / ZH 两个版本,对于每个 hint 我会给出代码的修改位置方便大家索引。

Simple baseline (15.05)

  • 运行所给的 sample code

Medium baseline (18.44)

  • 增加学习率的调度 (Optimizer: Adam + lr scheduling / 优化器: Adam + 学习率调度)
  • 训练得更久 (Configuration for experiments / 实验配置)
    这里根据预估的时间,可以简单的将 epoch 设置为原来的两倍。

Strong baseline (23.57)

  • 将模型架构转变为 Transformer (Model Initialization / 模型初始化)
  • 调整超参数 (Architecture Related Configuration / 架构相关配置)
    这里需要参考 Attention is all you need 论文中 table 3 的 transformer-base 超参数设置。
    image-20231115135033382

你可以仅遵循 sample code 的注释,将 encoder_layer 和 decoder_layer 改为 4(简单的将这一个改动称之为 transformer_4layer),此时模型的参数数量会和之前的 RNN 差不多,在 max_epoch =30 的情况下,Bleu 可以达到 23.59。

代码仓库中分享的 Strong 代码完全遵循了 transformer-base 的超参数设置,此时的模型参数将约为之前 RNN 的 5 倍,每一轮训练的时间约为 transform_4layer 的三倍,所以我将 max_epoch 设置为了 10,让其能够匹配上预估的时间,此时的 Bleu 为 24.91。如果将 max_epoch 设置为 30,最终的 Bleu 可以达到 27.48。

下面是二者实验对比。

Boss baseline (30.08)

  • 应用 back-translation (TODO)

    这里我们需要交换实验配置 config 中的 source_lang 和 target_lang,并修改 savedir,训练一个 back-translation 模型后再修改回原来的 config。

    然后你需要将 TODO 的部分完善,修改并复用之前的函数就可以达到目的。

    (为了与预估时间匹配,这里将 max_epoch 设置为 30 进行实验。)

代码仓库中分享的 Boss 代码展示的是最终训练的结果,完整的运行流程是:

  1. 实验配置中 / Configuration for experimentsBACK_TRANSLATION 设置为 True 运行
    训练一个 back-translation 模型,并处理好对应的语料。
  2. 实验配置 / Configuration for experiments 中的 BACK_TRANSLATION 设置为 False 运行
    结合 ted2020 和 mono (back-translation) 的语料进行训练。

Gradescope

Visualize Positional Embedding

你可以直接在 确定用于生成 submission 的模型权重 / Confirm model weights used to generate submission 后进行处理,在仓库的代码中我已经提前注释掉了 训练循环 / Training loop 中的训练部分,如果在之前,模型没有训练,直接运行代码会报错。

image-20231119122408389

添加的处理代码如下(可以复制下面的处理代码放到你的 submission 模块之后):

推荐阅读:All Pairs Cosine Similarity in PyTorch

pos_emb = model.decoder.embed_positions.weights.cpu().detach()

# 计算余弦相似度矩阵
def get_cosine_similarity_matrix(x):
    x = x / x.norm(dim=1, keepdim=True)
    sim = torch.mm(x, x.t())
    return sim
    
sim = get_cosine_similarity_matrix(pos_emb)
#sim = F.cosine_similarity(pos_emb.unsqueeze(1), pos_emb.unsqueeze(0), dim=2) # 一样的

# 绘制位置向量的余弦相似度矩阵的热力图
plt.imshow(sim, cmap="hot", vmin=0, vmax=1)
plt.colorbar()

plt.show()

Clipping Gradient Norm

只需要将 config.wandb 设置为 True 即可,此时可以在 wandb 上查看。

image-20231119183413555

或者直接在 train_one_epoch 添加一下处理代码,记录 gnorm。

image-20231119183535765

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/165670.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

git常用命令和参数有哪些?【git看这一篇就够了】

文章目录 前言常用命令有哪些git速查表奉上常用参数后言 前言 hello world欢迎来到前端的新世界 😜当前文章系列专栏:git操作相关 🐱‍👓博主在前端领域还有很多知识和技术需要掌握,正在不断努力填补技术短板。(如果出…

MIB 6.S081 System calls(1)using gdb

难度:easy In many cases, print statements will be sufficient to debug your kernel, but sometimes being able to single step through some assembly code or inspecting the variables on the stack is helpful. To learn more about how to run GDB and the common iss…

每天一道算法题(六)——返回一组数字中所有和为 0 且不重复的三元组

文章目录 前言1、问题2、示例3、解决方法4、效果5、注意点 前言 注意:答案中不可以包含重复的三元组。 1、问题 给你一个整数数组 nums ,判断是否存在三元组 [nums[i], nums[j], nums[k]] 满足 i ! j、i ! k 且 j ! k ,同时还满足 nums[i] n…

23.11.19日总结

经过昨天的中期答辩,其实可以看出来项目进度太慢了,现在是第十周,预计第十四周是终级答辩,在这段时间要把项目写完。 前端要加上一个未登录的拦截器,后端加上全局的异常处理。对于饿了么项目的商品建表,之前…

redis问题归纳

1.redis为什么这么快? (1)基于内存操作:redis的所有数据都存在内存中,因此所有的运算都是内存级别的,所以性能比较高 (2)数据结构简单:redis的数据结构是专门设计的&…

系列五、怎么查看默认的垃圾收集器是哪个?

一、怎么查看默认的垃圾收集器是哪个 java -XX:PrintCommandLineFlags -version

python-opencv 培训课程作业

python-opencv 培训课程作业 作业一: 第一步:读取 res 下面的 flower.jpg,读取彩图,并用 opencv 展示 第二步:彩图 -> 灰度图 第三步:反转图像:最大图像灰度值减去原图像,即可得…

整数转罗马数字

罗马数字包含以下七种字符: I, V, X, L,C,D 和 M。 字符 数值 I 1 V 5 X 10 L 50 C 100 D 500 M 1000 例如, 罗马数字 2 写做 II ,即为两个并列的 1。12 写做 XII ,即为…

【Go入门】 Go搭建一个Web服务器

【Go入门】 Go搭建一个Web服务器 前面小节已经介绍了Web是基于http协议的一个服务,Go语言里面提供了一个完善的net/http包,通过http包可以很方便的搭建起来一个可以运行的Web服务。同时使用这个包能很简单地对Web的路由,静态文件&#xff0c…

线性表--链表-1

文章目录 主要内容一.链表练习题1.设计一个递归算法,删除不带头结点的单链表 L 中所有值为 X 的结点代码如下(示例): 2.设 L为带头结点的单链表,编写算法实现从尾到头反向输出每个结点的值代码如下(示例): …

vite+vue3+ts项目,使用语法糖unplugin-auto-import插件的步骤

1. 安装插件 npm install unplugin-auto-import vitejs/plugin-vue -D2. vite.config.ts中引入插件 import AutoImport from "unplugin-auto-import/vite"export default defineConfig({plugins: [vue(), AutoImport({imports: ["vue", "vue-router…

C语言:动态内存管理

目录 为什么存在动态内存分配 动态内存函数 malloc和free 示例 calloc 示例 realloc 示例 常见的动态内存错误 对NULL指针的解引用操作 对动态开辟的空间进行越界访问 对于非动态开辟内存使用free释放 使用free释放一块动态开辟内存的一部分 对同一块内存多次释…

lv11 嵌入式开发 ARM指令集中(伪操作与混合编程) 7

目录 1 伪指令 2 伪操作 3 C和汇编的混合编程 4 ATPCS协议 1 伪指令 本身不是指令,编译器可以将其替换成若干条等效指令 空指令NOP 指令LDR R1, [R2] 将R2指向的内存空间中的数据读取到R1寄存器 伪指令LDR R1, 0x12345678 R1 0x12345678 LDR伪指令可以将任…

深度学习:欠拟合与过拟合

1 定义 1.1 模型欠拟合 AI模型的欠拟合(Underfitting)发生在模型未能充分学习训练数据中的模式和结构时,导致它在训练集和验证集上都表现不佳。欠拟合通常是由于模型太过简单,没有足够的能力捕捉到数据的复杂性和细节。 1.2 模型…

mysql练习1

-- 1.查询出部门编号为BM01的所有员工 SELECT* FROMemp e WHEREe.deptno BM01; -- 2.所有销售人员的姓名、编号和部门编号。 SELECTe.empname,e.empno,e.deptno FROMemp e WHEREe.empstation "销售人员";-- 3.找出奖金高于工资的员工。 SELECT* FROMemp2 WHE…

SpringSecurity6 | 默认登录页

✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: Java从入门到精通 ✨特色专栏&#xf…

SpringBoot项目连接linux服务器数据库两种解决方法(linux直接开放端口访问本机通过SSH协议访问,以mysql为例)

最近找个springboot脚手架重新熟悉一下springboot相关框架的东西,结果发现好像项目还不能直接像数据库GUI工具一样填几个SSH参数就可以了,于是就给他再整一下看看如何解决 linux开放3306(可修改)端口直接访问 此方法较为方便&am…

基于一致性算法的微电网分布式控制MATLAB仿真模型

微❤关注“电气仔推送”获得资料(专享优惠) 本模型主要是基于一致性理论的自适应虚拟阻抗、二次电压补偿以及二次频率补偿,实现功率均分,保证电压以及频率稳定性。 一致性算法 分布式一致性控制主要分为两类:协调同…

ZJU Beamer学习手册(二)

ZJU Beamer学习手册基于 Overleaf 的 ZJU Beamer模板 进行解读,本文则基于该模版进行进一步修改。 参考文献 首先在frame文件夹中增加reference.tex文件,文件内容如下。这段代码对参考文献的引用进行了预处理。 \usepackage[backendbiber]{biblatex} \…

学习网络编程No.10【深入学习HTTPS】

引言: 北京时间:2023/11/14/18:45,因为种种原因,上个月的文章昨天才更新,目前处于刷题前夕,算法课在看了。这次和以前不一样,因为以前对知识框架没有很好的理念,并不清楚相关知识要…