GPT建模与预测实战

代码链接见文末

效果图:

1.数据样本生成方法

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl

其中train.pkl是处理后的文件

因此,我们首先需要执行preprocess.py进行预处理操作,配置参数:

--data_path data/novel --save_path data/train.pkl --win_size 200 --step 200

其中--vocab_file是语料表,一般不用修改,--log_path是日志路径

预处理流程如下:

  • 首先,初始化tokenizer
  • 读取作文数据集目录下的所有文件,预处理后,对于每条数据,使用滑动窗口对其进行截断
  • 最后,序列化训练数据 

代码如下:

# 初始化tokenizer
    tokenizer = CpmTokenizer(vocab_file="vocab/chinese_vocab.model")#pip install jieba
    eod_id = tokenizer.convert_tokens_to_ids("<eod>")   # 文档结束符
    sep_id = tokenizer.sep_token_id

    # 读取作文数据集目录下的所有文件
    train_list = []
    logger.info("start tokenizing data")
    for file in tqdm(os.listdir(args.data_path)):
        file = os.path.join(args.data_path, file)
        with open(file, "r", encoding="utf8")as reader:
            lines = reader.readlines()
            title = lines[1][3:].strip()    # 取出标题
            lines = lines[7:]   # 取出正文内容
            article = ""
            for line in lines:
                if line.strip() != "":  # 去除换行
                    article += line
            title_ids = tokenizer.encode(title, add_special_tokens=False)
            article_ids = tokenizer.encode(article, add_special_tokens=False)
            token_ids = title_ids + [sep_id] + article_ids + [eod_id]
            # train_list.append(token_ids)

            # 对于每条数据,使用滑动窗口对其进行截断
            win_size = args.win_size
            step = args.step
            start_index = 0
            end_index = win_size
            data = token_ids[start_index:end_index]
            train_list.append(data)
            start_index += step
            end_index += step
            while end_index+50 < len(token_ids):  # 剩下的数据长度,大于或等于50,才加入训练数据集
                data = token_ids[start_index:end_index]
                train_list.append(data)
                start_index += step
                end_index += step

    # 序列化训练数据
    with open(args.save_path, "wb") as f:
        pickle.dump(train_list, f)

2.模型训练过程

 (1) 数据与标签

        在训练过程中,我们需要根据前面的内容预测后面的内容,因此,对于每一个词的标签需要向后错一位。最终预测的是每一个位置的下一个词的token_id的概率。

(2)训练过程

        对于每一轮epoch,我们需要统计该batch的预测token的正确数与总数,并计算损失,更新梯度。

训练配置参数:

--epochs 40 --batch_size 8 --device 0 --train_path data/train.pkl
def train_epoch(model, train_dataloader, optimizer, scheduler, logger,
                epoch, args):
    model.train()
    device = args.device
    ignore_index = args.ignore_index
    epoch_start_time = datetime.now()

    total_loss = 0  # 记录下整个epoch的loss的总和
    epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量
    epoch_total_num = 0  # 每个epoch中,预测的word的总数量

    for batch_idx, (input_ids, labels) in enumerate(train_dataloader):
        # 捕获cuda out of memory exception
        try:
            input_ids = input_ids.to(device)
            labels = labels.to(device)
            outputs = model.forward(input_ids, labels=labels)
            logits = outputs.logits
            loss = outputs.loss
            loss = loss.mean()

            # 统计该batch的预测token的正确数与总数
            batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)
            # 统计该epoch的预测token的正确数与总数
            epoch_correct_num += batch_correct_num
            epoch_total_num += batch_total_num
            # 计算该batch的accuracy
            batch_acc = batch_correct_num / batch_total_num

            total_loss += loss.item()
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            loss.backward()
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            # 进行一定step的梯度累计之后,更新参数
            if (batch_idx + 1) % args.gradient_accumulation_steps == 0:
                # 更新参数
                optimizer.step()
                # 更新学习率
                scheduler.step()
                # 清空梯度信息
                optimizer.zero_grad()

            if (batch_idx + 1) % args.log_step == 0:
                logger.info(
                    "batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(
                        batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))

            del input_ids, outputs

        except RuntimeError as exception:
            if "out of memory" in str(exception):
                logger.info("WARNING: ran out of memory")
                if hasattr(torch.cuda, 'empty_cache'):
                    torch.cuda.empty_cache()
            else:
                logger.info(str(exception))
                raise exception

    # 记录当前epoch的平均loss与accuracy
    epoch_mean_loss = total_loss / len(train_dataloader)
    epoch_mean_acc = epoch_correct_num / epoch_total_num
    logger.info(
        "epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))

    # save model
    logger.info('saving model for epoch {}'.format(epoch + 1))
    model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    model_to_save = model.module if hasattr(model, 'module') else model
    model_to_save.save_pretrained(model_path)
    logger.info('epoch {} finished'.format(epoch + 1))
    epoch_finish_time = datetime.now()
    logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))

    return epoch_mean_loss

(3)部署与网页预测展示

        app.py既是模型预测文件,又能够在网页中展示,这需要我们下载一个依赖库:

pip install streamlit

        

生成下一个词流程,每次只根据当前位置的前context_len个token进行生成:

  • 第一步,先将输入文本截断成训练的token大小,训练时我们采用的200,截断为后200个词
  • 第二步,预测的下一个token的概率,采用温度采样和topk/topp采样

最终,我们不断的以自回归的方式不断生成预测结果

这里指定模型目录 

进入项目路径

执行streamlit run app.py 

 生成效果:

 数据与代码链接:https://pan.baidu.com/s/1XmurJn3k_VI5OR3JsFJgTQ?pwd=x3ci 
提取码:x3ci 

 

         

      

 

         

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/537844.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SpringBoot入门(Hello World 项目)

SpringBoot关键结构 1.2.1 Core Container The Core Container consists of the Core, Beans, Context, and Expression Language modules. The Core and Beans modules provide the fundamental parts of the framework, including the IoC and Dependency Injection featur…

【嵌入式日志调试】嵌入式系统限制打印后使用echo定向到串口节点实现日志输出

背景 系统在启动业务进程时把输出定向到NULL&#xff0c;如./sample > /dev/null&#xff0c;正式版本的系统又是只读系统&#xff0c;不方便开放日志。然后又需要输出日志进行分析问题&#xff0c;系统不支持的情况&#xff0c;只改自己负责的进程实现日志打印 方案 步骤…

书生·浦语大模型实战营之XTuner 微调个人小助手认知

书生浦语大模型实战营之XTuner 微调个人小助手认知 在本节课中讲一步步带领大家体验如何利用 XTuner 完成个人小助手的微调&#xff01; 为了能够让大家更加快速的上手并看到微调前后对比的效果&#xff0c; 用 QLoRA 的方式来微调一个自己的小助手&#xff01; 可以通过下面两…

通过ckeditor组件在vue2中实现上传图片

1&#xff0c;开始实现逻辑前&#xff0c;优先启项目&#xff0c;然后将ckeditor引入&#xff0c;大概如下&#xff1a; 1&#xff0c;npm i ckeditor/ckeditor5-vue2 2&#xff0c;下载sdk&#xff0c;https://ckeditor.com/ckeditor-5/online-builder/#&#xff0c;打开这个地…

Linux——十个槽位,RWX

Linux——RWX 十个槽位 - 表示文件 d 表示文件夹 l 表示软链接 r权&#xff0c;针对文件可以查看文件内容 针对文件夹&#xff0c;可以查看文件夹内容&#xff0c;如ls命令 w权&#xff0c;针对表示可以修改此文件 针对文件夹&#xff0c;可以在文件夹内&#…

只需5分钟,利用Python掌握SQLite3

在数据涌现的今天&#xff0c;数据库已成为生活中不可或缺的工具。Python作为一种流行的编程语言&#xff0c;内置了多种用于操作数据库的库&#xff0c;其中之一就是SQLite。SQLite是一种轻量级的关系型数据库管理系统&#xff0c;它在Python中的应用非常广泛。本文介绍如何使…

如何快速制作问卷?时间省略必备技能

我们可以采用“提出重要问题、简化问题长度、使用调查逻辑、预填答案、避免询问技术性问题、分次调查、问题模块化、增加问题的多样性”等方式来缩短问卷制作的时间。 高回复率对于问卷调查的最终结果至关重要。就像一支强壮而细长的箭头可以走更远的距离一样&#xff0c;清晰而…

matlab 安装 mingw64(6.3.0),OPENEXR

matlab安装openexr 1. matlab版本与对应的mingw版本选择2. mingw&#xff08;6.3.0&#xff09;下载地址&#xff1a;3. matlab2020a配置mingw&#xff08;6.3.0&#xff09;流程“4. matlab 安装openexr方法一&#xff1a;更新matlab版本方法二&#xff1a;其他博文方法方法三…

详解Qt添加外部库

在Qt项目中添加外部库是一项常见任务&#xff0c;无论是静态库还是动态库都需要正确的配置才能让项目顺利编译链接。以下是详细步骤和不同场景下的配置方法&#xff1a; 方法一&#xff1a;手动编辑.pro文件 添加头文件路径&#xff1a; 在Qt项目中的.pro文件中使用INCLUDEPAT…

VsCode 安装Jupyter Notebook

VsCode 安装Jupyter Notebook 安装 1、打开 VSCode 编辑器&#xff0c;点击界面左端的【扩展】栏&#xff1b; 2、在【搜索框】中输入python&#xff0c;点击第一个Python&#xff0c;检查是否已经安装 python 插件&#xff0c;没安装的点击安装&#xff1b;已安装的继续第3步…

AI预测体彩排3第1弹【2024年4月12日预测--第1套算法开始计算第1次测试】

前面经过多个模型几十次对福彩3D的预测&#xff0c;积累了一定的经验&#xff0c;摸索了一些稳定的规律&#xff0c;有很多彩友让我也出一下排列3的预测结果&#xff0c;我认为目前时机已成熟&#xff0c;且由于福彩3D和体彩排列3的玩法完全一样&#xff0c;我认为3D的规律和模…

大文件传输之为啥传输过程中出现宽带不足的情况

在当今数字化时代&#xff0c;大文件传输已成为企业日常运营的关键环节。然而&#xff0c;许多企业在传输大文件时经常面临宽带不足的问题&#xff0c;这不仅影响了工作效率&#xff0c;还可能导致业务机会的丧失。本文将探讨大文件传输过程中宽带不足的原因&#xff0c;以及镭…

目前电视盒子哪个最好?测评工作室盘点超强电视盒子推荐

每年我们会进行数十次电视盒子测评&#xff0c;对各个品牌的产品都有深入的了解&#xff0c;最近我们收到了很多私信不知道目前电视盒子哪个最好&#xff0c;这次我们结合配置总结了五款性能超强的电视盒子推荐给各位&#xff0c;预算足够想买高配电视盒子的朋友们可不要错过啦…

视频评论ID提取工具|视频关键词评论批量采集软件

视频评论ID提取工具&#xff1a;批量抓取视频评论 视频评论ID提取工具是一款功能强大的软件&#xff0c;可以帮助您批量抓取视频视频下的评论信息。通过输入关键词和评论监控词&#xff0c;即可进行评论的抓取&#xff0c;并提供评论昵称、评论日期、评论内容、命中关键词以及所…

机器学习和深度学习 -- 李宏毅(笔记与个人理解)Day 13

Day13 Error surface is rugged…… Tips for training :Adaptive Learning Rate critical point is not the difficult Root mean Square --used in Adagrad 这里为啥是前面的g的和而不是直接只除以当前呢? 这种方法的目的是防止学习率在训练过程中快速衰减。如果只用当前的…

pip install tensorflow-gpu 报错

查阅官网后可知&#xff0c;该命令已经被删除掉了。 tensorflow-gpu PyPI 【解决办法】 我直接安装了其他版本的包 pip install tensorflow-gpu2.10.0 测试 import tensorflow as tfprint("tf.__version__: ", tf.__version__)# print("tf.test.is_gpu_av…

String类(2)

❤️❤️前言~&#x1f973;&#x1f389;&#x1f389;&#x1f389; hellohello~&#xff0c;大家好&#x1f495;&#x1f495;&#xff0c;这里是E绵绵呀✋✋ &#xff0c;如果觉得这篇文章还不错的话还请点赞❤️❤️收藏&#x1f49e; &#x1f49e; 关注&#x1f4a5;&…

计算机视觉:技术与应用的深度解析

一、引言 计算机视觉&#xff0c;作为人工智能的一个重要分支&#xff0c;旨在通过计算机模拟人类的视觉系统&#xff0c;实现对图像或视频信息的自动分析和理解。随着深度学习、神经网络等技术的快速发展&#xff0c;计算机视觉在各个领域的应用日益广泛&#xff0c;包括安全…

MongoDB数据库转换为表格文件的Python实现

目录 一、引言 二、转换工具与库的选择 三、转换过程详解 安装必要的库 连接MongoDB数据库 查询并处理数据 将数据写入CSV文件 四、进阶技巧与注意事项 五、总结 一、引言 在当今大数据时代&#xff0c;数据的存储、处理与共享显得尤为重要。MongoDB作为一个面向文档…

Arduino 项目笔记 |TH1621 LCD液晶显示屏驱动(SSOP-24封装)

LCD液晶屏资料 LCD液晶屏资料 重要参数&#xff1a; 工作电压&#xff1a; 3V可视角度&#xff1a;1201/4 &#xff0c;1/3 TH1621 驱动 HT1621 LCD控制驱动芯片介绍 VLCD 和 VCC 电压符合规格书&#xff0c;最好都取3.3V 。电压太高或太低都会出现段码液晶屏乱码的情况&am…