[oneAPI] 使用Bert进行中文文本分类

[oneAPI] 使用Bert进行中文文本分类

  • Intel® Optimization for PyTorch
  • 基于BERT的文本分类模型
    • 数据预处理
    • 数据集
      • 定义tokenize
      • 建立词表
      • 转换为Token序列
      • padding处理与mask
    • 模型
  • 结果
  • OneAPI
  • 参考资料

比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/

Intel® Optimization for PyTorch

在本次实验中,我们利用PyTorch和Intel® Optimization for PyTorch的强大功能,对PyTorch进行了精心的优化和扩展。这些优化举措极大地增强了PyTorch在各种任务中的性能,尤其是在英特尔硬件上的表现更加突出。通过这些优化策略,我们的模型在训练和推断过程中变得更加敏捷和高效,显著地减少了计算时间,提高了整体效能。我们通过深度融合硬件和软件的精巧设计,成功地释放了硬件潜力,使得模型的训练和应用变得更加快速和高效。这一系列优化举措为人工智能应用开辟了新的前景,带来了全新的可能性。
在这里插入图片描述

基于BERT的文本分类模型

基于BERT的文本分类模型就是在原始的BERT模型后再加上一个分类层即可,同时,对于分类层的输入(也就是原始BERT的输出),默认情况下取BERT输出结果中[CLS]位置对于的向量即可,当然也可以修改为其它方式,例如所有位置向量的均值等。因此,对于基于BERT的文本分类模型来说其输入就是BERT的输入,输出则是每个类别对应的logits值。

数据预处理

在构建数据集之前,我们首先需要知道的是模型到底应该接收什么样的输入,然后才能构建出正确的数据形式。在上面我们说到,基于BERT的文本分类模型的输入就等价于BERT模型的输入,同时BERT模型的输入如图1所示:
在这里插入图片描述

数据集

在这里,我们使用到的数据集是今日头条开放的一个新闻分类数据集(https://github.com/aceimnorstuvwxz/toutiao-text-classfication-dataset),一共包含有382688条数据,15个类别,经过处理后数据集格式为:

千万不要乱申请网贷,否则后果很严重_!_4
10年前的今天,纪念5.12汶川大地震10周年_!_11
怎么看待杨毅在一NBA直播比赛中说詹姆斯的球场统治力已经超过乔丹、伯德和科比?_!_3
戴安娜王妃的车祸有什么谜团?_!_2

其中_!_左边为新闻标题,也就是后面需要用到的分类文本,右边为类别标签。

定义tokenize

将输入进来的文本序列tokenize到字符级别。对于中文语料来说就是将每个字和标点符号都给切分开。在这里,我们可以借用transformers包中的BertTokenizer方法来完成,如下所示:

1 if __name__ == '__main__':
2     model_config = ModelConfig()
3     tokenizer = BertTokenizer.from_pretrained(model_config.pretrained_model_dir).tokenize
4     print(tokenizer("青山不改,绿水长流,我们月来客栈见!"))
5     print(tokenizer("10年前的今天,纪念5.12汶川大地震10周年"))
6 
7 # ['青', '山', '不', '改', ',', '绿', '水', '长', '流', ',', '我', '们', '月', '来', '客', '栈', '见', '!']
8 # ['10', '年', '前', '的', '今', '天', ',', '纪', '念', '5', '.', '12', '汶', '川', '大', '地', '震', '10', '周', '年']

建立词表

将vocab.txt中的内容读取进来形成一个词表即可

1 class Vocab:
 2     UNK = '[UNK]'
 3     def __init__(self, vocab_path):
 4         self.stoi = {}
 5         self.itos = []
 6         with open(vocab_path, 'r', encoding='utf-8') as f:
 7             for i, word in enumerate(f):
 8                 w = word.strip('\n')
 9                 self.stoi[w] = i
10                 self.itos.append(w)
11 
12     def __getitem__(self, token):
13         return self.stoi.get(token, self.stoi.get(Vocab.UNK))
14 
15     def __len__(self):
16         return len(self.itos)

转换为Token序列

在得到构建的字典后,便可以通过如下函数来将训练集、验证集和测试集转换成Token序列:

 1 def data_process(self, filepath):
 2     raw_iter = open(filepath, encoding="utf8").readlines()
 3     data = []
 4     max_len = 0
 5     for raw in tqdm(raw_iter, ncols=80):
 6         line = raw.rstrip("\n").split(self.split_sep)
 7         s, l = line[0], line[1]
 8         tmp = [self.CLS_IDX] + [self.vocab[token] for token in self.tokenizer(s)]
 9         if len(tmp) > self.max_position_embeddings - 1:
10             tmp = tmp[:self.max_position_embeddings - 1]  # BERT预训练模型只取前512个字符
11         tmp += [self.SEP_IDX]
12         tensor_ = torch.tensor(tmp, dtype=torch.long)
13         l = torch.tensor(int(l), dtype=torch.long)
14         max_len = max(max_len, tensor_.size(0))
15         data.append((tensor_, l))
16     return data, max_len

padding处理与mask

对原始文本序列tokenize转换为Token ID后还需要对其进行padding处理。对于这一处理过程可以通过如下代码来完成:

 1 def pad_sequence(sequences, batch_first=False, max_len=None, padding_value=0):
 2     if max_len is None:
 3         max_len = max([s.size(0) for s in sequences])
 4     out_tensors = []
 5     for tensor in sequences:
 6         if tensor.size(0) < max_len:
 7             tensor = torch.cat([tensor, torch.tensor(
 8               [padding_value] * (max_len - tensor.size(0)))], dim=0)
 9         else:
10             tensor = tensor[:max_len]
11         out_tensors.append(tensor)
12     out_tensors = torch.stack(out_tensors, dim=1)
13     if batch_first:
14         return out_tensors.transpose(0, 1)
15     return out_tensors

模型

class BertModel(nn.Module):
    """

    """

    def __init__(self, config):
        super().__init__()
        self.bert_embeddings = BertEmbeddings(config)
        self.bert_encoder = BertEncoder(config)
        self.bert_pooler = BertPooler(config)
        self.config = config
        self._reset_parameters()

    def forward(self,
                input_ids=None,
                attention_mask=None,
                token_type_ids=None,
                position_ids=None):
        """
        ***** 一定要注意,attention_mask中,被mask的Token用1(True)表示,没有mask的用0(false)表示
        这一点一定一定要注意
        :param input_ids:  [src_len, batch_size]
        :param attention_mask: [batch_size, src_len] mask掉padding部分的内容
        :param token_type_ids: [src_len, batch_size]  # 如果输入模型的只有一个序列,那么这个参数也不用传值
        :param position_ids: [1,src_len] # 在实际建模时这个参数其实可以不用传值
        :return:
        """
        embedding_output = self.bert_embeddings(input_ids=input_ids,
                                                position_ids=position_ids,
                                                token_type_ids=token_type_ids)
        # embedding_output: [src_len, batch_size, hidden_size]
        all_encoder_outputs = self.bert_encoder(embedding_output,
                                                attention_mask=attention_mask)
        # all_encoder_outputs 为一个包含有num_hidden_layers个层的输出
        sequence_output = all_encoder_outputs[-1]  # 取最后一层
        # sequence_output: [src_len, batch_size, hidden_size]
        pooled_output = self.bert_pooler(sequence_output)
        # 默认是最后一层的first token 即[cls]位置经dense + tanh 后的结果
        # pooled_output: [batch_size, hidden_size]
        return pooled_output, all_encoder_outputs

    def _reset_parameters(self):
        r"""Initiate parameters in the transformer model."""
        """
        初始化
        """
        for p in self.parameters():
            if p.dim() > 1:
                normal_(p, mean=0.0, std=self.config.initializer_range)

    @classmethod
    def from_pretrained(cls, config, pretrained_model_dir=None):
        model = cls(config)  # 初始化模型,cls为未实例化的对象,即一个未实例化的BertModel对象
        pretrained_model_path = os.path.join(pretrained_model_dir, "pytorch_model.bin")
        if not os.path.exists(pretrained_model_path):
            raise ValueError(f"<路径:{pretrained_model_path} 中的模型不存在,请仔细检查!>\n"
                             f"中文模型下载地址:https://huggingface.co/bert-base-chinese/tree/main\n"
                             f"英文模型下载地址:https://huggingface.co/bert-base-uncased/tree/main\n")
        loaded_paras = torch.load(pretrained_model_path)
        state_dict = deepcopy(model.state_dict())
        loaded_paras_names = list(loaded_paras.keys())[:-8]
        model_paras_names = list(state_dict.keys())[1:]
        if 'use_torch_multi_head' in config.__dict__ and config.use_torch_multi_head:
            torch_paras = format_paras_for_torch(loaded_paras_names, loaded_paras)
            for i in range(len(model_paras_names)):
                logging.debug(f"## 成功赋值参数:{model_paras_names[i]},形状为: {torch_paras[i].size()}")
                if "position_embeddings" in model_paras_names[i]:
                    # 这部分代码用来消除预训练模型只能输入小于512个字符的限制
                    if config.max_position_embeddings > 512:
                        new_embedding = replace_512_position(state_dict[model_paras_names[i]],
                                                             loaded_paras[loaded_paras_names[i]])
                        state_dict[model_paras_names[i]] = new_embedding
                        continue
                state_dict[model_paras_names[i]] = torch_paras[i]
            logging.info(f"## 注意,正在使用torch框架中的MultiHeadAttention实现")
        else:
            for i in range(len(loaded_paras_names)):
                logging.debug(f"## 成功将参数:{loaded_paras_names[i]}赋值给{model_paras_names[i]},"
                              f"参数形状为:{state_dict[model_paras_names[i]].size()}")
                if "position_embeddings" in model_paras_names[i]:
                    # 这部分代码用来消除预训练模型只能输入小于512个字符的限制
                    if config.max_position_embeddings > 512:
                        new_embedding = replace_512_position(state_dict[model_paras_names[i]],
                                                             loaded_paras[loaded_paras_names[i]])
                        state_dict[model_paras_names[i]] = new_embedding
                        continue
                state_dict[model_paras_names[i]] = loaded_paras[loaded_paras_names[i]]
            logging.info(f"## 注意,正在使用本地MyTransformer中的MyMultiHeadAttention实现,"
                         f"如需使用torch框架中的MultiHeadAttention模块可通过config.__dict__['use_torch_multi_head'] = True实现")
        model.load_state_dict(state_dict)
        return model

结果

在这里插入图片描述

OneAPI

import intel_extension_for_pytorch as ipex

model = model.to(config.device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)

参考资料

基于BERT预训练模型的中文文本分类任务: https://www.ylkz.life/deeplearning/p10979382/

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/86497.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

el-table根据容器大小自适应滚动条-修改滚动条样式

需求&#xff1a;父容器里有多个容器为上下级&#xff0c;之后浏览器在缩放的时候&#xff0c;上面容器高度改变了&#xff0c;所以el-table被挤压&#xff0c;如果el-table设置的是固定的高度&#xff0c;那么挤压后内容超出父容器&#xff0c;本文章就是解决这个问题 不自适…

2023年7月京东净水器行业品牌销售排行榜(京东数据分析软件)

伴随消费升级及健康生活理念的流行&#xff0c;消费者对饮水健康的关注度也逐步提高。加之经净水器处理的水在安全性、便捷性等方面的优势得到认可&#xff0c;净水器这一电器的市场占比也不断提高。在家电行业整体低迷的环境下&#xff0c;净水器的销量销额仍保持正向增长。 …

阿里云2核4G服务器配置汇总表_轻量和ECS

阿里云2核4G服务器配置价格表&#xff0c;297元一年&#xff0c;配置为轻量应用服务器2核4G、4M带宽、60GB高效云盘&#xff0c;折合24元一个月。 目录 2核4G服务器轻量&#xff1a; 2核4G服务器ECS 关于轻量和ECS的区别&#xff1a; 2核4G服务器轻量&#xff1a; 云服务器…

Docker碎碎念

docker和虚拟机的区别 虚拟机&#xff08;VM&#xff09;是通过在物理硬件上运行一个完整的操作系统来实现的。 每个虚拟机都有自己的内核、设备驱动程序和用户空间&#xff0c;它们是相互独立且完全隔离的。 虚拟机可以在不同的物理服务器之间迁移&#xff0c;因为它们是以整…

快速提高写作生产力——使用PicGo+Github搭建免费图床,并结合Typora

文章目录 简述PicGo下载PicGo获取Token配置PicGo结合Typora总结 简述PicGo PicGo: 一个用于快速上传图片并获取图片 URL 链接的工具 PicGo 本体支持如下图床&#xff1a; 七牛图床 v1.0腾讯云 COS v4\v5 版本 v1.1 & v1.5.0又拍云 v1.2.0GitHub v1.5.0SM.MS V2 v2.3.0-b…

漏洞挖掘和安全审计的技巧与策略

文章目录 漏洞挖掘&#xff1a;发现隐藏的弱点1. 源代码审计&#xff1a;2. 黑盒测试&#xff1a;3. 静态分析工具&#xff1a; 安全审计&#xff1a;系统的全面评估1. 渗透测试&#xff1a;2. 代码审计&#xff1a;3. 安全策略审查&#xff1a; 代码示例&#xff1a;SQL注入漏…

TCP编程流程(补充)

目录 1、listen&#xff1a; 2、listen、tcp三次握手 3、 发送缓冲区和接收缓冲区&#xff1a; 4、tcp编程启用多线程 1、listen&#xff1a; 执行listen会创建一个监听队列 listen(sockfd,5) 2、listen、tcp三次握手 三次握手 3、 发送缓冲区和接收缓冲区&#xff1a;…

Spring事务和事务传播机制(2)

前言&#x1f36d; ❤️❤️❤️SSM专栏更新中&#xff0c;各位大佬觉得写得不错&#xff0c;支持一下&#xff0c;感谢了&#xff01;❤️❤️❤️ Spring Spring MVC MyBatis_冷兮雪的博客-CSDN博客 在Spring框架中&#xff0c;事务管理是一种用于维护数据库操作的一致性和…

Gitlab 安装全流程

Version&#xff1a;gitlab-ce:16.2.4-ce.0 简介 Gitlab 是一个开源的 Git 代码仓库系统&#xff0c;可以实现自托管的 Github 项目&#xff0c;即用于构建私有的代码托管平台和项目管理系统。系统基于 Ruby on Rails 开发&#xff0c;速度快、安全稳定。它拥有与 Github 类似…

Java算法_ BST 中第 k 个最小元素 (LeetCode_Hot100)

题目描述&#xff1a;给定一个二叉搜索树的根节点 &#xff0c;和一个整数 &#xff0c;请你设计一个算法查找其中第 个最小元素&#xff08;从 1 开始计数&#xff09;。 获得更多&#xff1f;算法思路:代码文档&#xff0c;算法解析的私得。 运行效果 完整代码 /*** 2 * Aut…

Linux学习记录——이십오 多线程(2)

文章目录 1、理解原生线程库线程局部存储 2、互斥1、并发代码&#xff08;抢票&#xff09;2、锁3、互斥锁的实现原理 3、线程封装1、线程本体2、封装锁 4、线程安全5、死锁6、线程同步1、条件变量1、接口2、demo代码 1、理解原生线程库 线程库在物理内存中存在&#xff0c;也…

Redis 数据库 NoSQL

目录 一、NoSQL 二、为什么会出现NoSQL技术 三、NoSQL的类别 键值&#xff08;Key-Value&#xff09;存储数据库 列存储数据库 文档型数据库 图形&#xff08;Graph&#xff09;数据库 四、NoSQL适应场景 五、在分布式数据库中CAP原理 1、CAP 2、BASE 一、NoSQL NoS…

低代码开发平台能开发什么类型的系统和软件?

低代码开发平台能开发什么类型的系统和软件&#xff1f; 1、数据分析和报告系统&#xff1a; 使用低代码平台&#xff0c;企业可以创建数据看板&#xff0c;集成不同数据源&#xff0c;自动提取、分析和可视化数据。这种系统适用于监控业务指标、分析趋势&#xff0c;并为决策…

多个微信号怎么快速发圈、自动加好友、自动回复?

一键助你快速发圈、批量自动加好友、自动回复&#xff0c;好用哭了&#xff01; 微信管理系统是一个聚合管理多个微信账号的利器&#xff0c;让你的微信管理变得简单高效。不管你是电商、微商&#xff0c;还是拥有多个微信号的用户&#xff0c;这一款微信管理软件都可以满足你的…

vue2+qrcodejs2+clipboard——实现二维码展示+下载+复制到剪切板——基础积累

最近在写后台管理系统时&#xff0c;遇到一个需求就是要实现二维码的展示下载复制到剪切板。 效果图如下&#xff1a; 1.二维码展示下载功能——qrcodejs20.0.2 我是安装的qrcodejs20.0.2&#xff0c;指定了具体的版本号&#xff0c;也可以安装默认的当前稳定版本&#xff0…

门店数字化店务经营系统怎么做?门店数字化系统推荐

为什么你的门店无人问津&#xff0c;有的门店却天天都有到店客户&#xff1f;为什么你的门店要花费两三天才能统计好经营情况&#xff0c;有的门店却能够做到“数据实时可查”&#xff1f;经营管理和营销获客是每个门店发展的重中之重&#xff0c;今天也为大家分享一套完善的门…

西门子SCALANCE W744-1PRO 客户端配置

. 安装西门子无线搜索软件PST。 无线SCALANCE W788-1PRO参数设置。 打开PST软件&#xff1a;选择Settings->Network Adapter->2本地连接 输入该无线设置的IP地址&#xff0c;进入网络访问界面。输入密码&#xff1a;admin&#xff0c;点击Log on进入。 填写本无线的SSI…

kaggle推荐系统比赛top方案汇总【附baseline代码】

推荐系统可以很好地解决信息过载以及信息不足等问题&#xff0c;广泛应用与电商、金融、新闻咨询、社交、旅游等行业&#xff0c;其中最典型并具有良好的发展和应用前景的领域就是电子商务领域。 在学术界&#xff0c;推荐系统同样是热门的研究方向&#xff0c;在各大顶会中的…

中文医学知识语言模型:BenTsao

介绍 BenTsao&#xff1a;[原名&#xff1a;华驼(HuaTuo)]: 基于中文医学知识的大语言模型指令微调 本项目开源了经过中文医学指令精调/指令微调(Instruction-tuning) 的大语言模型集&#xff0c;包括LLaMA、Alpaca-Chinese、Bloom、活字模型等。 我们基于医学知识图谱以及医…

LeetCode——二叉树篇(八)

刷题顺序及思路来源于代码随想录&#xff0c;网站地址&#xff1a;https://programmercarl.com 目录 236. 二叉树的最近公共祖先 235. 二叉搜索树的最近公共祖 迭代 递归 701. 二叉搜索树中的插入操作 450. 删除二叉搜索树中的节点 236. 二叉树的最近公共祖先 给定一个二…