【机器学习】—机器学习和NLP预训练模型探索之旅

目录

一.预训练模型的基本概念

1.BERT模型

2 .GPT模型

二、预训练模型的应用

1.文本分类

使用BERT进行文本分类

2. 问答系统

使用BERT进行问答

三、预训练模型的优化

 1.模型压缩

1.1 剪枝

权重剪枝

2.模型量化

2.1 定点量化

使用PyTorch进行定点量化

3. 知识蒸馏

3.1 知识蒸馏的基本原理

3.2 实例代码:使用知识蒸馏训练学生模型

四、结论


随着数据量的增加和计算能力的提升,机器学习和自然语言处理技术得到了飞速发展。预训练模型作为其中的重要组成部分,通过在大规模数据集上进行预训练,使得模型可以捕捉到丰富的语义信息,从而在下游任务中表现出色。

一.预训练模型的基本概念

预训练模型是一种在大规模数据集上预先训练好的模型,可以作为其他任务的基础。预训练模型的优势在于其能够利用大规模数据集中的知识,提高模型的泛化能力和准确性。常见的预训练模型包括BERT(Bidirectional Encoder Representations from Transformers)、GPT(Generative Pre-trained Transformer)等。

1.BERT模型

BERT是由Google提出的一种双向编码器表示模型。BERT通过在大规模文本数据上进行掩码语言模型(Masked Language Model, MLM)和下一句预测(Next Sentence Prediction, NSP)的预训练,使得模型可以学习到深层次的语言表示。

2 .GPT模型

GPT由OpenAI提出,是一种基于Transformer的生成式预训练模型。GPT通过在大规模文本数据上进行自回归语言模型的预训练,使得模型可以生成连贯的文本。

二、预训练模型的应用

预训练模型在NLP领域有广泛的应用,包括但不限于文本分类、问答系统、机器翻译等。以下将介绍几个具体的应用实例。

1.文本分类

文本分类是将文本数据按照预定义的类别进行分类的任务。预训练模型可以通过在大规模文本数据上进行预训练,从而捕捉到丰富的语义信息,提高文本分类的准确性。

使用BERT进行文本分类

import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split

# 加载预训练的BERT模型和分词器
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)

# 定义数据集
class TextDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            return_token_type_ids=False,
            padding='max_length',
            return_attention_mask=True,
            return_tensors='pt',
        )
        return {
            'text': text,
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# 准备数据
texts = ["I love this!", "I hate this!"]
labels = [1, 0]
train_texts, val_texts, train_labels, val_labels = train_test_split(texts, labels, test_size=0.1)

train_dataset = TextDataset(train_texts, train_labels, tokenizer, max_len=32)
val_dataset = TextDataset(val_texts, val_labels, tokenizer, max_len=32)

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False)

# 训练模型
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)
for epoch in range(3):
    model.train()
    for batch in train_loader:
        optimizer.zero_grad()
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()

# 验证模型
model.eval()
correct = 0
total = 0
with torch.no_grad():
    for batch in val_loader:
        input_ids = batch['input_ids']
        attention_mask = batch['attention_mask']
        labels = batch['label']
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        _, predicted = torch.max(outputs.logits, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Validation Accuracy: {correct / total:.2f}')

2. 问答系统

问答系统是从文本中自动提取答案的任务。预训练模型可以通过在大规模问答数据上进行预训练,从而提高答案的准确性和相关性。

使用BERT进行问答

from transformers import BertForQuestionAnswering

# 加载预训练的BERT问答模型
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

# 输入问题和上下文
question = "What is the capital of France?"
context = "Paris is the capital of France."

# 编码输入
inputs = tokenizer.encode_plus(question, context, return_tensors='pt')

# 模型预测
outputs = model(**inputs)
start_scores = outputs.start_logits
end_scores = outputs.end_logits

# 获取答案的起始和结束位置
start_idx = torch.argmax(start_scores)
end_idx = torch.argmax(end_scores) + 1

# 解码答案
answer = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][start_idx:end_idx]))
print(f'Answer: {answer}')

三、预训练模型的优化

在实际应用中,预训练模型的优化至关重要。常见的优化方法包括模型压缩、量化和蒸馏等。

 1.模型压缩

模型压缩是通过减少模型参数数量和计算量来提高模型效率的方法。压缩后的模型不仅运行速度更快,还能减少存储空间和内存占用。常见的模型压缩技术包括剪枝、量化和知识蒸馏等。

1.1 剪枝

剪枝(Pruning)是一种通过删除模型中冗余或不重要的参数来减小模型大小的方法。剪枝可以在训练过程中或训练完成后进行。常见的剪枝方法包括:

  • 权重剪枝(Weight Pruning):删除绝对值较小的权重,认为这些权重对模型输出影响不大。
  • 结构剪枝(Structured Pruning):删除整个神经元或卷积核,减少模型的计算量和存储需求。

剪枝后的模型通常需要重新训练,以恢复或接近原始模型的性能。

权重剪枝
import torch
import torch.nn.utils.prune as prune

# 定义一个简单的模型
class SimpleModel(torch.nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = torch.nn.Linear(10, 10)
    
    def forward(self, x):
        return self.fc(x)

model = SimpleModel()

# 对模型的全连接层进行权重剪枝
prune.l1_unstructured(model.fc, name='weight', amount=0.5)

# 查看剪枝后的权重
print(model.fc.weight)

2.模型量化

模型量化是通过降低模型参数的精度来减少计算量的方法。量化通常通过将浮点数表示的权重和激活值转换为低精度表示(如8位整数)来实现。这可以显著减少模型的存储空间和计算开销,同时在硬件上加速模型推理。

2.1 定点量化

定点量化(Fixed-point Quantization)是将浮点数表示的权重和激活值转换为固定精度的整数表示。常见的定点量化包括8位整数量化(INT8),这种量化方法在不显著降低模型精度的情况下,可以大幅提升计算效率。

使用PyTorch进行定点量化
import torch
import torch.quantization

# 加载预训练模型
model = SimpleModel()

# 定义量化配置
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')

# 准备量化模型
model = torch.quantization.prepare(model, inplace=True)

# 模拟量化后的推理过程
# 这里应该使用训练数据对模型进行微调,但为了简单起见,省略此步骤
model = torch.quantization.convert(model, inplace=True)

# 查看量化后的模型
print(model)

3. 知识蒸馏

知识蒸馏(Knowledge Distillation)是通过将大模型(教师模型,Teacher Model)的知识转移到小模型(学生模型,Student Model)的方法,从而提高小模型的性能和效率。知识蒸馏的核心思想是通过教师模型的软标签(soft labels)指导学生模型的训练。

3.1 知识蒸馏的基本原理

在知识蒸馏过程中,学生模型不仅学习训练数据的真实标签,还学习教师模型对训练数据的输出,即软标签。软标签包含了更多的信息,比如类别之间的相似性,使学生模型能够更好地泛化。

蒸馏损失函数通常由两部分组成:

  • 交叉熵损失:衡量学生模型输出与真实标签之间的差异。
  • 蒸馏损失:衡量学生模型输出与教师模型软标签之间的差异。

总体损失函数为这两部分的加权和。

3.2 实例代码:使用知识蒸馏训练学生模型

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset

# 定义教师模型和学生模型
teacher_model = SimpleModel()
student_model = SimpleModel()

# 加载示例数据
data = torch.randn(100, 10)
labels = torch.randint(0, 10, (100,))
dataset = TensorDataset(data, labels)
data_loader = DataLoader(dataset, batch_size=10, shuffle=True)

# 定义蒸馏训练函数
def distillation_train(student_model, teacher_model, data_loader, optimizer, temperature=2.0, alpha=0.5):
    teacher_model.eval()
    student_model.train()
    for data, labels in data_loader:
        optimizer.zero_grad()
        
        # 教师模型输出
        with torch.no_grad():
            teacher_logits = teacher_model(data)
        
        # 学生模型输出
        student_logits = student_model(data)
        
        # 计算蒸馏损失
        loss_ce = F.cross_entropy(student_logits, labels)
        loss_kl = F.kl_div(
            F.log_softmax(student_logits / temperature, dim=1),
            F.softmax(teacher_logits / temperature, dim=1),
            reduction='batchmean'
        ) * (temperature ** 2)
        
        loss = alpha * loss_ce + (1.0 - alpha) * loss_kl
        loss.backward()
        optimizer.step()

# 定义优化器
optimizer = torch.optim.Adam(student_model.parameters(), lr=1e-3)

# 进行蒸馏训练
for epoch in range(10):
    distillation_train(student_model, teacher_model, data_loader, optimizer)

# 验证学生模型
student_model.eval()
correct = 0
total = 0
with torch.no_grad():
    for data, labels in data_loader:
        outputs = student_model(data)
        _, predicted = torch.max(outputs, dim=1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f'Student Model Accuracy: {correct / total:.2f}')

四、结论

预训练模型在机器学习和自然语言处理领域具有重要意义。通过在大规模数据集上进行预训练,模型可以捕捉到丰富的语义信息,从而在下游任务中表现出色。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/635303.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[emailprotected](7)父子通信,传递元素内容

目录 1,children 属性2,多个属性 普通对象等,可以通过变量直接传递,那类似 vue 中的 slot 插槽,如何传递元素内容? 1,children 属性 实际上,写在自定义组件标签的内部代码&#xf…

【再探】Java—泛型

Java 泛型本质是参数化类型,可以用在类、接口和方法的创建中。 1 “擦除式”泛型 Java的“擦除式”的泛型实现一直受到开发者的诟病。 “擦除式”的实现几乎只需要在Javac编译器上做出改进即可,不要改动字节码、虚拟机,也保证了以前没有使…

k8s pv 一直是release状态

如下图所示,pv 一直是release状态 这个时候大家可能就会想到现在我的 PVC 被删除了,PV 也变成了 Released 状态,那么我重建之前的 PVC 他们不就可以重新绑定了,事实并不会,PVC 只能和 Available 状态的 PV 进行绑定。…

【华为】将eNSP导入CRT,并解决不能敲Tab问题

华为】将eNSP导入CRT,并解决不能敲Tab问题 eNSP导入CRT打开eNSP,新建一个拓扑右键启动查看串口号关联CRT成功界面 SecureCRT连接华为模拟器ensp,Tab键不能补全问题选择Options(选项)-- Global Options (全局选项&#…

ORB-SLAM2从理论到代码实现(六):Tracking程序详解(上)

1. Tracking框架 Tracking线程流程框图: 各流程对应的主要函数 2. Tracking整体流程图 上面这张图把Tracking.cc讲的特别明白。 tracking线程在获取图像数据后,会传给函数GrabImageStereo、GrabImageRGBD或GrabImageMonocular进行预处理,这…

wordpress主题 ACG美化插件v3.4.2支持zibll主题7b2主题美化

独具一格的二次元风格,打造全新的子比美化方向 大部分代码均为CSS、JS做成插件只是为了方便懒人小白站长 后台全功能一览,大部分美化均为网上通用流传,

基于ucos-ii操作系统的生产者消费者-问题

目 录 第1章 题目分析. 1 1.1 生产者线程... 1 1.2 消费者线程... 1 1.3 缓冲区... 1 1.4 进程的同步与互斥... 1 第2章 解决方案. 2 2.1 总体方案... 2 2.2 生产者问题... 2 2.3 消费者问题... 3 2.4 进程问题... 5 第3章 实验结果. 6 3.1 运行结果... 6 3.2 结果分析... 8 第…

用kimi一键绘制《庆余年》人物关系图谱

《庆余年》里面人物关系复杂,如果能画出一个人物关系图谱,可以直观的理解其中人物关系,更好的追剧。 首先,用kimi下载庆余年的分集剧情,常见文章《AI网络爬虫:批量爬取电视猫上面的《庆余年》分集剧情》&am…

【Java面试】三、Redis篇(下)

文章目录 1、抢券场景2、Redis分布式锁3、Redisson实现分布式锁4、Redisson实现的分布式锁是可重入锁5、Redisson实现分布式锁下的主从一致性6、面试 1、抢券场景 正常思路: 代码实现: 比如优惠券数量为1。正常情况下:用户A的请求过来&a…

Centos7.9上安装Oracle 11gR2 RAC 三节点(ASMlib管理asm磁盘)

服务器规划 OS 规格 主机名 IP VIP private IP scanip centos 7.9 1C4G racdb01 192.168.40.165 192.168.183.165 192.168.40.16 192.168.40.200 centos 7.9 1C4G racdb02 192.168.40.175 192.168.183.175 192.168.40.17 192.168.40.200 centos 7.9 1C4G…

目前流行的前端框架有哪些?

目前流行的前端框架有很多,它们可以帮助开发者快速构建高质量的前端应用程序。本文将介绍一些目前比较受欢迎的前端框架,并分析它们的优缺点。 React React 是一个由 Facebook 开发的开源前端JavaScript库,用于构建用户界面,尤其…

基于Vue的图片文件上传与压缩组件的设计与实现

摘要 随着前端技术的发展,系统开发的复杂度不断提升,传统开发方式将整个系统做成整块应用,导致修改和维护成本高昂。组件化开发作为一种解决方案,能够实现单独开发、单独维护,并能灵活组合组件,从而提升开…

OSPF多区域组网实验(华为)

思科设备参考:OSPF多区域组网实验(思科) 技术简介 OSPF多区域功能通过划分网络为多个逻辑区域来提高网络的可扩展性和管理性能。每个区域内部运行独立的SPF计算,而区域之间通过区域边界路由器进行路由信息交换。这种划分策略适用…

Python 机器学习 基础 之 数据表示与特征工程 【分类变量】的简单说明

Python 机器学习 基础 之 数据表示与特征工程 【分类变量】的简单说明 目录 Python 机器学习 基础 之 数据表示与特征工程 【分类变量】的简单说明 一、简单介绍 二、数据表示与特征工程 数据表示 特征工程 三、分类变量 1、One-Hot编码(虚拟变量&#xff09…

【ArcGIS微课1000例】0112:沿线(面)按距离或百分比生成点

文章目录 一、沿线生成点工具介绍二、线状案例三、面状案例一、沿线生成点工具介绍 位置:工具箱→数据管理工具→采样→沿线生成点 摘要:沿线或面以固定间隔或百分比创建点要素。 用法:输入要素的属性将保留在输出要素类中。向输出要素类添加新字段 ORIG_FID,并设置为输…

Vue进阶之Vue项目实战(三)

Vue项目实战 图表渲染安装echarts图表渲染器(图表组件)图表举例:创建 ChartsRenderer.vue创建 ChartsDataTransformer.ts基于 zrender 开发可视化物料安装 zrender画一个矩形画一个柱状图基于svg开发可视化物料svg小示例使用d3进行图表渲染安装d3基本使用地图绘制本地持久化拓…

Leetcode861. 翻转矩阵后的得分

Every day a Leetcode 题目来源:861. 翻转矩阵后的得分 解法1:贪心 对于二进制数来说,我们只要保证最高位是1,就可以保证这个数是最大的,因为移动操作会使得它取反,因此我们进行行变化的时候只需要考虑首…

深度学习:手撕 RNN(2)-RNN 的常见模型架构

本文首次发表于知乎,欢迎关注作者。 上一篇文章我们介绍了一个基本的 RNN 模块。有了 这个 RNN 模块后,就像搭积木一样,以 RNN 为基本单元,根据不同的任务或者需求,可以构建不同的模型架构。本节介绍的所有结构&#…

Glassnode 内容主管:「减半」后的市场「抑郁」

原文标题:《Finance Bridge: Post-Halving Blues》撰文:Marcin Miłosierny,Glassnode 内容主管编译:Chris,Techub News 文章来源香港Web3媒体Techun News 摘要: 每月简报:4 月,尽…

前端自动将 HTTP 请求升级为 HTTPS 请求

前端将HTTP请求升级为HTTPS请求有两种方式&#xff1a; 一、index.html 中插入meta 直接在首页 index.html 的 head 中加入一条 meta 即可&#xff0c;如下所示&#xff1a; <meta http-equiv"Content-Security-Policy" content"upgrade-insecure-requests&…