李宏毅机器学习2022-HW7-BERT-Question Answering

文章目录

  • Task
  • Baseline
    • Medium
    • Strong
    • Boss
  • Code Link

Task

HW7的任务是通过BERT完成Question Answering。

数据预处理流程梳理

数据解压后包含3个json文件:hw7_train.json, hw7_dev.json, hw7_test.json。

DRCD: 台達閱讀理解資料集 Delta Reading Comprehension Dataset

ODSQA: Open-Domain Spoken Question Answering Dataset

  • train: DRCD + DRCD-TTS
    • 10524 paragraphs, 31690 questions
  • dev: DRCD + DRCD-TTS
    • 1490 paragraphs, 4131 questions
  • test: DRCD + ODSQA
    • 1586 paragraphs, 4957 questions

{train/dev/test}_questions:

  • List of dicts with the following keys:
  • id (int)
  • paragraph_id (int)
  • question_text (string)
  • answer_text (string)
  • answer_start (int)
  • answer_end (int)

{train/dev/test}_paragraphs:

  • List of strings
  • paragraph_ids in questions correspond to indexs in paragraphs
  • A paragraph may be used by several questions

读取这三个文件,每个文件返回相应的question数据和paragraph数据,都是文本数据,不能作为模型的输入。

利用Tokenization将question和paragraph文本数据先按token为单位分开,再转换为tokens_to_ids数字数据。Dataset选取paragraph中固定长度的片段(固定长度为150),片段需包含answer部分,然后使用Tokenization 以CLS + question + SEP + document+ CLS + padding(不足的补0)的形式作为训练输入。

Total sequence length = question length + paragraph length + 3 (special tokens)
Maximum input sequence length of BERT is restricted to 512

在这里插入图片描述
在这里插入图片描述

training

在这里插入图片描述

testing

对于每个窗口,模型预测一个开始分数和一个结束分数,取最大值作为答案

在这里插入图片描述

Baseline

Medium

应用linear learning rate decay+change doc_stride

这里linear learning rate decay选用了两种方法

  • 手动调整学习率

    假设初始学习率为 η 0 η_0 η0,总的步骤数为 T T T,那么在第 t t t步时的学习率 η t η_t ηt 可以表示为:

    η t = η 0 − η 0 T × t η_t=η_0−\frac{η_0}{T}×t ηt=η0Tη0×t

    其中:

    • η 0 η_0 η0 是初始学习率。
    • T T T是总的步骤数(total_step)。
    • t t t 是当前的步骤数(从 0 开始计数)。

    optimizer.param_groups[0]["lr"] -= learning_rate / total_step η t = η t − η 0 T η t η_t=η_t−\frac{η_0}{T}η_t ηt=ηtTη0ηt

    • optimizer.param_groups[0]["lr"] 对应 η t η_t ηt
    • learning_rate 对应 η 0 η_0 η0
    • total_step 对应 T T T
    • i 对应 t t t
    # Medium--Learning rate dacay
    # Method 1: adjust learning rate manually
    total_step = 1000
    for i in range(total_step):
        optimizer.param_groups[0]["lr"] -= learning_rate / total_step
    
  • 通过scheduler自动调整学习率

    • (recommend) transformer
    • torch.optim
    # Method 2: Adjust learning rate automatically by scheduler
    
    # (Recommend) https://huggingface.co/transformers/main_classes/optimizer_schedules.html#transformers.get_linear_schedule_with_warmup
    from transformers import get_linear_schedule_with_warmup
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=100, num_training_steps=1000)
    
    # https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
    # 这里如果要用pytorch的ExponentialLR,一定要导入optim模块,并且前面的AdamW是从transformers中import的这里要重新import
    import torch.optim as optim
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
    

change doc_stride在QA_Dataset的时候修改段落滑动窗口的步长

##### TODO: Change value of doc_stride #####
# 段落滑动窗口的步长
self.doc_stride = 30  # Medium

Strong

应用➢ Improve preprocessing ➢ Try other pretrained models

  • 尝试其他预训练模型

比如bert-base-multilingual-case,因为它可以避免英文无法tokenization输出[UNK],但是计算量大

model = BertForQuestionAnswering.from_pretrained("hfl/chinese-macbert-large").to(device)
tokenizer = BertTokenizerFast.from_pretrained("hfl/chinese-macbert-large")
  • preprocessing ,在QA_Dataset中修改截取答案的窗口

    1. 随机窗口选择 Random Window Selection
      随机选择窗口的起始位置

      • 随机范围的下界
        start_min = max(0, answer_end_token - self.max_paragraph_len + 1) 答案结束位置向前移动 self.max_paragraph_len - 1 个标记后的位置和 0 较大的那个
      • 随机范围的上界
        start_max = min(answer_start_token, len(tokenized_paragraph) - self.max_paragraph_len)
        • len(tokenized_paragraph) - self.max_paragraph_len:计算段落长度减去窗口长度后的位置,确保窗口不会超出段落末尾。
        • min(answer_start_token, ...):确保上界不超过答案开始位置,避免答案被截断。
      • 随机选择
        paragraph_start = random.randint(start_min, start_max)在计算出的下界和上界之间随机选择一个整数作为窗口的起始位置。
      • 计算窗口结束位置
        paragraph_end = paragraph_start + self.max_paragraph_len确保窗口长度为 self.max_paragraph_len
    2. 滑动窗口大小 Dynamic window size

            ##### TODO: Preprocessing Strong #####
            # Hint: How to prevent model from learning something it should not learn
    
            if self.split == "train":
                # Convert answer's start/end positions in paragraph_text to start/end positions in tokenized_paragraph
                answer_start_token = tokenized_paragraph.char_to_token(question["answer_start"])
                answer_end_token = tokenized_paragraph.char_to_token(question["answer_end"])
    
                # A single window is obtained by slicing the portion of paragraph containing the answer
                # 在training中paragraph的截取依据的是answer的position id
                """
                mid = (answer_start_token + answer_end_token) // 2
                paragraph_start = max(0, min(mid - self.max_paragraph_len // 2, len(tokenized_paragraph) - self.max_paragraph_len))
                paragraph_end = paragraph_start + self.max_paragraph_len"""
                # Strong
                # Method 1: Random window selection
                start_min = max(0, answer_end_token - self.max_paragraph_len + 1)  # 计算答案结束位置向前移动 self.max_paragraph_len - 1 个标记后的位置
                start_max = min(answer_start_token, len(tokenized_paragraph) - self.max_paragraph_len)
                start_max = max(start_min, start_max)
                paragraph_start = random.randint(start_min, start_max + 1)
                paragraph_end = paragraph_start + self.max_paragraph_len
                
                """
                # Method 2: Dynamic window size 
                # 这个会造成窗口的大小大于max_paragraph_len,那么会造成输入序列的长度不一致,后面padding也要改,这里暂不采用
                answer_length = answer_end_token - answer_start_token
                dynamic_window_size = max(self.max_paragraph_len, answer_length + 20)  # 添加一些额外的空间
                paragraph_start = max(0, min(answer_start_token - dynamic_window_size // 2, len(tokenized_paragraph) - dynamic_window_size))
                paragraph_end = paragraph_start + dynamic_window_size
                """
    
    

Boss

➢ Improve postprocessing ➢ Further improve the above hints

doc_stride + max_length+ learning rate scheduler + preprocessing+ postprocessing + new model + no validation

与strong baseline相比,最大的改变有两个,一是换pretrain model,在hugging face中搜索chinese + QA的模型,根据model card描述选择最好的模型,使用后大概提升2.5%的精度,二是更近一步的postprocessing,查看提交文件可看到很多answer包含CLS, SEP, UNK等字符,CLS和SEP的出现表示预测位置有误,UNK的出现说明有某些字符无法正常编码解码(例如一些生僻字),错误字符的问题均可在evaluate函数中改进,这个步骤提升了大概1%的精度。其他的修改主要是针对overfitting问题,包括减少了learning rate,提升dataset里面的paragraph max length, 将validation集合和train集合进行合并等。另外可使用的办法有ensemble,大概能提升0.5%的精度,改变random seed,也有提升的可能性。

if start_index > end_index or start_index < paragraph_start or end_index > paragraph_end:
    continue
    
if '[UNK]' in answer:
    print('发现 [UNK],这表明有文字无法编码, 使用原始文本')
    #print("Paragraph:", paragraph)
    #print("Paragraph:", paragraph_tokenized.tokens)
    print('--直接解码预测:', answer)
    #找到原始文本中对应的位置
    raw_start =  paragraph_tokenized.token_to_chars(origin_start)[0]
    raw_end = paragraph_tokenized.token_to_chars(origin_end)[1]
    answer = paragraph[raw_start:raw_end]
    print('--原始文本预测:',answer)

Code Link

github

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/895036.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

衡石分析平台系统分析人员手册-可视化报表仪表盘

仪表盘​ 仪表盘是数据分析最终展现形式&#xff0c;是数据分析的终极展现。 应用由一个或多个仪表盘展示&#xff0c;多个仪表盘之间有业务关联。 仪表盘编辑​ 图表列表​ 打开仪表盘后&#xff0c;就会看到该仪表盘中所有的图表。 调整图表布局​ 将鼠标移动到图表上拖动…

设计模式:类与类之间关系的表示方式(聚合,组合,依赖,继承,实现)

目录 聚合关系 组合关系 依赖关系 继承关系 实现关系 聚合关系 聚合是一种较弱的“拥有”关系&#xff0c;表示整体与部分的关系&#xff0c;但部分可以独立于整体存在。例如&#xff0c;部门和员工之间的关系&#xff0c;一个部门可以包含多个员工&#xff0c;但员工可以…

【大数据技术基础 | 实验四】HDFS实验:读写HDFS文件

文章目录 一、实验目的二、实验要求三、实验原理&#xff08;一&#xff09;Java Classpath&#xff08;二&#xff09;Eclipse Hadoop插件 四、实验环境五、实验内容和步骤&#xff08;一&#xff09;配置master服务器classpath&#xff08;二&#xff09;使用master服务器编写…

JOIN 表连接

1. 插入表测试数据 分别清空学生信息表 student、教师信息表 teacher、课程表 course、学生选课关联表 student_course 数据&#xff0c;并分别插入测试数据。 1.1 清空表数据 分别清空学生信息表 student、教师信息表 teacher、课程表 course、学生选课关联表 student_cours…

第8篇:网络安全基础

目录 引言 8.1 网络安全的基本概念 8.2 网络威胁与攻击类型 8.3 密码学的基本思想与加密算法 8.4 消息认证与数字签名 8.5 网络安全技术与协议 8.6 总结 第8篇&#xff1a;网络安全基础 引言 在现代信息社会中&#xff0c;计算机网络无处不在&#xff0c;从互联网到局…

Atlas800昇腾服务器(型号:3000)—驱动与固件安装(一)

服务器配置如下&#xff1a; CPU/NPU&#xff1a;鲲鹏 CPU&#xff08;ARM64&#xff09;A300I pro推理卡 系统&#xff1a;Kylin V10 SP1【下载链接】【安装链接】 驱动与固件版本版本&#xff1a; Ascend-hdk-310p-npu-driver_23.0.1_linux-aarch64.run【下载链接】 Ascend-…

0x3D service

0x3D service 1. 概念2. Request message 数据格式3. Respone message 数据格式3.1 正响应格式3.2 negative respone codes(NRC)4. 示例4.1 正响应示例:4.2 NRC 示例1. 概念 UDS(统一诊断服务)中的0x3D服务,即Write Memory By Address(按地址写内存)服务,允许客户端向服…

艺术家杨烨炘厦门开展,49只“鞋底倒计时”引轰动

艺术家杨烨炘 9&#xff0c;8&#xff0c;7&#xff0c;6&#xff0c;5&#xff0c;4&#xff0c;3&#xff0c;2&#xff0c;1&#xff0c;0………近日&#xff0c;一件名为《走向倒计时》的艺术作品在厦门引发讨论。艺术家杨烨炘邀请49位台湾同胞&#xff0c;将他们的鞋底拼成…

51单片机快速入门之 LCD1602 液晶显示屏2024/10/19

51单片机快速入门之 LCD1602 液晶显示屏 Proteus 电路图 : 74HC595 拓展电路可以不用,给 p0-p17 添加上拉电阻也可以!,我这里是方便读取和节省电阻线路 (因为之前不知道 在没有明确循环的情况下&#xff0c;Keil编译器可能会在main()中自动添加类似以下的汇编代码&#xff1a…

基于SpringBoot中药材进存销管理系统【附源码】

基于SpringBoot中药材进存销管理系统 效果如下&#xff1a; 系统注册界面 管理员主界面 员工界面 供应商界面 中药材类型界面 中药材界面 员工主界面 研究背景 随着中医药产业的快速发展&#xff0c;传统的管理方式已难以满足现代化、规模化的药材管理需求。中药材种类繁多&…

Vulnhub打靶-matrix-breakout-2-morpheus

基本信息 靶机下载&#xff1a;https://pan.baidu.com/s/1kz6ei5hNomFK44p1QT0xzQ?pwdy5qh 提取码: y5qh 攻击机器&#xff1a;192.168.20.128&#xff08;Windows操作系统&#xff09; 靶机&#xff1a;192.168.20.0/24 目标&#xff1a;获取2个flagroot权限 具体流程 …

基于FreeRTOS的LWIP移植

目录 前言一、移植准备工作二、以太网固件库与驱动2.1 固件库文件添加2.2 库文件修改2.3 添加网卡驱动 三、LWIP 数据包和网络接口管理3.1 添加LWIP源文件3.2 Lwip文件修改3.2.1 修改cc.h3.2.2 修改lwipopts.h3.2.3 修改icmp.c3.2.4 修改sys_arch.h和sys_arch.c3.2.5 修改ether…

【NOIP提高组】一元三次方程求解

【NOIP提高组】一元三次方程求解 &#x1f490;The Begin&#x1f490;点点关注&#xff0c;收藏不迷路&#x1f490; 有形如&#xff1a;ax3bx2cxd0 这样的一个一元三次方程。给出该方程中各项的系数(a&#xff0c;b&#xff0c;c&#xff0c;d均为实数)&#xff0c;并约定该方…

图像梯度-Sobel算子、scharrx算子和lapkacian算子

文章目录 一、认识什么是图像梯度和Sobel算子二、Sobel算子的具体使用三、scharrx算子与lapkacian(拉普拉斯)算子 一、认识什么是图像梯度和Sobel算子 图像的梯度是指图像亮度变化的空间导数&#xff0c;它描述了图像在不同方向上的强度变化。在图像处理和计算机视觉中&#x…

Unity使用Git及GitHub进行项目管理

git: 工作区,暂存区(存放临时要存放的内容),代码仓库区1.初始化 git init 此时展开隐藏项目,会出现.git文件夹 2.减小项目体积 touch .gitignore命令 创建.gitignore文件夹 gitignore文件夹的内容 gitignore中添加一下内容 # This .gitignore file should be place…

2025秋招八股文--网络原理篇

前言 1.本系列面试八股文的题目及答案均来自于网络平台的内容整理&#xff0c;对其进行了归类整理&#xff0c;在格式和内容上或许会存在一定错误&#xff0c;大家自行理解。内容涵盖部分若有侵权部分&#xff0c;请后台联系&#xff0c;及时删除。 2.本系列发布内容分为12篇…

《Linux从小白到高手》综合应用篇:深入理解Linux常用关键内核参数及其调优

1. 题记 有关Linux关键内核参数的调整&#xff0c;我前面的调优文章其实就有涉及到&#xff0c;只是比较零散&#xff0c;本篇集中深入介绍Linux常用关键内核参数及其调优&#xff0c;Linux调优80%以上都涉及到内核的这些参数的调整。 2. 文件系统相关参数 fs.file-max 参数…

VMware虚拟机的下载安装与使用

目录 一、下载VMware虚拟机 步骤1---找到需要的虚拟机下载位置 步骤2---找到下载的安装包安装程序 二、删除干净VMware虚拟机 步骤1--进去服务里面关闭虚拟机服务 步骤2---控制面板里面删除 步骤3---注册表删除HKEY_CURRENT_USER\Software\VMware, Inc. 步骤4---安装在…

政安晨【零基础玩转各类开源AI项目】基于本地Ubuntu (Linux ) 系统应用Gradio-Lite:无服务器 Gradio 完全在浏览器中运行

目录 简介 什么是gradio/lite&#xff1f; 入门 1.导入 JS 和 CSS 2. 创建标签 3. 在标签内编写你的 Gradio 应用程序 更多示例&#xff1a;添加其他文件和要求 多个文件 其他要求 SharedWorker 模式 代码和演示playground 1.无服务器部署 2.低延迟 3. 隐私和安全…

【C语言】文件操作(1)(文件打开关闭和顺序读写函数的万字笔记)

文章目录 一、什么是文件1.程序文件2.数据文件 二、数据文件1.文件名2.数据文件的分类文本文件二进制文件 三、文件的打开和关闭1.流和标准流流标准流 2.文件指针3.文件的打开和关闭文件的打开文件的关闭 四、文件的顺序读写1.fgetc函数2.fputc函数3.fgets函数4.fputs函数5.fsc…