训练自己的GPT2

训练自己的GPT2

  • 1.预训练与微调
  • 2.准备工作
  • 2.在自己的数据上进行微调

1.预训练与微调

所谓的预训练,就是在海量的通用数据上训练大模型。比如,我把全世界所有的网页上的文本内容都整理出来,把全人类所有的书籍、论文都整理出来,然后进行训练。这个训练过程代价很大,首先模型很大,同时数据量又很大,比如GPT3参数量达到了175B,训练数据达到了45TB,训练一次就话费上千万美元。如此大代价学出来的是一个通用知识的模型,他确实很强,但是这样一个模型,可能无法在一些专业性很强的领域上取得比较好的表现,因为他没有针对这个领域的数据进行训练过。

因此,大模型火了之后,很多人都开始把大模型用在自己的领域。通常也就是把自己领域的一些数据,比如专业书、论文等等整理出来,使用预训练好的大模型在新的数据集上进行微调。微调的成本相比于预训练就要小得多了。

2.准备工作

首先需要安装第三方库transformerstransformers是一个用于自然语言处理(NLP)的Python第三方库,实现Bert、GPT-2和XLNET等比较新的模型,支持TensorFlow和PyTorch。以及下载预训练好的模型权重。

pip install transformers

安装完成之后,我们可以直接使用下面的代码,来构造一个预训练的GPT2

from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

当运行的时候,代码会自动从hugging face上下载模型。但是由于hugging face是国外网站,可能下载起来很慢或者无法下载,因此我们也可以自己手动下载之后在本地读取。

打开hugging face的网站,搜索GPT2。或者直接进入GPT2的页面。

下载上图中的几个文件到本地,假设下载到./gpt2文件夹

然后就可以使用下面的代码来尝试预训练的模型直接生成文本你的效果。

from transformers import GPT2Tokenizer, GPT2LMHeadModel


tokenizer = GPT2Tokenizer.from_pretrained("./gpt2")
model = GPT2LMHeadModel.from_pretrained("./gpt2")

q = "tell me a fairy story"

ids = tokenizer.encode(q, return_tensors='pt')
final_outputs = model.generate(
    ids,
    do_sample=True,
    max_length=100,
    pad_token_id=model.config.eos_token_id,
    top_k=50,
    top_p=0.95,
)

print(tokenizer.decode(final_outputs[0], skip_special_tokens=True))

回答如下:

2.在自己的数据上进行微调

首先把我们的数据,也就是文本,全部整理到一起。比如可以把所有文本拼接到一起。

假设所有的文本数据都存到一个文件中。那么可以直接使用下面的代码进行训练。

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification, AdamW, GPT2LMHeadModel
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, TextDataset

def load_data_collator(tokenizer, mlm = False):
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer, 
        mlm=mlm,
    )
    return data_collator

def load_dataset(file_path, tokenizer, block_size = 128):
    dataset = TextDataset(
        tokenizer = tokenizer,
        file_path = file_path,
        block_size = block_size,
    )
    return dataset


def train(train_file_path, model_name,
          output_dir,
          overwrite_output_dir,
          per_device_train_batch_size,
          num_train_epochs,
          save_steps):
    tokenizer = GPT2Tokenizer.from_pretrained(model_name)
    train_dataset = load_dataset(train_file_path, tokenizer)
    data_collator = load_data_collator(tokenizer)

    tokenizer.save_pretrained(output_dir)

    model = GPT2LMHeadModel.from_pretrained(model_name)

    model.save_pretrained(output_dir)

    training_args = TrainingArguments(
          output_dir=output_dir,
          overwrite_output_dir=overwrite_output_dir,
          per_device_train_batch_size=per_device_train_batch_size,
          num_train_epochs=num_train_epochs,
      )

    trainer = Trainer(
          model=model,
          args=training_args,
          data_collator=data_collator,
          train_dataset=train_dataset,
    )

    trainer.train()
    trainer.save_model()


train_file_path = "./train.txt"   # 你自己的训练文本
model_name = './gpt2'  # 预训练的模型路径
output_dir = './custom_data'  # 你自己设定的模型保存路径
overwrite_output_dir = False
per_device_train_batch_size = 96  # 每一台机器上的batch size。
num_train_epochs = 50   
save_steps = 50000

# Train
train(
    train_file_path=train_file_path,
    model_name=model_name,
    output_dir=output_dir,
    overwrite_output_dir=overwrite_output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    num_train_epochs=num_train_epochs,
    save_steps=save_steps
)     

训练完成之后,推理的话,直接使用第二节里的代码,将预训练模型路径换成自己训练的模型路径就行了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/308355.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

从零学Java 集合概述

Java 集合概述 文章目录 Java 集合概述1 什么是集合?2 Collection体系集合2.1 Collection父接口2.1.1 常用方法2.1.2 Iterator 接口 1 什么是集合? 概念:对象的容器,定义了对多个对象进行操作的常用方法;可实现数组的功能。 和数组区别&…

中小企业实施了MES系统后,同样具备大企业的生产能力

工业4.0、智能制造是当前制造业最热门的话题。数字化工厂是实现智能制造的基础,在建设数字化工厂的过程中,MES系统是核心也是最重要的一环。万界星空MES系统是企业信息数据集成的纽带,可帮助企业实现监控与实际生产过程的同步化,全…

基于JavaWeb+BS架构+SpringBoot+Vue校车调度管理系统的设计和实现

基于JavaWebBS架构SpringBootVue校车调度管理系统的设计和实现 文末获取源码Lun文目录前言主要技术系统设计功能截图订阅经典源码专栏Java项目精品实战案例《500套》 源码获取 文末获取源码 Lun文目录 摘 要 1 Abstract 1 目 录 2 1 绪 论 1 1.1研究背景 1 1.2 研究意义 1 1.…

软件测试|Python openpyxl库使用指南

简介 我们之前介绍过,python在自动化办公方面可以大放异彩,因为Python有许多的第三方库,其中有很多库就支持我们对office软件进行操作,熟练的使用Python对office进行操作,可以实现自动化办公,极大提升我们…

【博士每天一篇论文-算法】Optimal modularity and memory capacity of neural reservoirs

阅读时间:2023-11-15 1 介绍 年份:2019 作者:Nathaniel Rodriguez 印第安纳大学信息学、计算和工程学院,美国印第安纳州布卢明顿 期刊: Network Neuroscience 引用量:39 这篇论文主要研究了神经网络的模块…

WEB 3D技术 three.js 光照与阴影

本文 我们来说 灯光与阴影 之前 我们有接触到光照类的知识 但是阴影应该都是第一次接触 那么 我们先来看光 首先是 AmbientLight 环境光 你在官网中搜索 AmbientLight 官方是就写明了 环境光是不会产生阴影的 因为 它没有反向 然后是 DirectionalLight 平行光 它是可以投射阴…

Java建筑工程建设智慧工地源码

智慧工地管理平台依托物联网、互联网,建立云端大数据管理平台,形成“端云大数据”的业务体系和新的管理模式,从施工现场源头抓起,最大程度的收集人员、安全、环境、材料等关键业务数据,打通从一线操作与远程监管的数据…

代理IP连接不上/网速过慢?如何应对?

当您使用代理时,您可能会遇到不同的代理错误代码显示代理IP连不通、访问失败、网速过慢等种种问题。 在本文中中,我们将讨论您在使用代理IP时可能遇到的常见错误、发生这些错误的原因以及解决方法。 一、常见代理服务器错误 当您尝试访问网站时&#…

MySQL 存储引擎全攻略:选择最适合你的数据库引擎

1. MySQL的支持的存储引擎有哪些 官方文档给出的有以下几种: 我们也可以通过SHOW ENGINES命令来查看: 还可以通过ENGINES表查看 2. 存储引擎比较 我们通过存储引擎表来看各自的优点: InnoDB 默认的存储引擎(SUPPORT字段为D…

LeetCode 36 有效的数独

题目描述 有效的数独 请你判断一个 9 x 9 的数独是否有效。只需要 根据以下规则 ,验证已经填入的数字是否有效即可。 数字 1-9 在每一行只能出现一次。数字 1-9 在每一列只能出现一次。数字 1-9 在每一个以粗实线分隔的 3x3 宫内只能出现一次。(请参考…

[openGL]在ubuntu20.06上搭建openGL环境

就在刚刚, 我跑上了一个6小时后出结果的测试程序. 离下班还有很久, 于是我打开了接单群 , 发现了很多可以写的openGL项目. 但是!!我的电脑现在是ubuntu呀, 但是不要慌!!!接下来我一步一步教你如何完美搭建一个ubuntu上的openGL环境. 保证一个坑也不会踩! 文章目录 创建项目工作…

借助Gitee将typora图片上传CSDN

概述 前面已经发了一个如何借助Github将typora上的图片上传到csdn上,但这有个缺陷:需要科学上网才能加速查看已经上传到github上的图片,否则就会出现已经上传的图片,无法正常查看的问题 如何解决? 那就可以使用Gite…

前端(angular)在谷歌(chrome)浏览器使用高德地图api定位报错超时geolocation time out ,能定位但不安全的方法

已知信息整合 正如大家搜到的大佬说的原因是chrome浏览器本身的问题。我换成edge就可以。高德地图给出的地图定位api的常见问题,这是另外还有个别浏览器(如google Chrome浏览器等)本身的定位接口是黑洞 以下是能定位但不安全的方法 连接上…

Java面试之集合篇

前言 本篇主要总结JAVA面试中关于集合相关的高频面试题。本篇的面试题基于网络整理以及自己的总结编辑。在不断的完善补充哦。欢迎小伙伴们在评论区发表留言哦! 1、基础 1.1、Java 集合框架有哪些? Java 集合框架,大家可以看看 《Java 集…

Excel·VBA按指定顺序排序函数

与之前写过的《ExcelVBA数组冒泡排序函数》不同,不是按照数值大小的升序/降序对数组进行排序,而是按照指定数组的顺序,对另一个数组进行排序 以下代码调用了《ExcelVBA数组冒泡排序函数》bubble_sort_arr函数(如需使用代码需复制…

18张AI电脑动漫超清壁纸免费分享

18张AI电脑动漫壁纸,紫色系和暗黑系,都很不错,喜欢的朋友可以拿去 CSDN免积分下载

【云计算】云计算概述

1. 云计算概述 1.1 云计算的定义 美国国家标准与技术研究院(NIST)定义 云计算是一种按使用量付费的模式,这种模式提供可用的、便捷的、按需的网络访问,进入可配置的计算资源共享池(资源包括网络,服务器,存储,应用软件…

AI墨墨交流群正式成立:探索科技前沿,共建智能未来

在这个充满变革的时代,AI技术正如涌泉般迸发,带来无限可能。我们深感,唯有汇聚智慧,方能更好地驾驭这股前沿科技的潮流。因此,我们自豪地宣布:AI墨墨交流群正式成立了!这不仅是一个交流群&#…

小白苦恼:电脑那么多USB口,怎么知道哪个读写更快?

前言 最近有个朋友和小白抱怨:电脑那么多USB接口,有些接口在传输文件的时候实在慢的很。 电脑诞生以来,USB接口就一直存在。但是USB接口还是长得几乎一样,不仔细去研究都不知道哪个USB会更快。 许多小伙伴就会直接放弃辨认&…

阿里云服务器新购、续费、升级优惠活动及代金券领取入口汇总

阿里云作为国内领先的云计算服务提供商,一直以来都为广大的用户提供了优质、稳定、高效的服务。为了更好地满足用户的需求,阿里云会不定期地推出各种优惠活动,包括新购、续费、升级优惠活动以及代金券领取等。本文将为大家详细介绍这些优惠活…