使用无标注的数据训练Bert

文章目录

  • 1、准备用于训练的数据集
  • 2、处理数据集
  • 3、克隆代码
  • 4、运行代码
  • 5、将ckpt模型转为bin模型使其可在pytorch中运用

Bert官方仓库:https://github.com/google-research/bert

1、准备用于训练的数据集

此处准备的是BBC news的数据集,下载链接:https://www.kaggle.com/datasets/gpreda/bbc-news
原数据集格式(.csv):
在这里插入图片描述

2、处理数据集

训练Bert时需要预处理数据,将数据处理成https://github.com/google-research/bert/blob/master/sample_text.txt中所示格式,如下所示:
在这里插入图片描述
数据预处理代码参考:

import pandas as pd

# 读取BBC-news数据集
df = pd.read_csv("../../bbc_news.csv")
# print(df['title'])
l1 = []
l2 = []
cnt = 0
for line in df['title']:
    l1.append(line)

for line in df['description']:
    l2.append(line)
# cnt=0
f = open("test1.txt", 'w+', encoding='utf8')
for i in range(len(l1)):
    s = l1[i] + " " + l2[i] + '\n'
    f.write(s)
    # cnt+=1
    # if cnt>10: break
f.close()
# print(l1)

处理完后的BBC news数据集格式如下所示:
在这里插入图片描述

3、克隆代码

使用git克隆仓库代码
http:

git clone https://github.com/google-research/bert.git

或ssh:

git clone git@github.com:google-research/bert.git

4、运行代码

先下载Bert模型:BERT-Base, Uncased
该文件中有以下文件:
在这里插入图片描述
运行代码:
在Teminal中运行:

python create_pretraining_data.py \
  --input_file=./sample_text.txt(数据集地址) \
  --output_file=/tmp/tf_examples.tfrecord(处理后数据集保存的位置) \
  --vocab_file=$BERT_BASE_DIR/vocab.txt(vocab.txt文件位置) \
  --do_lower_case=True \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --masked_lm_prob=0.15 \
  --random_seed=12345 \
  --dupe_factor=5

训练模型:

python run_pretraining.py \
  --input_file=/tmp/tf_examples.tfrecord(处理后数据集保存的位置) \
  --output_dir=/tmp/pretraining_output(训练后模型保存位置) \
  --do_train=True \
  --do_eval=True \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json(bert_config.json文件位置) \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt(如果要从头开始的预训练,则去掉这行) \
  --train_batch_size=32 \
  --max_seq_length=128 \
  --max_predictions_per_seq=20 \
  --num_train_steps=20 \
  --num_warmup_steps=10 \
  --learning_rate=2e-5

训练完成后模型输出示例:

***** Eval results *****
  global_step = 20
  loss = 0.0979674
  masked_lm_accuracy = 0.985479
  masked_lm_loss = 0.0979328
  next_sentence_accuracy = 1.0
  next_sentence_loss = 3.45724e-05

要注意应该能够在至少具有 12GB RAM 的 GPU 上运行,不然会报错显存不足。
使用未标注数据训练BERT

5、将ckpt模型转为bin模型使其可在pytorch中运用

上一步训练好后准备好训练出来的model.ckpt-20.index文件和Bert模型中的bert_config.json文件

创建python文件convert_bert_original_tf_checkpoint_to_pytorch.py:

# coding=utf-8
# Copyright 2018 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert BERT checkpoint."""


import argparse

import torch

from transformers import BertConfig, BertForPreTraining, load_tf_weights_in_bert
from transformers.utils import logging


logging.set_verbosity_info()


def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytorch_dump_path):
    # Initialise PyTorch model
    config = BertConfig.from_json_file(bert_config_file)
    print("Building PyTorch model from configuration: {}".format(str(config)))
    model = BertForPreTraining(config)

    # Load weights from tf checkpoint
    load_tf_weights_in_bert(model, config, tf_checkpoint_path)

    # Save pytorch-model
    print("Save PyTorch model to {}".format(pytorch_dump_path))
    torch.save(model.state_dict(), pytorch_dump_path)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    # Required parameters
    parser.add_argument(
        "--tf_checkpoint_path", default=None, type=str, required=True, help="Path to the TensorFlow checkpoint path."
    )
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help="The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.",
    )
    parser.add_argument(
        "--pytorch_dump_path", default=None, type=str, required=True, help="Path to the output PyTorch model."
    )
    args = parser.parse_args()
    convert_tf_checkpoint_to_pytorch(args.tf_checkpoint_path, args.bert_config_file, args.pytorch_dump_path)

在Terminal中运行以下命令:

python convert_bert_original_tf_checkpoint_to_pytorch.py \
--tf_checkpoint_path Models/chinese_L-12_H-768_A-12/bert_model.ckpt.index(.ckpt.index文件位置) \
--bert_config_file Models/chinese_L-12_H-768_A-12/bert_config.json(bert_config.json文件位置)  \
--pytorch_dump_path  Models/chinese_L-12_H-768_A-12/pytorch_model.bin(输出的.bin模型文件位置)

以上命令最好在一行中运行:

python convert_bert_original_tf_checkpoint_to_pytorch.py --tf_checkpoint_path bert_model.ckpt.index --bert_config_file bert_config.json  --pytorch_dump_path  pytorch_model.bin

然后就可以得到bin文件了
在这里插入图片描述

【BERT for Tensorflow】本地ckpt文件的BERT使用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/17763.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

camunda表达式如何使用

在Camunda中,表达式是一种灵活的方式,可以用于在流程定义和表单中计算和处理数据。表达式可以在Camunda的各个环节中使用,例如服务任务、网关、表单、条件等。 以下是Camunda表达式的一些常见用途: 1、计算值:表达式可…

卢北辰:数据点亮梦想,能力驱动人生 | 提升之路系列(九)

导读 为了发挥清华大学多学科优势,搭建跨学科交叉融合平台,创新跨学科交叉培养模式,培养具有大数据思维和应用创新的“π”型人才,由清华大学研究生院、清华大学大数据研究中心及相关院系共同设计组织的“清华大学大数据能力提升项…

ipad有必要用手写笔吗?电容笔和Apple pencil区别

与Apple Pencil最大的不同之处,在于普通的电容笔并不具备着重力压感,而是会给人一种倾斜的压感。如果不是频繁作画,那就用一支普通的电容笔。这种电容笔不但可以用于办公室,也可以用于记笔记、做练习。再说了,一支苹果…

【C++修炼之路:二叉搜索树】

目录: 二叉搜索树的概念构建一颗二叉树二叉树的查找二插树的插入 二叉树的删除删除右子树的最小节点 写一个中序来走这个二叉搜索树递归版删除(recursion)递归版插入(recursion)递归版查找(recursion&#…

python+java+nodejs基于vue的企业人事工资管理系统

根据系统功能需求分析,对系统功能的进行设计和分解。功能分解的过程就是一个由抽象到具体的过程。 作为人事数据库系统,其主要实现的功能应包括以下几个模块: 1.登录模块 登录模块是由管理员、员工2种不同身份进行登录。 2.系统管理模块 用户…

中级软件设计师备考---软件工程1

目录 经典的模型敏捷开发方法【的分类】信息系统开发方法【的分类】结构化设计---内聚与耦合结构化设计---系统结构/模块结构 需求的分类 经典的模型 瀑布模型:最早的一类、适用于需求明确的项目、结构化的典型代表 原型模型:先构造一个建议的系统原型再…

【系统集成项目管理工程师】计算题专题一

一、决策树和期望货币值 1、项目经理向客户推荐了四种供应商选择方案。每个方案损益值已标在下面的决策树上。根据预期收益值,应选择设备供应商 A.供应商1B.供应商2C.供应商3D.供应商4 解题: 供应商 1:60% * 10000 (-30000&am…

Oracle SQL执行计划操作(13)——其他相关操作

该类操作主要包括以上未进行讲解的其他相关操作。根据不同的具体SQL语句及其他相关因素,如下各操作可能会出现于相关SQL语句的执行计划。 1)SELECT STATEMENT 检索表中数据。该操作出现于通过select语句检索表中数据时产生的执行计划。该操作具体如图15-1中节点0所示。 图1…

RISC-V OS(老师的OS) 基于 汪辰老师的视频笔记

前言 最后面没写完,以后再补。。。 RISC-V OS RVOS 介绍 操作系统定义 操作系统(英语:Operating System,缩写:OS)是一组系统软件程序: 主管并控制计算机操作、运用和运行硬件、软件资源。提…

SPSS如何进行对应分析之案例实训?

文章目录 0.引言1.对应分析2.多重对应分析 0.引言 因科研等多场景需要进行数据统计分析,笔者对SPSS进行了学习,本文通过《SPSS统计分析从入门到精通》及其配套素材结合网上相关资料进行学习笔记总结,本文对对应分析进行阐述。 1.对应分析 &a…

55、RK3588使用MPP编码yuv到h264、解码h264到yuv模块开发和测试

基本思想:需要使用独立模块代码去实现自己的逻辑功能,所以在基于官方源码基础上,和参考附录几个官方链接,搞出一版rk3588编码测试和解码测试demo 测试视频/生成h264/生成yuv 链接: https://pan.baidu.com/s/1HbpeqMJb8HcgFpzaKh…

【Linux学习】多线程——线程控制 | 线程TCB

🐱作者:一只大喵咪1201 🐱专栏:《Linux学习》 🔥格言:你只管努力,剩下的交给时间! 线程控制 | 线程TCB 🧰线程控制🎴线程创建🎴线程结束&#x1…

写作业用白光还是暖光?盘点色温4000K的护眼台灯

台灯的白光或者暖光指的是台灯的色温,低色温的光线看起来发黄发红,高色温的光线发白发蓝。 如果灯光的光源是高品质光源,本身没有蓝光问题,那么色温的选择对护眼的影响是比较少的,更多的是对人学习工作状态&#xff0c…

Linux 之 vi 文本编辑器(二)

1、文本编辑器简介 Linux 中最常用的文本编辑器: vi:类 Unix 系统中默认的文本编辑器 vim:vi 编辑器的增强版本,习惯上也称 vi vi 文本编辑器的作用和特性: vi 可以执行插入、删除、查找、替换等众多文本操作&…

Leetcode268. 丢失的数字

Every day a leetcode 题目来源&#xff1a;268. 丢失的数字 解法1&#xff1a;排序 代码&#xff1a; /** lc appleetcode.cn id268 langcpp** [268] 丢失的数字*/// lc codestart class Solution { public:int missingNumber(vector<int> &nums){int n nums.s…

ESP32设备驱动-Si1145红外接近-紫外 (UV) 指数和环境光传感器驱动

Si1145红外接近-紫外 (UV) 指数和环境光传感器驱动 文章目录 Si1145红外接近-紫外 (UV) 指数和环境光传感器驱动1、Si1145介绍2、硬件准备3、软件准备4、驱动实现1、Si1145介绍 Si1145/46/47 是一款低功耗、基于反射的红外接近、紫外 (UV) 指数和环境光传感器,具有 I2C 数字接…

【一起撸个DL框架】4 反向传播求梯度

CSDN个人主页&#xff1a;清风莫追 欢迎关注本专栏&#xff1a;《一起撸个DL框架》 文章目录 4 反向传播求梯度&#x1f965;4.1 简介4.2 导数与梯度4.3 链式法则4.4 示例&#xff1a;y2x1的梯度 4 反向传播求梯度&#x1f965; 4.1 简介 上一篇&#xff1a;【一起撸个DL框架】…

【OpenCV】 2D-2D:对极几何算法原理

2D-2D匹配: 对极几何 SLAM十四讲笔记1 1.1 对极几何數學模型 考虑从两张图像上观测到了同一个3D点&#xff0c;如图所示**。**我们希望可以求解相机两个时刻的运动 R , t R,t R,t。 假设我们要求取两帧图像 I 1 , I 2 I_1,I_2 I1​,I2​之间的运动,设第一帧到第二帧的运动为…

全国快递物流 API 实现快递单号自动识别的原理解析

概述 全国快递物流 API 是一种提供快递物流单号查询的接口&#xff0c;涵盖了包括申通、顺丰、圆通、韵达、中通、汇通等600快递公司的数据。该 API 的目标是为快递公司、电商、物流平台等提供便捷、快速、准确的快递物流信息查询服务。 数据采集和处理 全国快递物流 API 的…

自定义控件 (?/N) - 颜料 Paint

参考来源 一、颜色 1.1 直接设置颜色 1.1.1 setColor( ) public void setColor(ColorInt int color) paint.setColor(Color.RED) paint.setColor(Color.parseColor("#009688")) 1.1.2 setARGB( ) public void setARGB(int a, int r, int g, int b) paint.se…