【Transformer从零开始代码实现 pytoch版】(六)模型基本测试运行

模型基本测试及运行

在这里插入图片描述

(1)构建数据生成器

def data_generator(V, batch, num_batch):
    """ 用于随机生成copy任务的数据

    :param V: 随机生成数字的最大值+1
    :param batch: 每次输送给模型更新一次参数的数据量
    :param num_batch: 输送多少次完成一轮
    :return:
    """
    # 遍历nbatches
    for i in range(num_batch):
        # 在循环中使用np的random.randint方法随机生成[1, v)的整数
        # 每批次10个样本,分布在(batch, 10)形状的矩阵中,然后再把numpy形式转换成torch中的tensor
        data = torch.from_numpy(np.random.randint(1, V, size=(batch, 10)))

		# 生成起始标志,使数据矩阵中的第一列数字都为1,这一列也就成为了起始标志列
        # 当解码器进行第一次解码的时候,会使用起始标志列作为输入
        data[:, 0] = 1

        # 因为是copy任务,所有source与target是完全相同的,且数据样本作用变量不需要求梯度
        # 因此requires_grad设置为False
        with torch.no_grad():
            target = source = data

        # 使用Batch对source和target进行对应批次的掩码张量生成,最后使用yield返回
        yield Batch(source, target)

示例

V = 11                  # 将生成0-10的整数
batch = 20              # 每次喂给模型20个数据进行参数更新
num_batch = 30          # 连续喂30次完成全部数据的遍历

res = data_generator(V, batch, num_batch)
print(f"res {res}")


res <generator object data_generator at 0x000001BD670E4D60>

(2) 获得Transformer模型及其优化器和损失函数

# 获得Transformer模型机及其优化器和损失函数
from pyitcast.transformer_utils import get_std_opt          # 导入优化器工具包,用于获得标准的针对Transformer模型的优化器
from pyitcast.transformer_utils import LabelSmoothing       # 导入标签平滑工具包,用于标签平滑(小幅度的改变原有标签值的值域)
from pyitcast.transformer_utils import SimpleLossCompute    # 导入损失计算工具包,能够使标签平滑后的结果进行损失计算

# 使用make_mode获得model
model = make_model(V, V, N=2)
# 使用get_std_opt获得模拟优化器
model_optimizer = get_std_opt(model)
# 使用LabelSmoothing获得平滑对象
criterion = LabelSmoothing(size=V, padding_idx=0, smoothing=0.0)     # 输入目标词汇的总数
# 使用SimpleLossCompute获得利用标签平滑结果的损失计算方法
loss = SimpleLossCompute(model.generator, criterion, model_optimizer)

  • from pyitcast.transformer_utils import get_std_opt :该标准优化器基于Adam优化器,使其对序列到序列的任务更有效
  • from pyitcast.transformer_utils import LabelSmoothing:因为在理论上人工标注的数据可能并非完全正确,会受一些外界隐私影响而产涩会给你一些微笑的偏差,因此使用标签平滑来弥补这种偏差,减少模型对某一条规律的绝对认知,以防过拟合。
  • from pyitcast.transformer_utils import SimpleLossCompute:损失的计算方法可以认为使交叉熵损失函数。

在这里插入图片描述
在这里插入图片描述

(3)运行模型进行训练和评估

# 导入模型单轮训练工具包run_epoch,该工具将对模型使用给定的损失函数计算方法进行单轮参数更新,同时,打印每轮参数更新的损失结果
from pyitcast.transformer_utils import run_epoch

def run(model, loss, epochs=10):
    """ 模型训练函数

    :param model: 要进行训练的模型
    :param loss: 使用的损失计算方法
    :param epochs: 模型的训练轮数
    :return:
    """
    for epoch in range(epochs):
        # 使用训练模式,进行反向传播,所有参数将被更新
        model.train()
        run_epoch(data_generator(V, 8, 20), model, loss)        # batch_size = 20

        # 使用评估模型,不进行反向传播,所有参数不会被更新
        model.eval()
        run_epoch(data_generator(V, 8, 5), model, loss)        # batch_size = 5

示例

run(model, loss)

(4)使用模型进行贪婪解码

# 贪婪解码
from pyitcast.transformer_utils import greedy_decode        # 导入贪婪解码工具包greedy_decode,每次预测都是选择概率最大的结果作为输出

def greedy_run(model, loss, epochs=10):
    for epoch in range(epochs):
        model.train()
        run_epoch(data_generator(V, 8, 20), model, loss)

        model.eval()
        run_epoch(data_generator(V, 8, 5), model, loss)

    # 模型训练结束,进入评估模式
    model.eval()

    # 初始化一个输入张量
    source = torch.LongTensor([[1,3,2,5,4,6,7,8,9,10]])

    # 定义源数据掩码张量,因为元素都是1,这里1代表不遮掩,因此相当于对数据源没有遮掩
    source_mask = torch.ones(1, 1, 10)

    # 起始标志默认为1
    result = greedy_decode(model, source, source_mask, max_len=10, start_symbol=1)
    print(result)

示例

greedy_run(model, loss)

(5)小结

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/146111.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

第四天课程 分布式搜索引擎1

分布式搜索引擎01 – elasticsearch基础 0.学习目标 1.初识elasticsearch 1.1.了解ES 1.1.1.elasticsearch的作用 elasticsearch是一款非常强大的开源搜索引擎&#xff0c;具备非常多强大功能&#xff0c;可以帮助我们从海量数据中快速找到需要的内容 例如&#xff1a; …

第十九章总结:Java绘图

19.1&#xff1a;Java绘图类 19.2&#xff1a;绘制图形 package nineteentn; import java.awt.*; import javax.swing.*;public class DrawCircle extends JFrame {private final int OVAL_WIDTH 80; // 圆形的宽private final int OVAL_HEIGHT 80; // 圆形的高public DrawC…

nginx-静态资源实践(压缩配置,常见静态资源配置)

Nginx 实战搭建一个静态资源web服务器 第一个阶段访问单个文件 listen 80; #监听的端口 server_name localhost; #服务名称 #配置路径映射 location /geotools/ {alias geotools/; #将geotools/ 和/geotools/路径一一对应起来 }目录关系如下&#xff1a; nginx下面建一个g…

软考系统分析师知识点集锦一:企业信息化战略与实施

一、信息化战略体系(★★★★★) 1、信息资源规划ISP 信息资源规划是信息化建设的基础工程&#xff0c;是指对企业生产经营活动所需要的信息&#xff0c;对产生、获取、处理、存储、传输和利用等方面进行全面的规划。 IRP强调将需求分析与系统建模紧密结合起来&#xff0c;需…

酷柚易汛ERP - 其他入库单操作指南

1、应用场景 处理其他非采购类型的入库单据&#xff0c;比如赠品、获赔商品、以货抵债、借入、接受捐赠等不参与采购管理的入库类业务。 2、主要操作 2.1 新增其他入库单 打开【仓库】-【其他入库单】&#xff0c;选择商品后&#xff0c;根据存货核算方法自动计算出单位成本…

栈的实现---超详细

栈的实现 文章目录 栈的实现一、栈的模型二、栈的代码实现以及测试用例①栈的初始化②入栈③出栈④弹出栈顶⑤判断栈空间是否为空⑥计算栈空间长度⑦销毁栈⑧测试用例 一、栈的模型 首先栈有两个概念 1.数据结构里的栈。2.语言/操作系统中的栈(内存空间)&#xff0c;可能会在递…

JUC工具包介绍

目录 1. 引言 2. 介绍JUC工具包 2.1. JUC工具包的概述和作用 2.2. 什么是JUC工具包&#xff1f; 2.2.1. JUC工具包与传统线程编程的区别和优势 3. 线程池&#xff08;Executor&#xff09; 3.1. 线程池的概念和优势 3.1.1. ThreadPoolExecutor类的介绍和使用示例 3.1.…

Docker安装Octoprint 3D打印控制软件

Octoprint简介 Octoprint是一个运行在Linux系统上的开源套件&#xff0c;可以为普通的3D打印机添加强大的外围管理功能。 web管理界面远程操控摄像头实时监控视频录制、延时摄影在线切片图形化的温度曲线显示手机监控操作免SD卡和U盘通过插件和USB/GPIO接口实现更多功能 Oct…

第三天课程 RabbitMQ

RabbitMQ 1.初识MQ 1.1.同步和异步通讯 微服务间通讯有同步和异步两种方式&#xff1a; 同步通讯&#xff1a;就像打电话&#xff0c;需要实时响应。 异步通讯&#xff1a;就像发邮件&#xff0c;不需要马上回复。 两种方式各有优劣&#xff0c;打电话可以立即得到响应&am…

分布式锁介绍

为什么需要分布式锁 在单机部署的系统中&#xff0c;使用线程锁来解决高并发的问题&#xff0c;多线程访问共享变量的问题达到数据一致性&#xff0c;如使用synchornized、ReentrantLock等。 但是在后端集群部署的系统中&#xff0c;程序在不同的JVM虚拟机中运行&#xff0c;且…

PostGIS学习教程五:数据

教程的数据是有关纽约市的四个shapefile文件和一个包含社会人口经济数据的数据表。在前面一节我们已经将shapefile加载为PostGIS表&#xff0c;在后面我们将添加社会人口经济数据。 下面描述了每个数据集的记录数量和表属性。这些属性值和关系是我们以后分析的基础。 要在pgAdm…

fileread任意文件读取学习笔记

任意文件读取概述 一些网站的需求&#xff0c;可能会提供文件查看与下载的功能。如果对用户查看或下载的文件没有限制或者限制绕过&#xff0c;就可以查看或下载任意文件。这些文件可以是源代码文件&#xff0c;配置文件&#xff0c;敏感文件等等。 任意文件读取会造成&#x…

Python学习:同步异步阻塞与非阻塞

嗨喽&#xff0c;大家好呀~这里是爱看美女的茜茜呐 一、状态介绍 在了解其他概念之前&#xff0c;我们首先要了解进程的几个状态。 在程序运行的过程中&#xff0c;由于被操作系统的调度算法控制&#xff0c;程序会进入几个状态&#xff1a;就绪&#xff0c;运行和阻塞。 就绪…

LeetCode | 225. 用队列实现栈

LeetCode | 225. 用队列实现栈 OJ链接 此题可以用两个队列去实现一个栈&#xff0c;每次始终保持一个队列为空&#xff0c; 入栈操作相当于给非空队列进行入队操作 入数据&#xff0c;把不为空的队列入 出数据&#xff0c;把不为空的队列数据导入为空&#xff0c;直到最后一…

1334. 阈值距离内邻居最少的城市

分析题目两点“阈值距离”、“邻居最少”。 “阈值距离”相当于定了个上界&#xff0c;求节点之间的最短距离。 “邻居最少”相当于能连接的点的数量。 求节点之间的最短距离有以下几种方法&#xff1a; 在这道题当中&#xff0c;n的范围是100以内&#xff0c;所以可以考虑O(n…

超详细!!新手必看!STM32--独立看门狗IWBG

一、看门狗是什么&#xff1f; 答&#xff1a;看门狗是一个12bit的递减计数器。当计数器的值从某个值一直减到0的时候&#xff0c;系统就会产生一个复位信号&#xff0c;CPU收到复位信号&#xff0c;系统复位重新运行。在计数没减到0之前&#xff0c;重置了计数器的值的话&…

降水短临预报模型trajGRU简介

1 前言 trajGRU 是在对 convLSTM 的改进&#xff0c;且这两个模型是同一个作者。 convLSTM 在降水短临预报这块已经超越传统模型&#xff0c;但其是局部不变性的(location-invariant)&#xff0c;而自然的运动和转换(如旋转)是局部变化的(location-invariant)。作者为了能够使…

【python 生成器 面试必备】yield关键字,协程必知必会系列文章--自己控制程序调度,体验做上帝的感觉 2

这篇文章要解决的问题&#xff1a;How to Pass Value to Generators Using the “yield” Expression in Python ref:https://python.plainenglish.io/yield-python-part-ii-e93abb619a16 1.如何传值 yield 是一个表达式&#xff01;&#xff01;&#xff01;&#xff01; yi…

⑤ 【MySQL】DCL语句 —— 用户管理、权限控制

个人简介&#xff1a;Java领域新星创作者&#xff1b;阿里云技术博主、星级博主、专家博主&#xff1b;正在Java学习的路上摸爬滚打&#xff0c;记录学习的过程~ 个人主页&#xff1a;.29.的博客 学习社区&#xff1a;进去逛一逛~ MySQL用户与权限 ⑤ 【MySQL】DCL语句 —— 用…

AWTK实现汽车仪表Cluster/DashBoard嵌入式GUI开发(七):FREERTOS移植

前言: 一般的GUI工程都需要一个操作系统,可能是linux,重量级的,也可能是FreeRTOS,轻量级的。 一句话理解那就是工程就是FreeRTOS task任务的集合。 一个main函数可以看到大框架: 很显然,除了第一个是硬件配置的初始化,中间最重要的部分就是要创建任务,把AWTK的应用…