李沐70_bert微调——自学笔记

微调BERT

1.BERT滴哦每一个词元返回抽取了上下文信息的特征向量

2.不同的任务使用不同的特性

句子分类

将cls对应的向量输入到全连接层分类

命名实体识别

1.识别应该词元是不是命名实体,例如人名、机构、位置

2.将非特殊词元放进全连接层分类

问题回答

1.给定应该问题和描述文字,找出一个片段作为回答

2.对片段中的每个词元预测它是不是回答的开头或结束

总结

1.即使下游任务各有不同,使用BERT微调时君只需要增加输出层

2.但根据任务的不同,输入的表示,和使用Bert特征也会不一样

!pip install d2l==0.17.6  ### 很重要,不要下载错了,对于colab

Natural Language Inference and Dataset

斯坦福自然语言推断(SNLI)数据集

import os
import re
import torch
from torch import nn
from d2l import torch as d2l

d2l.DATA_HUB['SNLI'] = (
    'https://nlp.stanford.edu/projects/snli/snli_1.0.zip',
    '9fcde07509c7e87ec61c640c1b2753d9041758e4')

data_dir = d2l.download_extract('SNLI')

读取数据集

定义函数read_snli以仅提取数据集的一部分,然后返回前提、假设及其标签的列表。*

def read_snli(data_dir, is_train):
    """将SNLI数据集解析为前提、假设和标签"""
    def extract_text(s):
        # 删除我们不会使用的信息
        s = re.sub('\\(', '', s)
        s = re.sub('\\)', '', s)
        # 用一个空格替换两个或多个连续的空格
        s = re.sub('\\s{2,}', ' ', s)
        return s.strip()
    label_set = {'entailment': 0, 'contradiction': 1, 'neutral': 2}
    file_name = os.path.join(data_dir, 'snli_1.0_train.txt'
                             if is_train else 'snli_1.0_test.txt')
    with open(file_name, 'r') as f:
        rows = [row.split('\t') for row in f.readlines()[1:]]
    premises = [extract_text(row[1]) for row in rows if row[0] in label_set]
    hypotheses = [extract_text(row[2]) for row in rows if row[0] \
                in label_set]
    labels = [label_set[row[0]] for row in rows if row[0] in label_set]
    return premises, hypotheses, labels

打印前3对前提和假设,以及它们的标签(“0”“1”和“2”分别对应于“蕴涵”“矛盾”和“中性”)。

train_data = read_snli(data_dir, is_train=True)
for x0, x1, y in zip(train_data[0][:3], train_data[1][:3], train_data[2][:3]):
    print('前提:', x0)
    print('假设:', x1)
    print('标签:', y)
前提: A person on a horse jumps over a broken down airplane .
假设: A person is training his horse for a competition .
标签: 2
前提: A person on a horse jumps over a broken down airplane .
假设: A person is at a diner , ordering an omelette .
标签: 1
前提: A person on a horse jumps over a broken down airplane .
假设: A person is outdoors , on a horse .
标签: 0

训练集约有550000对,测试集约有10000对。下面显示了训练集和测试集中的三个标签“蕴涵”“矛盾”和“中性”是平衡的。

test_data = read_snli(data_dir, is_train=False)
for data in [train_data, test_data]:
    print([[row for row in data[2]].count(i) for i in range(3)])
[183416, 183187, 182764]
[3368, 3237, 3219]

定义一个用于加载SNLI数据集的类。类构造函数中的变量num_steps指定文本序列的长度,使得每个小批量序列将具有相同的形状。

class SNLIDataset(torch.utils.data.Dataset):
    """用于加载SNLI数据集的自定义数据集"""
    def __init__(self, dataset, num_steps, vocab=None):
        self.num_steps = num_steps
        all_premise_tokens = d2l.tokenize(dataset[0])
        all_hypothesis_tokens = d2l.tokenize(dataset[1])
        if vocab is None:
            self.vocab = d2l.Vocab(all_premise_tokens + \
                all_hypothesis_tokens, min_freq=5, reserved_tokens=['<pad>'])
        else:
            self.vocab = vocab
        self.premises = self._pad(all_premise_tokens)
        self.hypotheses = self._pad(all_hypothesis_tokens)
        self.labels = torch.tensor(dataset[2])
        print('read ' + str(len(self.premises)) + ' examples')

    def _pad(self, lines):
        return torch.tensor([d2l.truncate_pad(
            self.vocab[line], self.num_steps, self.vocab['<pad>'])
                         for line in lines])

    def __getitem__(self, idx):
        return (self.premises[idx], self.hypotheses[idx]), self.labels[idx]

    def __len__(self):
        return len(self.premises)

我们可以调用read_snli函数和SNLIDataset类来下载SNLI数据集,并返回训练集和测试集的DataLoader实例,以及训练集的词表。

def load_data_snli(batch_size, num_steps=50):
    """下载SNLI数据集并返回数据迭代器和词表"""
    num_workers = d2l.get_dataloader_workers()
    data_dir = d2l.download_extract('SNLI')
    train_data = read_snli(data_dir, True)
    test_data = read_snli(data_dir, False)
    train_set = SNLIDataset(train_data, num_steps)
    test_set = SNLIDataset(test_data, num_steps, train_set.vocab)
    train_iter = torch.utils.data.DataLoader(train_set, batch_size,
                                             shuffle=True,
                                             num_workers=num_workers)
    test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                            shuffle=False,
                                            num_workers=num_workers)
    return train_iter, test_iter, train_set.vocab

在这里,我们将批量大小设置为128时,将序列长度设置为50,并调用load_data_snli函数来获取数据迭代器和词表。然后我们打印词表大小。

train_iter, test_iter, vocab = load_data_snli(128, 50)
len(vocab)
read 549367 examples
read 9824 examples


/usr/local/lib/python3.10/dist-packages/torch/utils/data/dataloader.py:558: UserWarning: This DataLoader will create 4 worker processes in total. Our suggested max number of worker in current system is 2, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(





18678

现在我们打印第一个小批量的形状。与情感分析相反,我们有分别代表前提和假设的两个输入X[0]和X[1]。

for X, Y in train_iter:
    print(X[0].shape)
    print(X[1].shape)
    print(Y.shape)
    break
/usr/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()


torch.Size([128, 50])
torch.Size([128, 50])
torch.Size([128])

Fine-Tuning BERT

import json
import multiprocessing
import os
import torch
from torch import nn
from d2l import torch as d2l

加载预训练的BERT

提供了两个版本的预训练的BERT:“bert.base”与原始的BERT基础模型一样大,需要大量的计算资源才能进行微调,而“bert.small”是一个小版本,以便于演示。

d2l.DATA_HUB['bert.base'] = (d2l.DATA_URL + 'bert.base.torch.zip',
                             '225d66f04cae318b841a13d32af3acc165f253ac')
d2l.DATA_HUB['bert.small'] = (d2l.DATA_URL + 'bert.small.torch.zip',
                              'c72329e68a732bef0452e4b96a1c341c8910f81f')

两个预训练好的BERT模型都包含一个定义词表的“vocab.json”文件和一个预训练参数的“pretrained.params”文件。我们实现了以下load_pretrained_model函数来加载预先训练好的BERT参数。

def load_pretrained_model(pretrained_model, num_hiddens, ffn_num_hiddens,
                          num_heads, num_layers, dropout, max_len, devices):
    data_dir = d2l.download_extract(pretrained_model)
    # 定义空词表以加载预定义词表
    vocab = d2l.Vocab()
    vocab.idx_to_token = json.load(open(os.path.join(data_dir,
        'vocab.json')))
    vocab.token_to_idx = {token: idx for idx, token in enumerate(
        vocab.idx_to_token)}
    bert = d2l.BERTModel(len(vocab), num_hiddens, norm_shape=[256],
                         ffn_num_input=256, ffn_num_hiddens=ffn_num_hiddens,
                         num_heads=4, num_layers=2, dropout=0.2,
                         max_len=max_len, key_size=256, query_size=256,
                         value_size=256, hid_in_features=256,
                         mlm_in_features=256, nsp_in_features=256)
    # 加载预训练BERT参数
    bert.load_state_dict(torch.load(os.path.join(data_dir,
                                                 'pretrained.params')))
    return bert, vocab

加载和微调经过预训练BERT的小版本(“bert.small”)

devices = d2l.try_all_gpus()
bert, vocab = load_pretrained_model(
    'bert.small', num_hiddens=256, ffn_num_hiddens=512, num_heads=4,
    num_layers=2, dropout=0.1, max_len=512, devices=devices)
Downloading ../data/bert.small.torch.zip from http://d2l-data.s3-accelerate.amazonaws.com/bert.small.torch.zip...

微调BERT数据集

定义了一个定制的数据集类SNLIBERTDataset。在每个样本中,前提和假设形成一对文本序列,并被打包成一个BERT输入序列,如 图15.6.2所示。回想 14.8.4节,片段索引用于区分BERT输入序列中的前提和假设。利用预定义的BERT输入序列的最大长度(max_len),持续移除输入文本对中较长文本的最后一个标记,直到满足max_len。为了加速生成用于微调BERT的SNLI数据集,我们使用4个工作进程并行生成训练或测试样本。

class SNLIBERTDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, max_len, vocab=None):
        all_premise_hypothesis_tokens = [[
            p_tokens, h_tokens] for p_tokens, h_tokens in zip(
            *[d2l.tokenize([s.lower() for s in sentences])
              for sentences in dataset[:2]])]

        self.labels = torch.tensor(dataset[2])
        self.vocab = vocab
        self.max_len = max_len
        (self.all_token_ids, self.all_segments,
         self.valid_lens) = self._preprocess(all_premise_hypothesis_tokens)
        print('read ' + str(len(self.all_token_ids)) + ' examples')

    def _preprocess(self, all_premise_hypothesis_tokens):
        pool = multiprocessing.Pool(4)  # 使用4个进程
        out = pool.map(self._mp_worker, all_premise_hypothesis_tokens)
        all_token_ids = [
            token_ids for token_ids, segments, valid_len in out]
        all_segments = [segments for token_ids, segments, valid_len in out]
        valid_lens = [valid_len for token_ids, segments, valid_len in out]
        return (torch.tensor(all_token_ids, dtype=torch.long),
                torch.tensor(all_segments, dtype=torch.long),
                torch.tensor(valid_lens))

    def _mp_worker(self, premise_hypothesis_tokens):
        p_tokens, h_tokens = premise_hypothesis_tokens
        self._truncate_pair_of_tokens(p_tokens, h_tokens)
        tokens, segments = d2l.get_tokens_and_segments(p_tokens, h_tokens)
        token_ids = self.vocab[tokens] + [self.vocab['<pad>']] \
                             * (self.max_len - len(tokens))
        segments = segments + [0] * (self.max_len - len(segments))
        valid_len = len(tokens)
        return token_ids, segments, valid_len

    def _truncate_pair_of_tokens(self, p_tokens, h_tokens):
        # 为BERT输入中的'<CLS>'、'<SEP>'和'<SEP>'词元保留位置
        while len(p_tokens) + len(h_tokens) > self.max_len - 3:
            if len(p_tokens) > len(h_tokens):
                p_tokens.pop()
            else:
                h_tokens.pop()

    def __getitem__(self, idx):
        return (self.all_token_ids[idx], self.all_segments[idx],
                self.valid_lens[idx]), self.labels[idx]

    def __len__(self):
        return len(self.all_token_ids)

下载完SNLI数据集后,我们通过实例化SNLIBERTDataset类来生成训练和测试样本。这些样本将在自然语言推断的训练和测试期间进行小批量读取。

# 如果出现显存不足错误,请减少“batch_size”。在原始的BERT模型中,max_len=512
batch_size, max_len, num_workers = 512, 128, d2l.get_dataloader_workers()
data_dir = d2l.download_extract('SNLI')
train_set = SNLIBERTDataset(d2l.read_snli(data_dir, True), max_len, vocab)
test_set = SNLIBERTDataset(d2l.read_snli(data_dir, False), max_len, vocab)
train_iter = torch.utils.data.DataLoader(train_set, batch_size, shuffle=True,
                                   num_workers=num_workers)
test_iter = torch.utils.data.DataLoader(test_set, batch_size,
                                  num_workers=num_workers)
read 549367 examples
read 9824 examples

微调BERT

用于自然语言推断的微调BERT只需要一个额外的多层感知机,该多层感知机由两个全连接层组成(请参见下面BERTClassifier类中的self.hidden和self.output)。这个多层感知机将特殊的“”词元的BERT表示进行了转换,该词元同时编码前提和假设的信息为自然语言推断的三个输出:蕴涵、矛盾和中性。

class BERTClassifier(nn.Module):
    def __init__(self, bert):
        super(BERTClassifier, self).__init__()
        self.encoder = bert.encoder
        self.hidden = bert.hidden
        self.output = nn.Linear(256, 3)

    def forward(self, inputs):
        tokens_X, segments_X, valid_lens_x = inputs
        encoded_X = self.encoder(tokens_X, segments_X, valid_lens_x)
        return self.output(self.hidden(encoded_X[:, 0, :]))

预训练的BERT模型bert被送到用于下游应用的BERTClassifier实例net中。

net = BERTClassifier(bert)

为了允许具有陈旧梯度的参数,标志ignore_stale_grad=True在step函数d2l.train_batch_ch13中被设置。我们通过该函数使用SNLI的训练集(train_iter)和测试集(test_iter)对net模型进行训练和评估。

lr, num_epochs = 1e-4, 5
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction='none')
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
    devices)
loss 0.520, train acc 0.791, test acc 0.780
2536.5 examples/sec on [device(type='cuda', index=0)]

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/581845.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

fetch请求后端返回文件流,并下载。

前端&#xff1a; <script src"~/layui/layui.js"></script> <script src"~/Content/js/common/js/vue.min.js"></script> <script src"~/Content/js/common/js/jquery-1.10.2.min.js"></script><styl…

[论文笔记]GAUSSIAN ERROR LINEAR UNITS (GELUS)

引言 今天来看一下GELU的原始论文。 作者提出了GELU(Gaussian Error Linear Unit,高斯误差线性单元)非线性激活函数&#xff1a; GELU x Φ ( x ) \text{GELU} x\Phi(x) GELUxΦ(x)&#xff0c;其中 Φ ( x ) \Phi(x) Φ(x)​是标准高斯累积分布函数。与ReLU激活函数通过输入…

pycharm配置wsl开发环境(conda)

背景 在研究qanything项目的过程中&#xff0c;为了进行二次开发&#xff0c;需要在本地搭建开发环境。然后根据文档说明发现该项目并不能直接运行在windows开发环境&#xff0c;但可以运行在wsl环境中。于是我需要先创建wsl环境并配置pycharm。 wsl环境创建 WSL是“Windows Su…

【多模态大模型】AI对视频内容解析问答

文章目录 1. 项目背景2. 直接对视频进行解析进行AI问答&#xff1a;MiniGPT4-Video2.1 MiniGPT4-Video效果 3. 对视频抽帧为图片再进行AI问答3.1 视频抽帧3.2 图片AI问答3.2.1 阿里通义千问大模型 Qwen-vl-plus3.2.2 Moonshot 1. 项目背景 最近在做一个项目,需要使用AI技术对视…

DDP示例

https://zhuanlan.zhihu.com/p/602305591 https://zhuanlan.zhihu.com/p/178402798 关于模型保存与加载 &#xff1a; 其实分为保存 有module和无module2种 &#xff1b; &#xff08;上面知乎这篇文章说带时带module) 关于2种带与不带的说明&#xff1a; https://blog.csdn.…

69、栈-有效的括号

思路&#xff1a; 有效的括号序列是指每个开括号都有一个对应的闭括号&#xff0c;并且括号的配对顺序正确。 比如&#xff1a;({)} 这个就是错误的&#xff0c;({}) 这个就是正确的。所以每一个做括号&#xff0c;必有一个对应的右括号&#xff0c;并且需要顺序正确。这里有…

Meilisearch 快速入门(Windows 环境) 搜索引擎 语义搜索

Meilisearch 快速入门(Windows 环境)# 简介# Meilisearch 是一个基于 rust 开发的,快速的、完全开源的轻量级搜索引擎。它的数据存储基于磁盘与内存映射,不受 RAM 限制。在一定数量级下,搜索速度不逊于 Elasticsearch。 下载# 官方服务端包下载地址:github.com/meili…

常用图像加密技术-流密码异或加密

异或加密是最常用的一种加密方式&#xff0c;广泛的适用于图像处理领域。这种加密方式依据加密密钥生成伪随机序列与图像的像素值进行异或操作&#xff0c;使得原像素值发生变化&#xff0c;进而使得图像内容发生变化&#xff0c;达到保护图像内容的目的。 该加密方法是以图像…

鸿蒙OpenHarmony【小型系统 烧录】(基于Hi3516开发板)

烧录 针对Hi3516DV300开发板&#xff0c;除了DevEco Device Tool&#xff08;操作方法请参考烧录)&#xff09;外&#xff0c;还可以使用HiTool进行烧录。 前提条件 开发板相关源码已编译完成&#xff0c;已形成烧录文件。客户端&#xff08;操作平台&#xff0c;例如Window…

深度学习模型的优化和调优de了解

深度学习模型的优化和调优&#xff1a;随着深度学习应用的广泛&#xff0c;优化和调优神经网络模型成为了一个重要的问题。这包括选择合适的网络架构、调整超参数、应对过拟合等。 深度学习模型的优化和调优是指在训练神经网络模型时&#xff0c;通过一系列技术和方法来提高模型…

FTP 文件传输协议

FTP 文件传输协议 作用 用来传输文件的 FTP协议采用的是TCP作为传输协议&#xff0c; 21号端口用来传输FTP控制命令的&#xff0c; 20号端口用来传输文件数据的 FTP传输模式&#xff1a; 主动模式&#xff1a; FTP服务端接收下载控制命令后&#xff0c;会主动从tcp/20号端口…

C语言之详细讲解文件操作

什么是文件 与普通文件载体不同&#xff0c;文件是以硬盘为载体存储在计算机上的信息集合&#xff0c;文件可以是文本文档、图片、程序等等。文件通常具有点三个字母的文件扩展名&#xff0c;用于指示文件类型&#xff08;例如&#xff0c;图片文件常常以KPEG格式保存并且文件…

修改word文件的创作者方法有哪些?如何修改文档的作者 这两个方法你一定要知道

在数字化时代&#xff0c;文件创作者的信息往往嵌入在文件的元数据中&#xff0c;这些元数据包括创作者的姓名、创建日期以及其他相关信息。然而&#xff0c;有时候我们可能需要修改这些创作者信息&#xff0c;出于隐私保护、版权调整或者其他实际需求。那么&#xff0c;有没有…

短信验证码绕过漏洞(一)

短信验证码绕过漏洞 0x01原理&#xff1a; 服务器端返回的相关参数作为最终登录凭证&#xff0c;导致可绕过登录限制。 危害&#xff1a;在相关业务中危害也不同&#xff0c;如找回密码&#xff0c;注册&#xff0c;电话换绑等地方即可形成高危漏洞&#xff0c;如果是一些普…

常用算法代码模板 (3) :搜索与图论

AcWing算法基础课笔记与常用算法模板 (3) ——搜索与图论 常用算法代码模板 (1) &#xff1a;基础算法 常用算法代码模板 (2) &#xff1a;数据结构 常用算法代码模板 (3) &#xff1a;搜索与图论 常用算法代码模板 (4) &#xff1a;数学知识 文章目录 0 搜索技巧1 树与图的存…

【Scala---01】Scala『 Scala简介 | 函数式编程简介 | Scala VS Java | 安装与部署』

文章目录 1. Scala简介2. 函数式编程简介3. Scala VS Java4. 安装与部署 1. Scala简介 Scala是由于Spark的流行而兴起的。Scala是高级语言&#xff0c;Scala底层使用的是Java&#xff0c;可以看做是对Java的进一步封装&#xff0c;更加简洁&#xff0c;代码量是Java的一半。 因…

MATLAB语音信号分析与合成——MATLAB语音信号分析学习资料汇总(图书、代码和视频)

教科书&#xff1a;MATLAB语音信号分析与合成&#xff08;第2版&#xff09; 链接&#xff08;含配套源代码&#xff09;&#xff1a;https://pan.baidu.com/s/1pXMPD_9TRpJmubPGaRKANw?pwd32rf 提取码&#xff1a;32rf 基础入门视频&#xff1a; 视频链接&#xff1a; 清…

MCU自动测量单元:自动化数据采集的未来

随着科技的飞速发展&#xff0c;自动化技术在各个领域中的应用日益广泛。其中&#xff0c;MCU(微控制器)自动测量单元以其高效、精准的特性&#xff0c;成为自动化数据采集领域的佼佼者&#xff0c;引领着未来数据采集技术的革新。本文将深入探讨MCU自动测量单元的原理、优势以…

Vue2 - 完成实现ElementUI中el-dialog弹窗的拖拽功能(宽度高度适配,且关闭后打开位置居中)

我们在做后台管理系统时常用到ElementUI 中的 el-Dialog,但是官方文档并未我们提供 el-Dialog弹窗如何实现拖拽功能,我们通常需要思考如何让用户能够自由地拖动弹窗,在页面上调整位置以获得更好的用户体验。在下面的博客文章中,我们将实现如何为 ElementUI 的 el-Dialog 弹…

网络安全 SQLmap-tamper的使用

目录 使用SQLmap Tamper脚本 1. 选择合适的Tamper脚本 2. 在命令行中使用Tamper脚本 3. 组合使用Tamper脚本 4. 注意和考虑 黑客零基础入门学习路线&规划 网络安全学习路线&学习资源 SQLmap是一款强大的自动化SQL注入和数据库取证工具。它用于检测和利用SQL注入漏…