深度学习_10_softmax_实战

由于网上代码的画图功能是基于jupyter记事本,而我用的是pycham,这导致画图代码不兼容pycharm,所以删去部分代码,以便能更好的在pycharm上运行

完整代码:

import torch
from d2l import torch as d2l

"创建训练集&创建检测集合"
batch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

"创建模型w, b"
num_inputs = 784
num_outputs = 10

W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True)
b = torch.zeros(num_outputs, requires_grad=True)

"softmax"
def softmax(X):
    X_exp = torch.exp(X)
    partition = X_exp.sum(1, keepdim=True)
    return X_exp / partition  # 这里应用了广播机制

"输出,即传入图片输出"
def net(X):
    return softmax(torch.matmul(X.reshape((-1, W.shape[0])), W) + b)

"交叉熵损失"
def cross_entropy(y_hat, y):
    return - torch.log(y_hat[range(len(y_hat)), y])

"显示预测与估计相对应下标数量"
def accuracy(y_hat, y):  #@save
    """计算预测正确的数量"""
    if len(y_hat.shape) > 1 and y_hat.shape[1] > 1: # 确定长宽高都大于1
        y_hat = y_hat.argmax(axis=1) # 取出每行中最大值
    cmp = y_hat.type(y.dtype) == y
    return float(cmp.type(y.dtype).sum()) # 返回对应下标数量


"利用优化后的模型计算精度"
def evaluate_accuracy(net, data_iter):  #@save

    if isinstance(net, torch.nn.Module):
        net.eval()  # 将模型设置为评估模式
    metric = Accumulator(2)  # 正确预测数、预测总数
    with torch.no_grad():
        for X, y in data_iter:
            metric.add(accuracy(net(X), y), y.numel()) # 下标相同数量 / 总下标
    return metric[0] / metric[1]


"加法器"
class Accumulator:  #@save

    def __init__(self, n):
        self.data = [0.0] * n

    def add(self, *args):
        self.data = [a + float(b) for a, b in zip(self.data, args)]

    def reset(self):
        self.data = [0.0] * len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

"训练更新模型&返回训练损失与精度函数"
def train_epoch_ch3(net, train_iter, loss, updater):  #@save
    """训练模型一个迭代周期(定义见第3章)"""
    # 将模型设置为训练模式
    if isinstance(net, torch.nn.Module):
        net.train()
    # 训练损失总和、训练准确度总和、样本数
    metric = Accumulator(3)
    for X, y in train_iter:
        # 计算梯度并更新参数
        y_hat = net(X)
        l = loss(y_hat, y)
        if isinstance(updater, torch.optim.Optimizer):
            # 使用PyTorch内置的优化器和损失函数
            updater.zero_grad()
            l.mean().backward()
            updater.step()
        else:
            # 使用定制的优化器和损失函数
            l.sum().backward()
            updater(X.shape[0])
        metric.add(float(l.sum()), accuracy(y_hat, y), y.numel())
    # 返回训练损失和训练精度
    return metric[0] / metric[2], metric[1] / metric[2]

lr = 0.1

"更新模型"
def updater(batch_size):
    return d2l.sgd([W, b], lr, batch_size)

if __name__ == '__main__':
    num_epochs = 10
    cnt = 1
    for i in range(num_epochs):
        X, Y = train_epoch_ch3(net, train_iter, cross_entropy, updater)
        print("训练次数: " + str(cnt))
        cnt += 1
        print("训练损失: {:.4f}".format(X))
        print("训练精度: {:.4f}".format(Y))
        print(".................................")
#        print(W)
#        print(b)

效果:

在这里插入图片描述

训练效果还是和网上一样的,就是缺了画图功能,将就着吧

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/122776.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

使用合成数据训练语义分割模型

计算机视觉应用现在在各个科技领域无处不在。 这些模型高效且有效,研究人员每年都会尝试新想法。 新的前沿是试图消除深度学习的最大负担:需要大量的标记数据。 正如本文所述,此问题的解决方案是使用合成数据。 从这些研究中获益最多的计算机…

ros1 模拟客户端生成小乌龟服务请求生成小乌龟

模拟客户端生成小乌龟服务请求生成小乌龟 一、话题模型二、创建功能包三 创建客户端Client代码四 配置CMakeLists.txt编译规则:五 测试启动ros 主服务启动小乌龟的服务启动模型客户端服务 一、话题模型 Sever端是海龟仿真器/turtlesim,Client端是待实现…

Android 13.0 Launcher3 app图标长按去掉应用信息按钮

1.前言 在13.0的rom定制化开发中,在Launcher3定制化开发中,对Launcher3的定制化功能中,在Launcher3的app列表页会在长按时,弹出微件和应用信息两个按钮,点击对应的按钮跳转到相关的功能页面, 现在由于产品需求要求禁用应用信息,不让进入到应用信息页面所以要去掉应用信息…

BigDecimal使用的时候需要注意什么?

BigDecimal只要涉及到浮点数运算都会用到BigDecimal,并且面试的时候经常会问到,那么BigDecimal使用的时候需要注意什么? 目录 1.为什么不能用浮点数表示金额?2.十进制转换二进制3.科学记数法4.IEEE 7545.在线浮点数转换二进制6.原…

限流式保护器在养老院火灾预防中的应用

安科瑞 华楠 【摘要】老年人是一个庞大特殊的社会群体。随着我国人口的老龄化,老年人口数量断上升。涉及老年人的火灾越来越多,本文从养老院火灾的案例、成因、预防措施等方面对此类火灾进行了深入的探讨。 【关键词】老年公寓;火灾预防&…

VINS-Mono-后端优化 (一:预积分残差计算-IMU预积分约束)

这里先回顾一下预积分是怎么来的 VINS-Mono-IMU预积分 (三:为什么要预积分预积分推导) 这里贴出预积分的公式 具体含义解释看对对应的文章 整个误差函数如下 预积分 α \alpha α β \beta β γ \gamma γ 是用 IMU 预积分获得的增量&a…

CSS关于默认宽度

所谓的默认宽度&#xff0c;就是不设置width属性时&#xff0c;元素所呈现出来的宽度 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title></title><style>* {margin: 0;padding: 0;}.box {/…

Hive 常用存储、压缩格式

1. Hive常用的存储格式 TEXTFI textfile为默认存储格式 存储方式&#xff1a;行存储 磁盘开销大 数据解析开销大 压缩的text文件 hive 无法进行合拆分 SEQUENCEFILE sequencefile二进制文件&#xff0c;以<key,value>的形式序列到文件中 存储方式&#xff1a;行存储 可…

STM32中断简介

中断系统 中断&#xff1a;在主程序运行过程中&#xff0c;出现了特定的中断触发条件&#xff08;中断源&#xff09;&#xff0c;使得CPU暂停当前正在运行的程序&#xff0c;转而去处理中断程序&#xff0c;处理完成后又返回原来被暂停的位置继续运行&#xff1b; 以上是中断的…

Docker - 镜像

Docker - 镜像 镜像是什么 镜像是一种轻量级&#xff0c;可执行的独立软件包&#xff0c;用来打包软件运行环境和基于运行环境开发的软件&#xff0c;它包含运行某个软件所需的所有内容&#xff0c;包括代码&#xff0c;运行时&#xff0c;库&#xff0c;环境变量和配置文件。…

电脑风扇控制软件 Macs Fan Control Pro mac中文版功能介绍

Macs Fan Control mac是一款专门为 Mac 用户设计的软件&#xff0c;它可以帮助用户控制和监控 Mac 设备的风扇速度和温度。这款软件允许用户手动调整风扇速度&#xff0c;以提高设备的散热效果&#xff0c;减少过热造成的风险。 Macs Fan Control 可以在菜单栏上显示当前系统温…

如何在Python爬虫中使用IP代理以避免反爬虫机制

目录 前言 一、IP代理的使用 1. 什么是IP代理&#xff1f; 2. 如何获取IP代理&#xff1f; 3. 如何使用IP代理&#xff1f; 4. 如何避免IP代理失效&#xff1f; 5. 代理IP的匿名性 二、代码示例 总结 前言 在进行爬虫时&#xff0c;我们很容易会遇到反爬虫机制。网站…

聊一聊被人嘲笑的if err!=nil和golang为什么要必须支持多返回值?

golang多返回值演示 我们知道&#xff0c;多返回值是golang的一个特性&#xff0c;比如下面这段代码,里面的参数名我起了几个比较好区分的 package mainfunc main() {Swap(10999, 10888) }func Swap(saaa, sbbb int) (int, int) {return sbbb, saaa }golang为什么要支持多返回…

IP代理识别API:预防欺诈和保护网络安全的必要工具

引言 随着互联网的快速发展&#xff0c;我们的生活变得越来越依赖于网络。然而&#xff0c;随着网络的发展&#xff0c;网络犯罪和网络欺诈也在不断增加。为了保护自己的网站和客户免受网络欺诈的侵害&#xff0c;许多企业和组织开始使用IP代理识别API作为一种必要工具。 什么…

jenkins结合k8s部署动态slave

1、完成k8s连接 在完成jenkins的部署后现安装kubernets的插件 如果jenkins 是部署在k8s集群中只需要填写一下 如果是非本集群的部署则需要填写证书等 cat ./config echo ‘certificate-authority-data-value’ | base64 -d > ./ca.crt echo ‘client-certificate-data’ |…

第二次pta认证P测试C++

#include <iostream> using namespace std; int f(int n){if (n0){return 1;}if (n1){return 3;}return 4*f(n-1)-f(n-2); } int n; int main() {cin>>n;cout<<f(n);return 0; }第二题 试题编号&#xff1a;2022-13-0302 试题名称&#xff1a;长正整数相加 …

springcloud小说阅读网站源码

开发工具&#xff1a; 大等于jdk1.8&#xff0c;大于mysql5.5&#xff0c;nodejs&#xff0c;idea&#xff08;eclipse&#xff09;&#xff0c;vscode&#xff08;webstorm&#xff09; 技术说明&#xff1a; springcloud springboot mybatis vue elementui 功能介绍&…

2023年9月少儿编程 中国电子学会图形化编程等级考试Scratch编程二级真题解析(判断题)

2023年9月scratch编程等级考试二级真题 判断题(共10题,每题2分,共20分) 26、下列两个程序运行效果一样 答案:对 考点分析:考查积木综合使用,重点考查重复执行和坐标积木 两个程序都是在x=0,y=100的时候停止,所以正确 27、甲、乙和丙,一位是山东人,一位是河南人,…

2023云栖大会,Salesforce终敲开中国CRM市场

2015年被视为中国CRM SaaS元年&#xff0c;众多CRM SaaS创业公司和厂商在Salesforce的榜样作用下涌入了CRM SaaS赛道。在全球市场&#xff0c;Salesforce是CRM SaaS领域的领导厂商&#xff0c;连续多年占据了全球CRM SaaS第一大厂商地位。然而&#xff0c;Salesforce作为业务类…