从零开始复现GPT2(六):生成代码的实现


源码地址:https://gitee.com/guojialiang2023/gpt2


GPT2

  • 模型
    • 文本生成
      • 配置
      • 生成框架
      • 文本生成类实现
      • 文本生成代码

模型

在这里插入图片描述

文本生成

配置

class GenerateConfig(object):
    def __init__(self,
                 seq_len: int,
                 nucleus_prob: float,
                 use_gpu: bool):
        self.seq_len = seq_len
        self.nucleus_prob = nucleus_prob
        self.use_gpu = use_gpu

生成框架

import torch
import torch.nn as nn
from typing import List


class GenerationSpec(object):
    def initialize(self):
        pass

    def construct_model(self) -> nn.Module:
        raise NotImplementedError()

    def encode_context(self, context: str) -> List[int]:
        raise NotImplementedError()

    def decode_tokens(self, tokens: List[int]) -> str:
        raise NotImplementedError()

    def decorate_sequence(self, sequence: torch.Tensor, offset: int
                          ) -> torch.Tensor:
        return sequence

文本生成类实现

import torch
from model import Past
from generation import GenerationSpec, GenerateConfig
from typing import List, Optional, Tuple


class Generator(object):
    def __init__(self, spec: GenerationSpec, config: GenerateConfig):
        self.spec = spec
        self.config = config

    def initialize(self, from_model: Optional[str] = None):
        # Initialize generation environment and construct a model.
        self.spec.initialize()
        self.model = self.spec.construct_model().eval()

        # Load trained model parameters.
        if from_model:
            ckpt = torch.load(from_model, map_location='cpu')
            self.model.load_state_dict(ckpt['model'])

        # Move the model to GPU device and convert the data type to half
        # precision.
        if self.config.use_gpu:
            self.model.cuda().half()

    def generate(self, context: str) -> str:
        words = self.spec.encode_context(context)

        current, past = words, None
        while len(words) < self.config.seq_len:
            # Predict the next word token from the given context.
            probs, past = self._predict_probs(current, past)
            next_word = self._sample_from_top_p(probs)

            # Change the context to the predicted word.
            words.append(next_word)
            current = [next_word]

        return self.spec.decode_tokens(words)

    @torch.no_grad()
    def _predict_probs(self,
                       words: List[int],
                       past: Optional[List[Past]] = None
                       ) -> Tuple[torch.Tensor, List[Past]]:
        x = torch.tensor(words, dtype=torch.long)
        x = self.spec.decorate_sequence(
            x, offset=past[0][0].size(-2) if past is not None else 0)

        if self.config.use_gpu:
            logits, past = self.model(x.cuda(), past)
            logits = logits.cpu().float()
        else:
            logits, past = self.model(x, past)

        return logits[-1, :].softmax(-1), past

    def _sample_from_top_p(self, probs: torch.Tensor) -> int:
        probs, indices = probs.sort(descending=True)

        mask = probs.cumsum(-1) > self.config.nucleus_prob
        mask[0] = False
        probs.masked_fill_(mask, 0)

        # Sample from filtered distribution.
        return indices[probs.multinomial(1)[0]].item()

代码定义了一个用于文本生成的Generator类,它使用了GPT-2模型。这个类能够基于给定的上下文生成文本。下面是对代码中关键部分的详细解释:

__init__ 方法

  • __init__是类的构造函数,它接受两个参数:specconfig
  • specGenerationSpec类型)包含了模型构造和上下文编码/解码的规范,这是一个抽象定义,用于处理特定于模型的逻辑。
  • configGenerateConfig类型)包含了生成过程的配置,如是否使用GPU、生成序列的长度等。

initialize 方法

  • initialize方法用于初始化生成环境,构造模型,并可选地从已训练的模型中加载参数。
  • 如果from_model参数被提供,方法会加载这个模型的参数。
  • 如果配置为使用GPU,模型会被移动到GPU上,并转换为半精度浮点数以提高性能。

generate 方法

  • generate方法接受一个字符串类型的context作为参数,这个上下文用作文本生成的起点。
  • 方法首先将上下文编码为模型能理解的形式(通常是一系列token的ID)。
  • 然后,它在给定的上下文基础上循环生成文本,直到达到配置的序列长度seq_len
  • 在每一步中,它都会调用_predict_probs方法来预测下一个单词的概率分布,并通过_sample_from_top_p方法从这个分布中采样一个单词。
  • 生成的单词被添加到上下文中,用作下一次预测的输入。

_predict_probs 方法

  • 这是一个私有方法,用于基于当前的单词(或单词序列)和过去的状态(如果有的话)来预测下一个单词的概率分布。
  • 它接受当前的单词序列和可选的过去状态作为输入,并返回下一个单词的概率分布和更新后的状态。
  • 如果配置为使用GPU,输入和输出会相应地移动到GPU或CPU上,并且输出的logits会被转换为浮点数。

_sample_from_top_p 方法

  • 这个私有方法用于实现“nucleus sampling”(也称为“top-p sampling”),这是一种从概率分布中采样单词的方法,它仅考虑累积概率超过某个阈值(nucleus_prob)的最高概率单词。
  • 通过这种方式,它有效地过滤掉了低概率的单词,减少了生成随机无关文本的可能性,同时保留了一定程度的随机性以增加文本的多样性。

文本生成代码

这段代码是一个完整的Python脚本,用于通过命令行界面生成使用GPT-2模型训练的文本。它主要包括定义GPT2GenerationSpec类、generate_sentence_with_gpt2_model函数和add_subparser函数,以及如何使用argparse库来解析命令行参数。下面是对这些关键部分的详细解释:

GPT2GenerationSpec

  • GPT2GenerationSpec类继承自GenerationSpec,专门用于GPT-2模型的文本生成。它通过初始化参数配置词汇表、序列长度、Transformer模型的层次、注意力头数、维度和维度增加率。
  • initialize方法加载词汇表并初始化分词器。
  • construct_model方法构建Transformer模型实例,这里的模型配置是根据初始化时传入的参数定制的。
  • encode_context方法将文本上下文编码为模型能够处理的token序列。
  • decode_tokens方法将token序列解码回文本字符串,如果遇到结束符eos_idx,则只解码到该符号为止。

generate_sentence_with_gpt2_model 函数

  • 这个函数是脚本的主要执行点,用于生成文本。它首先根据命令行参数创建GPT2GenerationSpecGenerateConfig实例。
  • 接着,它初始化Generator实例,并可选地从文件中加载预训练的模型参数。
  • 使用while True循环不断读取用户输入,并生成相应的文本输出。

add_subparser 函数

  • 这个函数用于argparse库,以定义命令行参数和子命令。它允许用户通过命令行指定词汇表文件路径、模型文件路径、模型配置(如序列长度、层数、注意力头数等)和生成选项(如nucleus采样概率和是否使用GPU)。
  • 它使得脚本的使用更加灵活,用户可以根据需要调整生成文本的配置。

至此全部从零开始实现GPT2已全部完成

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/367777.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【C/C++ 10】扫雷小游戏

一、题目 写一个扫雷小游戏&#xff0c;每次输入一个坐标&#xff0c;若该处是地雷&#xff0c;则游戏失败&#xff0c;若该处不是地雷&#xff0c;则显示周围地雷数量&#xff0c;若扫除全部非地雷区域&#xff0c;则扫雷成功。 二、算法 设置两张地图&#xff08;二维数组&…

手把手教你开发Python桌面应用-PyQt6图书管理系统-主界面UI设计实现

锋哥原创的PyQt6图书管理系统视频教程&#xff1a; PyQt6图书管理系统视频教程 Python桌面开发 Python入门级项目实战 (无废话版) 火爆连载更新中~_哔哩哔哩_bilibiliPyQt6图书管理系统视频教程 Python桌面开发 Python入门级项目实战 (无废话版) 火爆连载更新中~共计24条视频&…

移远(Quectel)物联网通信解决方案

一、方案简介 无线通信模块是具备无线通信的电路模块&#xff0c;它能通过无线连接传输数据&#xff0c;能识别分析主控制器发来的命令&#xff0c;控制节点设备的工作&#xff0c;或者向主控制器发送当前节点设备的工作状态。 市面上常用的无线通信模组包括蓝牙模组、WLAN模…

2024年【上海市安全员B证】最新解析及上海市安全员B证复审考试

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 上海市安全员B证最新解析根据新上海市安全员B证考试大纲要求&#xff0c;安全生产模拟考试一点通将上海市安全员B证模拟考试试题进行汇编&#xff0c;组成一套上海市安全员B证全真模拟考试试题&#xff0c;学员可通过…

算法练习-二叉树的节点个数【完全/普通二叉树】(思路+流程图+代码)

难度参考 难度&#xff1a;中等 分类&#xff1a;二叉树 难度与分类由我所参与的培训课程提供&#xff0c;但需要注意的是&#xff0c;难度与分类仅供参考。且所在课程未提供测试平台&#xff0c;故实现代码主要为自行测试的那种&#xff0c;以下内容均为个人笔记&#xff0c;旨…

ubuntu22.04 安装部署01:禁用内核更新

一、前言 ubunut22.04系统安装以后&#xff0c;内核更新会导致各种各样的问题&#xff0c;因此锁定初始安装环境特别重要&#xff0c;下面介绍如何锁定内核更新。 二、操作方法 2.1 查看可用内核 dpkg --list | grep linux-image dpkg --list | grep linux-headers dpkg --…

STM32--USART串口(2)串口外设

一、USART简介 可配置数据位&#xff1a;不需要校验就是8位&#xff0c;需要校验就选9位&#xff1b; 停止位&#xff1a;决定了帧的间隔; STM32F103C8T6USART&#xff1a;USART1挂载在APB2总线上&#xff0c;USART2和USART3挂载在APB1总线上&#xff1b; 二、USART框图 TXE…

Python中使用Opencv-python库绘制直线、矩形、圆、文本

Python中使用Opencv-python库绘制直线、矩形、圆、文字 在Python中使用Opencv-python绘制直线、矩形、圆、文本非常简单&#xff0c;分别使用到line、rectangle、circle、putText这几个函数&#xff0c;具体可以参考https://docs.opencv.org/4.9.0/d6/d6e/group__imgproc__dra…

如何部署Node.js服务并实现无公网ip远程访问本地项目【内网穿透】

文章目录 前言1.安装Node.js环境2.创建node.js服务3. 访问node.js 服务4.内网穿透4.1 安装配置cpolar内网穿透4.2 创建隧道映射本地端口 5.固定公网地址 前言 Node.js 是能够在服务器端运行 JavaScript 的开放源代码、跨平台运行环境。Node.js 由 OpenJS Foundation&#xff0…

【C++】类与对象(三)—运算符重载|const成员函数|取地址及const取地址操作符重载

前言 运算符重载&#xff0c;自增自减运算符重载&#xff0c;const成员函数&#xff0c;取地址及const取地址操作符重载 文章目录 一、运算符重载自增和自减运算符重载 二、const 成员函数三、取地址及const取地址操作符重载&#xff08;了解即可&#xff09; 一、运算符重载 运…

P1967 [NOIP2013 提高组] 货车运输

[NOIP2013 提高组] 货车运输 题目背景 NOIP2013 提高组 D1T3 题目描述 A 国有 n n n 座城市&#xff0c;编号从 1 1 1 到 n n n&#xff0c;城市之间有 m m m 条双向道路。每一条道路对车辆都有重量限制&#xff0c;简称限重。 现在有 q q q 辆货车在运输货物&#x…

Unity Meta Quest MR 开发(三):Scene API 配置+实现虚拟与现实之间的碰撞

文章目录 &#x1f4d5;教程说明&#x1f4d5; Scene 配置⭐开启场景理解功能和应用访问空间数据的权限⭐OVRSceneManager⭐制作 Plane Prefab 和 Volume Prefab⭐运行场景⭐添加透视材质 &#x1f4d5;虚拟与现实物体的碰撞&#xff08;弹球 Demo&#xff09;&#x1f4d5;Mes…

【JavaSE篇】——继承

目录 &#x1f393;继承 ✅为什么需要继承 ✅继承概念 ✅继承的语法 ✅父类成员访问 &#x1f6a9;子类中访问父类的成员变量 1. 子类和父类不存在同名成员变量的情况 2. 子类和父类成员变量同名 &#x1f6a9;子类中访问父类的成员方法 1. 成员方法名字不同 2. 成员…

MyBatis常见面试题汇总

说一下MyBatis执行流程&#xff1f; MyBatis是一款优秀的基于Java的持久层框架&#xff0c;它内部封装了JDBC&#xff0c;使开发者只需要关注SQL语句本身&#xff0c;而不需要花费精力去处理加载驱动、创建连接等的过程&#xff0c;MyBatis的执行流程如下&#xff1a; 加载配…

车载测试Vector工具——基于DoIP的ECU/车辆的连接故障排除

车载测试Vector工具——基于DoIP的ECU/车辆的连接故障排除 我是穿拖鞋的汉子,魔都中坚持长期主义的汽车电子工程师(Wechat:gongkenan2013)。 老规矩,分享一段喜欢的文字,避免自己成为高知识低文化的工程师: 屏蔽力是信息过载时代一个人的特殊竞争力,任何消耗你的人和…

计算huggingface模型占用硬盘空间的实战代码

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于机器学习算法研究与应用。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。拥有多项发明专利。对机器学习和深度学习拥有自己独到的见解。曾经辅导过若干个非计算机专业的学生进入到算法…

景联文科技受邀出席全国信标委生物特征识别分委会二届五次全会

全国信息技术标准化技术委员会生物特征识别分技术委员会&#xff08;SAC/TC28/SC37&#xff0c;以下简称“分委会”&#xff09;二届五次全会于2024年1月30日在北京顺利召开&#xff0c;会议由分委员秘书长王文峰主持。 分委会由国家标准化管理委员会批准成立&#xff0c;主要负…

N 叉树的层序遍历

给定一个 N 叉树&#xff0c;返回其节点值的层序遍历。&#xff08;即从左到右&#xff0c;逐层遍历&#xff09;。 树的序列化输入是用层序遍历&#xff0c;每组子节点都由 null 值分隔&#xff08;参见示例&#xff09;。 示例 1&#xff1a; 输入&#xff1a;root [1,null…

配置实例—VLAN间跨三层通信的交换机配置实例

一、组网需求 企业的不同用户拥有相同的业务&#xff0c;且位于不同的网段。现在相同业务的用户所属的VLAN不相同&#xff0c;需要实现不同VLAN中的用户相互通信。 如图1所示&#xff0c;User1和User2中拥有相同的业务&#xff0c;但是属于不同的VLAN且位于不同的网段。现需要…

【笔记】React Native实战练习(仿网易云游戏网页移动端)

/** * 如果系统看一遍RN相关官方文档&#xff0c;可能很快就忘记了。一味看文档也很枯燥无味&#xff0c; * 于是大概看了关键文档后&#xff0c;想着直接开发一个Demo出来&#xff0c;边学边写&#xff0c;对往后工作 * 开发衔接上能够更顺。这期间肯定会遇到各种各样的问题&a…