pytorch-多分类实战之手写数字识别

目录

  • 1. 网络设计
  • 2. 代码实现
    • 2.1 网络代码
    • 2.2 train
  • 3. 完整代码

1. 网络设计

输入是手写数字图片28x28,输出是10个分类0~9,有两个隐藏层,如下图所示:
在这里插入图片描述

2. 代码实现

2.1 网络代码

第一层将784降维到200,第二次使用200不降维,输出层200降维到10,每一层之后加一个激活函数relu,每一层都需要梯度信息所以requires_grad=True;
forward函数最后不要加softmax,因为后面CrossEntropyLoss中包含了softmax操作。
在这里插入图片描述

2.2 train

优化目标是w1、b1、w2、b2、w3、b3,使用SGD优化器,使用CrossEntropyLoss计算loss
在这里插入图片描述

3. 完整代码

import  torch
import  torch.nn as nn
import  torch.nn.functional as F
import  torch.optim as optim
from    torchvision import datasets, transforms


batch_size=200
learning_rate=0.01
epochs=10

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])),
    batch_size=batch_size, shuffle=True)



w1, b1 = torch.randn(200, 784, requires_grad=True),\
         torch.zeros(200, requires_grad=True)
w2, b2 = torch.randn(200, 200, requires_grad=True),\
         torch.zeros(200, requires_grad=True)
w3, b3 = torch.randn(10, 200, requires_grad=True),\
         torch.zeros(10, requires_grad=True)

# torch.nn.init.kaiming_normal_(w1)
# torch.nn.init.kaiming_normal_(w2)
# torch.nn.init.kaiming_normal_(w3)


def forward(x):
    x = x@w1.t() + b1
    x = F.relu(x)
    x = x@w2.t() + b2
    x = F.relu(x)
    x = x@w3.t() + b3
    x = F.relu(x)
    return x



optimizer = optim.SGD([w1, b1, w2, b2, w3, b3], lr=learning_rate)
criteon = nn.CrossEntropyLoss()

for epoch in range(epochs):

    for batch_idx, (data, target) in enumerate(train_loader):
        data = data.view(-1, 28*28)

        logits = forward(data)
        loss = criteon(logits, target)

        optimizer.zero_grad()
        loss.backward()
        # print(w1.grad.norm(), w2.grad.norm())
        optimizer.step()

        if batch_idx % 100 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))


    test_loss = 0
    correct = 0
    for data, target in test_loader:
        data = data.view(-1, 28 * 28)
        logits = forward(data)
        test_loss += criteon(logits, target).item()

        pred = logits.data.max(1)[1]
        correct += pred.eq(target.data).sum()

    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

如下图:
未使用torch.nn.init.kaiming_normal_(w1)初始化参数的情况,可以看出Loss在2.302585后就不下降了。
在这里插入图片描述
如下图:使用了torch.nn.init.kaiming_normal_(w1)初始化参数的情况下,Loss下降还是比较快的。
在这里插入图片描述
因此使用好的初始化参数对网络的训练起到至关重要的作用

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/539544.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Vue3 + ElementUI】表单校验无效(写法:this.$refs[‘formName‘].validate((valid) =>{} ))

一. 表单校验 1.1 template模块 el-form 中 若校验&#xff0c;ref 和 rules 必须要有 <template><div style"padding:20px"><el-form ref"formName" :model"form" :rules"formRules" label-width"120px"…

【opencv】示例-grabcut.cpp 使用OpenCV库的GrabCut算法进行图像分割

left mouse button - set rectangle SHIFTleft mouse button - set GC_FGD pixels CTRLleft mouse button - set GC_BGD pixels 这段代码是一个使用OpenCV库的GrabCut算法进行图像分割的C程序。它允许用户通过交互式方式选择图像中的一个区域&#xff0c;并利用GrabCut算法尝试…

idea项目编译时报错:GC overhead limit exceeded

问题描述 今天idea构建一个新的项目时报错&#xff1a;GC overhead limit exceeded&#xff0c;错误是发生在编译阶段&#xff0c;而不是运行阶段。 ava: GC overhead limit exceeded java.lang.OutOfMemoryError: GC overhead limit exceededat com.sun.tools.javac.resources…

STM32H743VIT6使用STM32CubeMX通过I2S驱动WM8978(2)

接前一篇文章&#xff1a;STM32H743VIT6使用STM32CubeMX通过I2S驱动WM8978&#xff08;1&#xff09; 本文参考以下文章及视频&#xff1a; STM32CbueIDE Audio播放音频 WM8978 I2S_stm32 cube配置i2s录音和播放-CSDN博客 STM32第二十二课&#xff08;I2S&#xff0c;HAL&am…

STM32 串口接收定长,不定长数据

本文为大家介绍如何使用 串口 接收定长 和 不定长 的数据。 文章目录 前言一、串口接收定长数据1. 函数介绍2.代码实现 二、串口接收不定长数据1.函数介绍2. 代码实现 三&#xff0c;两者回调函数的区别比较四&#xff0c;空闲中断的介绍总结 前言 一、串口接收定长数据 1. 函…

pbootcms百度推广链接打不开显示404错误页面

PbootCMS官方在2023年4月21日的版本更新中&#xff08;对应V3.2.5版本&#xff09;&#xff0c;对URL参数添加了如下判断 if(stripos(URL,?) ! false && stripos(URL,/?tag) false && stripos(URL,/?page) false && stripos(URL,/?ext_) false…

Composer安装与配置

Composer&#xff0c;作为PHP的依赖管理工具&#xff0c;极大地简化了PHP项目中第三方库的安装、更新与管理过程。本文将详细介绍Composer的安装步骤、基本配置方法&#xff0c;以及一些实用的操作示例&#xff0c;帮助读者快速上手并熟练运用Composer。 一、Composer安装 环…

Python --- 怎么把Python当计算器用?(小白自学笔记)

怎么把Python当计算器用&#xff1f;(小白自学笔记) Part I&#xff1a;标准数学包的导入 今天刚刚装了python&#xff0c;打算用它来取代matlab的基本计算功能&#xff0c;当我的日常计算器用。(这里还有一个捷径&#xff0c;如果你跟我一样也是纯小白的话&#xff0c;直接问c…

如何确保软件通过SmartScreen验证,消除用户下载时的警告提示?

在当前的网络时代&#xff0c;各种软件应用程序深深渗透到人们的日常生活和工作中&#xff0c;许多企业选择自行开发应用程序以推进其业务发展。但在发布应用程序后&#xff0c;软件所有者经常会遇到一个挑战&#xff0c;即用户在下载时可能会遇到微软SmartScreen提示“此应用程…

JVM修炼之路【11】- 解决内存溢出、内存泄漏 以及相关案例

前面的10篇 都是基础的知识&#xff0c;包括类加载的过程 类加载的细节&#xff0c;jvm内存模型 垃圾回收 等等&#xff0c; 这一篇我们开始实战了解一下 各种疑难杂症&#xff1a;怎么监控 怎么发现 怎么解决 内存溢出 内存泄漏 这两个概念在垃圾回收器里面已经讲过了&#…

Java前置一些知识

文章目录 搭建Java环境安装path环境变量Java技术体系 Java执行原理JDK组成跨平台Java内存分配 IDEA管理Java程序 搭建Java环境 安装 oralce官网下载 JDK17 Windows 傻瓜式的点下一步就行&#xff0c;注意&#xff1a;安装目录不要有空格、中文 java 执行工具 javac 编译工具…

2024年第十四届MathorCup数学应用挑战赛C题解析(更新中)

2024年第十四届MathorCup数学应用挑战赛C题解析&#xff08;更新中&#xff09; 题目题目解析(更新中&#xff09;问题一问题二问题三 题目 C题 物流网络分拣中心货量预测及人员排班电商物流网络在订单履约中由多个环节组成&#xff0c;图1是一个简化的物流 网络示意图。其中&a…

状态模式:管理对象状态转换的动态策略

在软件开发中&#xff0c;状态模式是一种行为型设计模式&#xff0c;它允许一个对象在其内部状态改变时改变它的行为。这种模式把与特定状态相关的行为局部化&#xff0c;并且将不同状态的行为分散到对应的状态类中&#xff0c;使得状态和行为可以独立变化。本文将详细介绍状态…

ActiveMQ 01 消息中间件jmsMQ

消息中间件之ActiveMQ 01 什么是JMS MQ 全称&#xff1a;Java MessageService 中文&#xff1a;Java 消息服务。 JMS 是 Java 的一套 API 标准&#xff0c;最初的目的是为了使应用程序能够访问现有的 MOM 系 统&#xff08;MOM 是 MessageOriented Middleware 的英文缩写&am…

django基于python的法院执法案件管理系统

本课题使用Python语言进行开发。代码层面的操作主要在PyCharm中进行&#xff0c;将系统所使用到的表以及数据存储到MySQL数据库中&#xff0c;方便对数据进行操作本课题基于WEB的开发平台&#xff0c;设计的基本思路是&#xff1a; 框架&#xff1a;django/flask 后端&#xff…

LwIP 之八 详解 IP RAW 编程、示例、API 源码、数据流

我们最为熟知的网络通信程序接口应该是 Socket。LwIP 自然也提供了 Socket 编程接口,不过,LwIP 的 Socket 编程接口都是使用最底层的接口来实现的。我们这里要学习的 IP RAW 编程则是指的直接使用 LwIP 的提供的 RAW API 来直接实现应用层功能。这里先来一张图,对 LwIP 内部…

【Godot4自学手册】第三十六节圆形移动或扇形移动的铁球

在第三十四节我实现了来回无限滚动的伤害铁刺球&#xff0c;这一节我准备实现一个圆形移动或扇形移动&#xff0c;并带有链条的铁球。效果如下&#xff1a; 一、实现原理 绕一点做圆周运动&#xff0c;简单的说就是&#xff1a; 每一帧根据旋转的角度计算出下一个位置的坐标…

【c++leetcode】14. Longest Common Prefix

问题入口 解决方案 class Solution { public:string longestCommonPrefix(vector<string>& v) {string ans "";sort(v.begin(), v.end());int n v.size();string first v[0],last v[n - 1];for(int i 0; i < min(first.size(),last.size()); i){…

实现网站图片水印

要实现网站图片水印&#xff0c;有几种方式&#xff1a;1、对于自己想要上传图片先通过某些软件增加水印&#xff0c;然后再上传到图片服务器。2、通过上传客户端&#xff08;eg&#xff1a;picgo&#xff09;功能或插件直接自动水印以及上传服务器。本文主要聚焦于第二种方式&…

前端重置表单的多个Demo

目录 前言1. 纯重置2. reset重置3. resetFields重置4. 彩蛋 前言 由于从Java转全栈&#xff0c;对于前端的相关知识目前 以点科普面&#xff0c;此处的总结 重置前端表单内容&#xff0c;防止影响后续操作 其基本知识只需要通过点击按钮触发重置表单 1. 纯重置 可以通过按钮…