【NLP 9、实践 ① 五维随机向量交叉熵多分类】

目录

五维向量交叉熵多分类

规律:

实现:

1.设计模型

2.生成数据集

3.模型测试

4.模型训练

5.对训练的模型进行验证

调用模型


你的平静,是你最强的力量

                                —— 24.12.6

五维向量交叉熵多分类

规律:

x是一个五维(索引)向量,对x做五分类任务

改用交叉熵实现一个多分类任务,五维随机向量中最大的数字在哪维就属于哪一类


实现:

1.设计模型

Linear():模型函数中定义线性层

activation = nn.Softmax(dim=1):定义激活层为softmax激活函数

nn.CrossEntropyLoss() / nn.functional.cross_entropy:定义交叉熵损失函数

pyTorch中定义的交叉熵损失函数内部封装了softMax函数, 而使用交叉熵必须使用softMax函数,对数据进行归一化

经过 Softmax 归一化后,输出向量的每个元素可以被解释为样本属于相应类别的概率。这使得我们能够直接比较不同类别上的概率大小,并且与真实的类别概率分布(如one-hot编码)进行合理的对比。

例如,在一个三分类问题中,经过 Softmax 后的输出可能是[0.2,0.3,0.5],我们可以直观地说样本属于第三类的概率是 0.5,这是一个符合概率意义的解释

forward函数,前向计算,定义网络的使用方式,声明模型计算过程

# 1.设计模型
class TorchModel(nn.Module):
    def __init__(self, input_size):
        super(TorchModel, self).__init__()
        # 预测出一个五维的向量,五维向量代表五个类别上的概率分布
        self.linear = nn.Linear(input_size, 5)  # 线性层
        # 类交叉熵写法:CrossEntropyLoss()     函数交叉熵写法:cross_entropy
        # nn.CrossEntropyLoss() pycharm交叉的熵损失函数内部封装了softMax函数, 而使用交叉熵必须使用softMax函数
        self.loss = nn.functional.cross_entropy # loss函数采用交叉熵损失
        self.activation = nn.Softmax(dim=1)

    # 当输入真实标签,返回loss值;无真实标签,返回预测值
    def forward(self, x, y=None):
        # 输入过第一个网络层
        y_pred = self.linear(x)  # (batch_size, input_size) -> (batch_size, 1)
        if y is not None:
            return self.loss(y_pred, y)  # 预测值和真实值计算损失
        else:
            return self.activation(y_pred)  # 输出预测结果
            # return y_pred

2.生成数据集

由于题目要求,要在一个五维随机向量中查找标量最大的数所在维度,所以用np.random函数随机生成一个五维向量,然后通过np.argmax函数找出生成向量中最大标量所对应的维度,并将其作为数据 x标注 y 返回

当我们输出一串数字,要告诉模型输出的是一串单独的数而不是一串样本时,需要用到 "[ ]",换句话说当y是单独的一个数(标量)时,才需要加“[ ]”

而该模型输出的预测结果是一个向量,而不是一个数(标量的概率)时,不需要拼在一起

# 2.生成数据集标签label   数据构建
# 生成一个样本, 样本的生成方法,代表了我们要学习的规律,随机生成一个5维向量,如果第一个值大于第五个值,认为是正样本,反之为负样本
def build_sample():
    x = np.random.random(5)
    # 获取最大值对应的索引
    max_index = np.argmax(x)
    return x, max_index


# 随机生成一批样本
# 正负样本均匀生成
def build_dataset(total_sample_num):
    X = []
    Y = []
    # 随机生成样本,total_sample_num 生成的随机样本数
    for i in range(total_sample_num):
        x, y = build_sample()
        X.append(x)
        # 当我们输出一串数字,要告诉模型输出的是一串单独的数而不是一串样本时,需要用到"[]",换句话说当y是单独得一个数(标量)时,才需要加“[]”
        # 而该模型输出的预测结果是一个向量,而不是一个数(标量的概率)时,不需要拼在一起
        Y.append(y)
        X_array = np.array(X)
        Y_array = np.array(Y)
        # 一般torch中的Long整形类型用来判定类型
    return torch.FloatTensor(X_array), torch.LongTensor(Y_array)

3.模型测试

用来测试每轮模型预测的精确度

model.eval():声明模型框架在这个函数中不做训练

with torch.no_grad():在模型测试的部分中,声明是测试函数,不计算梯度,增加模型训练效率

zip():zip 函数是一个内置函数,用于将多个可迭代对象(如列表、元组、字符串等)中对应的元素打包成一个个元组,然后返回由这些元组组成的可迭代对象(通常是一个 zip 对象)。如果各个可迭代对象的长度不一致,那么 zip 操作会以最短的可迭代对象长度为准。

# 3.模型测试
# 用来测试每轮模型的准确率
def evaluate(model):
    model.eval()
    test_sample_num = 100
    x, y = build_dataset(test_sample_num)
    print("本次预测集中共有%d个正样本,%d个负样本" % (sum(y), test_sample_num - sum(y)))
    correct, wrong = 0, 0
    with torch.no_grad():
        y_pred = model(x)  # 模型预测 model.forward(x)
        for y_p, y_t in zip(y_pred, y):  # 与真实标签进行对比
            # np.argmax是求最大数所在维,max求最大数,torch.argmax是求最大数所在维
            if torch.argmax(y_p) == int(y_t):
                correct += 1  # 正确预测加一
            else:
                wrong += 1  # 错误预测加一

    print("正确预测个数:%d, 正确率:%f" % (correct, correct / (correct + wrong)))
    return correct / (correct + wrong)

4.模型训练

① 配置参数        

② 建立模型

③ 选择优化器(Adam)

④ 读取训练集

⑤ 训练过程

        Ⅰ、model.train():设置训练模式

        Ⅱ、对训练集样本开始循环训练(循环取出训练数据)

        Ⅲ、根据模型函数和损失函数的定义计算模型损失

        Ⅳ、计算梯度

        Ⅴ、通过梯度用优化器更新权重

        Ⅵ、计算完一轮训练数据后梯度进行归零,下一轮重新计算

torch.save(model.state_dict(), "model.pt"):模型保存model.pt文件

一般任务不同只需更改数据读取(步骤③)模型构建(步骤①)内容,训练过程一般无需更改,evaluate测试代码可能也需更改,因为不同模型测试正确率的方式不同

# 4.模型训练
def main():
    # 配置参数
    epoch_num = 20  # 训练轮数
    batch_size = 20  # 每次训练样本个数
    train_sample = 5000  # 每轮训练总共训练的样本总数
    input_size = 5  # 输入向量维度
    learning_rate = 0.001  # 学习率
    # ① 建立模型
    model = TorchModel(input_size)
    # ② 选择优化器
    optim = torch.optim.Adam(model.parameters(), lr=learning_rate)
    log = []
    # ③ 创建训练集,正常任务是读取训练集
    train_x, train_y = build_dataset(train_sample)
    # 训练过程
    # 轮数进行自定义
    for epoch in range(epoch_num):
        model.train()
        watch_loss = []
        # ④ 读取数据集
        for batch_index in range(train_sample // batch_size):
            x = train_x[batch_index * batch_size : (batch_index + 1) * batch_size]
            y = train_y[batch_index * batch_size : (batch_index + 1) * batch_size]
            # ⑤ 计算loss
            loss = model(x, y)  # 计算loss  model.forward(x,y)
            # ⑥ 计算梯度
            loss.backward()  # 计算梯度
            # ⑦ 权重更新
            optim.step()  # 更新权重
            # ⑧ 梯度归零
            optim.zero_grad()  # 梯度归零
            watch_loss.append(loss.item())
            # 一般任务不同只需更改数据读取(步骤③)和模型构建(步骤①)内容,训练过程一般无需更改,evaluate测试代码可能也需更改,因为不同模型测试正确率的方式不同
        print("=========\n第%d轮平均loss:%f" % (epoch + 1, np.mean(watch_loss)))
        acc = evaluate(model)  # 测试本轮模型结果
        log.append([acc, float(np.mean(watch_loss))])
    # 保存模型
    torch.save(model.state_dict(), "model.pt")
    # 画图
    print(log)
    plt.plot(range(len(log)), [l[0] for l in log], label="acc")  # 画acc曲线
    plt.plot(range(len(log)), [l[1] for l in log], label="loss")  # 画loss曲线
    plt.legend()
    plt.show()
    return

5.对训练的模型进行验证

调用main函数

if __name__ == "__main__":
    main()


调用模型

model.eval():声明模型框架在这个函数中不做训练

predict("model.pt", test_vec):调用模型存储的文件model.pt,通过调用模型对数据进行预测

# 使用训练好的模型做预测
def predict(model_path, input_vec):
    input_size = 5
    model = TorchModel(input_size)
    # 加载训练好的权重
    model.load_state_dict(torch.load(model_path, weights_only=True))
    # print(model.state_dict())

    model.eval()  # 测试模式,不计算梯度
    with torch.no_grad():
        # 输入一个真实向量转成Tensor,让模型forward一下
        result = model.forward(torch.FloatTensor(input_vec))  # 模型预测
    for vec, res in zip(input_vec, result):
        # python中,round函数是对浮点数进行四舍五入
        print("输入:%s, 预测类别:%s, 概率值:%s" % (vec, torch.argmax(res), res))  # 打印结果

if __name__ == "__main__":
    test_vec = [[0.97889086,0.15229675,0.31082123,0.03504317,0.88920843],
                [0.74963533,0.5524256,0.95758807,0.95520434,0.84890681],
                [0.00797868,0.67482528,0.13625847,0.34675372,0.19871392],
                [0.09349776,0.59416669,0.92579291,0.41567412,0.1358894]]
    predict("model.pt", test_vec)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/930241.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

STM32使用RCC(Reset Clock Contorl,复位时钟控制器)配置时钟以及时钟树

RCC主要作用 设置系统时钟SYSCLK(System Clock)频率;设置AHB、APB2、APB1以及各个外设分频因子,从而设置HCLK、PCLK2、PCLK1以及各个外设的时钟频率;控制AHB、APB2、APB1这三条总线时钟以及每个外设的时钟开启&#xf…

使用mtools搭建MongoDB复制集和分片集群

mtools介绍 mtools是一套基于Python实现的MongoDB工具集,其包括MongoDB日志分析、报表生成及简易的数据库安装等功能。它由MongoDB原生的工程师单独发起并做开源维护,目前已经有大量的使用者。 mtools所包含的一些常用组件如下: mlaunch支…

随记:win11 win+g 捕获 不能录视频 不用下载注册表修复工具

问题: 我解决的方法: win R 打开 再去 计算机\HKEY_CURRENT_USER\Software\Microsoft\Windows\CurrentVersion\GameDVR 这是我的问题,要是没有这个,可能是其他原因了 还有就是我还看到一个 上面那个不能解决的可以试试这个

算法刷题Day11: BM33 二叉树的镜像

点击题目链接 思路 转换为子问题:左右子树相反转。遍历手法:后序遍历 代码 class Solution:def Transverse(self,root: TreeNode):if root None:return rootnewleft self.Transverse(root.left)newright self.Transverse(root.right)# 对root节点…

【赵渝强老师】PostgreSQL的服务器日志文件

PostgreSQL数据库的物理存储结构主要是指硬盘上存储的文件,包括:数据文件、日志文件、参数文件、控制文件、WAL预写日志文件等等。下面重点讨论一下PostgreSQL的服务器日志文件。 视频讲解如下 【赵渝强老师】PostgreSQL的服务器日志文件 通过使用pg_ct…

【人工智能】深度解剖利用人工智能MSA模型

目录 情感分析的应用一、概述二、研究背景三、主要贡献四、模型结构和代码五、数据集介绍六、性能展示七、复现过程 情感分析的应用 近年来社交媒体的空前发展以及配备高质量摄像头的智能手机的出现,我们见证了多模态数据的爆炸性增长,如电影、短视频等…

行业标杆!鸿翼入选WAIC 2024《2024大模型典型示范应用案例集》

​7月5日,在2024世界人工智能大会“迈向 AGI:大模型焕新与产业赋能”论坛上,《2024大模型典型示范应用案例集》(以下简称《案例集》)重磅发布!鸿翼AI项目成功入选,彰显了鸿翼在大模型应用领域的…

nodejs循环导出多个word表格文档

文章目录 nodejs循环导出多个word表格文档一、文档模板编辑二、安装依赖三、创建导出工具类exportWord.js四、调用五、效果图nodejs循环导出多个word表格文档 结果案例: 一、文档模板编辑 二、安装依赖 // 实现word下载的主要依赖 npm install docxtemplater pizzip --save/…

ABAP DIALOG屏幕编程1

一、DIALOG屏幕编程 DIALOG屏幕编程是SAP ABAP中用于创建用户交互界面的一种技术,主要用于开发事务性应用程序。它允许用户通过屏幕输入或操作数据,程序根据用户的操作执行逻辑处理。 1、DIALOG编程的主要组件 a、屏幕 (Screen) DIALOG程序的核心部分…

Shell免交互

Shell免交互 一. 变量配置1.1 在E0F外面的变量可以直接传入使用1.2 EOF的输入内容可以直接赋值给变量 二. expect语句2.1 转义符2.2 expect的语法2.3 格式2.4 脚本外传参2.5 嵌套 三. 访问其它主机 交互:当我们使用程序时,需要进入程序发出对应的指令&am…

清风数学建模学习笔记——Topsis法

数模评价类(2)——Topsis法 概述 Topsis:Technique for Order Preference by Similarity to Ideal Solution 也称优劣解距离法,该方法的基本思想是,通过计算每个备选方案与理想解和负理想解之间的距离,从而评估每个…

【认证法规】安全隔离变压器

文章目录 定义反激电源变压器 定义 安全隔离变压器(safety isolating transformer),通过至少相当于双重绝缘或加强绝缘的绝缘使输入绕组与输出绕组在电气上分开的变压器。这种变压器是为以安全特低电压向配电电路、电器或其它设备供电而设计…

喆塔科技携手国家级创新中心,共建高性能集成电路数智化未来

集创新之力成数智之塔 近日,喆塔科技与国家集成电路创新中心携手共建“高性能集成电路数智化联合工程中心”并举行签约揭牌仪式。出席此次活动的领导嘉宾包含:上海市经济和信息化委员会、上海市集成电路行业协会、复旦大学微电子学院、国家集成电路创新中…

OpenCV-图像阈值

简单阈值法 此方法是直截了当的。如果像素值大于阈值,则会被赋为一个值(可能为白色),否则会赋为另一个值(可能为黑色)。使用的函数是 cv.threshold。第一个参数是源图像,它应该是灰度图像。第二…

手游和应用出海资讯:怪物猎人AR手游累计总收入已超过2.5亿美元、SuperPlay获得迪士尼纸牌游戏发行许可

NetMarvel帮助游戏和应用广告主洞察全球市场、获取行业信息,以下为12月第一周资讯: ● 怪物猎人AR手游累计总收入已超过 2.5 亿美元 ● SuperPlay获得迪士尼纸牌游戏发行许可 ● 腾讯混元大模型上线文生视频能力 ● 网易天下事业部一拆三,蛋仔…

ARINC 标准全解析:航空电子领域多系列标准的核心内容、应用与重要意义

ARINC标准概述 ARINC标准是航空电子领域一系列重要的标准规范,由航空电子工程委员会(AEEC)编制,众多航空公司等参与支持。这些标准涵盖了从飞机设备安装、数据传输到航空电子设备功能等众多方面,确保航空电子系统的兼…

代码随想录Day35 本周小结动态规划,动态规划:01背包理论基础,动态规划:01背包理论基础(滚动数组),416. 分割等和子集。

1.本周小结动态规划 周一 动态规划:不同路径 (opens new window)中求从出发点到终点有几种路径,只能向下或者向右移动一步。 我们提供了三种方法,但重点讲解的还是动规,也是需要重点掌握的。 dp[i][j]定义 :表示从…

tomcat+jdbc报错怎么办?

1. 虽然mysql8.0以上的不用手动添加driver类,但是一旦加上driver类,就要手动添加了 不然会报找不到driver类的错误 2. java.lang.RuntimeException: java.sql.SQLException: No suitable driver found for jdbc:mysql://localhost:xXX?serverTimezoneU…

【电子仪器】蓝牙测试仪的使用

大家好,我是山羊君Goat。 蓝牙测试仪是专门对于蓝牙信号,RF射频等进行综合测试的电子仪器。 掌握蓝牙测试仪是部分硬件工程师,特别是RF射频硬件工程师一个必备的技能了。 那要如何操作蓝牙测试仪呢? 这里以一款以前市场上比较流…

ansible基础教程(上)

一、介绍: Ansible是一款用于软件配置、配置管理和软件部署的开源自动化和编排工具。相比于其它的工具,Ansible的安装更加简单、易于使用。通过SSH到客户端的方式进行连接,因此它不需要在客户端有特殊的代理,并且通过将模块推送到…