卷积神经网络实现手写数字识别 pytorch

卷积神经网络实现手写数字识别

导入相关的库

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt

设置相应的超参数

EPOCH=1
BATCH_SIZE=50
LR=0.001
DOWNLOAD_MNIST=False

加载训练数据

train_data=torchvision.datasets.MNIST(
    root='./mnist',
    train=True,
    transform=torchvision.transforms.ToTensor(),    #将下载的文件转换成pytorch认识的tensor类型,且将图片的数值大小从(0-255)归一化到(0-1)
    download=True
)

任意显示一个图片

print(train_data.data.size())
print(train_data.targets.size())
plt.imshow(train_data.data[0].numpy(),cmap='gray')
plt.title('%i'%train_data.targets[0])
plt.show()
torch.Size([60000, 28, 28])
torch.Size([60000])

在这里插入图片描述

设置训练集和测试集

train_loader=Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
test_data=torchvision.datasets.MNIST(
    root='./mnist',
    train=False,
)

取数据并处理

with torch.no_grad():
    test_x=Variable(torch.unsqueeze(test_data.data, dim=1)).type(torch.FloatTensor)[:2000]/255   #只取前两千个数据吧,差不多已经够用了,然后将其归一化。
    test_y=test_data.targets[:2000]  

卷积网络的定义

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.conv1=nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=16,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.conv2=nn.Sequential(
            nn.Conv2d(
                in_channels=16,
                out_channels=32,
                kernel_size=5,
                stride=1,
                padding=2,
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        self.out=nn.Linear(32*7*7,10)
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)     #(batch,32,7,7)
        x=x.view(x.size(0),-1)    #(batch ,32 * 7 * 7)
        output=self.out(x)
        return output

卷积神经网络的结构

cnn=CNN()
print(cnn)
CNN(
  (conv1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (conv2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (out): Linear(in_features=1568, out_features=10, bias=True)
)

训练以及预测

optimizer=torch.optim.Adam(cnn.parameters(),lr=LR)
# 指定损失函数使用交叉信息熵
loss_fn=nn.CrossEntropyLoss()

'''
开始训练我们的模型哦
'''
step=0
for epoch in range(EPOCH):
    #加载训练数据
    for step,data in enumerate(train_loader):
        x,y=data
        b_x=Variable(x)
        b_y=Variable(y)
        output=cnn(b_x)
        loss=loss_fn(output,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if (step%50==0):
            test_output=cnn(test_x)
            y_pred=torch.max(test_output,1)[1].data.squeeze()
            accuracy=sum(y_pred==test_y).item()/test_y.size(0)
            print('now epoch :  ', epoch, '   |  loss : %.4f ' % loss.item(), '     |   accuracy :   ' , accuracy)
test_output=cnn(test_x[:10])
y_pred=torch.max(test_output,1)[1].data.squeeze()
print(y_pred.tolist(),'predecton Result')
print(test_y[:10].tolist(),'Real Result')
now epoch :   0    |  loss : 2.3108       |   accuracy :    0.18
now epoch :   0    |  loss : 0.5315       |   accuracy :    0.838
now epoch :   0    |  loss : 0.2425       |   accuracy :    0.891
now epoch :   0    |  loss : 0.3464       |   accuracy :    0.921
now epoch :   0    |  loss : 0.2871       |   accuracy :    0.9405
now epoch :   0    |  loss : 0.1521       |   accuracy :    0.937
now epoch :   0    |  loss : 0.1424       |   accuracy :    0.952
now epoch :   0    |  loss : 0.0820       |   accuracy :    0.955
now epoch :   0    |  loss : 0.0958       |   accuracy :    0.9585
now epoch :   0    |  loss : 0.0694       |   accuracy :    0.9615
now epoch :   0    |  loss : 0.1700       |   accuracy :    0.9645
now epoch :   0    |  loss : 0.0583       |   accuracy :    0.9595
now epoch :   0    |  loss : 0.1821       |   accuracy :    0.9645
now epoch :   0    |  loss : 0.0640       |   accuracy :    0.966
now epoch :   0    |  loss : 0.1602       |   accuracy :    0.9735
now epoch :   0    |  loss : 0.0350       |   accuracy :    0.972
now epoch :   0    |  loss : 0.0863       |   accuracy :    0.9585
now epoch :   0    |  loss : 0.1025       |   accuracy :    0.9655
now epoch :   0    |  loss : 0.0898       |   accuracy :    0.972
now epoch :   0    |  loss : 0.0819       |   accuracy :    0.976
now epoch :   0    |  loss : 0.0593       |   accuracy :    0.976
now epoch :   0    |  loss : 0.0615       |   accuracy :    0.9745
now epoch :   0    |  loss : 0.0529       |   accuracy :    0.978
now epoch :   0    |  loss : 0.0372       |   accuracy :    0.977
[7, 2, 1, 0, 4, 1, 4, 9, 5, 9] predecton Result
[7, 2, 1, 0, 4, 1, 4, 9, 5, 9] Real Result

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/655542.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

将 KNX 接入 Home Assistant 之二 准备软件

写在前面: 在KNX官网也有关于 Home Assistant 的教程,地址是 Get started with Home Assistant x KNX 需要的东西是 a KNX IP Interface or Routera Raspberry Pian SD Card at least 32 GB 安装 Home Assistant 系统 下载镜像: 地址&…

网络原理-HTTPS协议

在前面说到HTTP中,我们通过抓包,可以看到许多信息内容,其中往往会包含一些敏感信息,这些都是明文传输,当我们的请求被黑客或者不法分子截获后,那就很危险了,因此衍生出了HTTPS协议来对传输数据进行加密。 一、加密原理 基本原理如下: 明文 密钥 > 密文 密文…

【每日刷题】Day50

【每日刷题】Day50 🥕个人主页:开敲🍉 🔥所属专栏:每日刷题🍍 🌼文章目录🌼 1. 654. 最大二叉树 - 力扣(LeetCode) 2. 119. 杨辉三角 II - 力扣&#xff08…

Gitlab不允许使用ssh拉取代码的解决方案

一、起因 之前一直是用ssh进行代码拉取,后来公司搞网安行动,不允许ssh进行连接拉取代码了 因为我是用shell写了个小型的CI/CD,部署前端项目用于后端联调的,因此在自动部署时,不方便人机交互,所以需要自动填充账密。 …

如何在群晖NAS安装WPS Office并使用公网地址远程访问处理文档表格

前言 作者简介: 懒大王敲代码,计算机专业应届生 今天给大家聊聊如何在群晖NAS安装WPS Office并使用公网地址远程访问处理文档表格,希望大家能觉得实用! 欢迎大家点赞 👍 收藏 ⭐ 加关注哦!💖&a…

水电自动抄表系统是什么?

1.简述:水电自动抄表系统 水电自动抄表系统是一种现代化计量检定解决方法,为提升公用事业服务项目的效率和精确性。传统式手动抄水表方法已经被这类高效率、精准的自动化系统所替代,它能够实时、远程控制地收集解决水电使用数据。 2.系统原…

列表、元组、字典和集合的区别

自学python如何成为大佬(目录):https://blog.csdn.net/weixin_67859959/article/details/139049996?spm1001.2014.3001.5501 在前面介绍了序列中的列表、元组、字典和集合的应用,下面通过表2对这几个数据序列进行比较。 表2 列表、元组、字典和集合的区别 数 据…

【算法】dd爱转转

✨题目链接: dd爱旋转 ✨题目描述 读入一个n∗n的矩阵,对于一个矩阵有以下两种操作 1:顺时针旋180 2:关于行镜像 如 变成 给出q个操作,输出操作完的矩阵 ✨输入描述: 第一行一个数n(1≤n≤1000),表示矩阵大小 接下来n行&#xff…

【全开源】优校管理系统支持微信小程序+微信公众号+H5

概述 优校管理系统(简称优校管)是基于FastAdmin和ThinkPHP进行开发的中小学信息化管理系统,拥有PC版、UniAPP版(高级授权)。支持微信小程序、H5等多平台,主要用于信息管理、教学管理、素养评价,支持多个学校(标准授权限5个,高级授…

Nat Genet|马丁院士团队利用单细胞和空间多组学描绘宫颈鳞癌免疫微环境图谱

宫颈鳞状细胞癌(CSCC)是宫颈癌最常见的组织学类型。虽然局部疾病预后良好,但晚期、转移性或复发性CSCC的5年生存率仅为16.5%,晚期宫颈癌中对免疫检查点阻断的反应有限。细胞异质性和可塑性是肿瘤进展和肿瘤免疫微环境(TIME)调节的…

自由应用大本营?开源免费的Android应用商店:F-Droid Client

F-Droid Client:拥抱开源,守护隐私,让自由软件成为您生活的一部分- 精选真开源,释放新价值。 概览 F-Droid Client是一个开源的Android应用商店,它在GitHub上免费提供。这个项目致力于收集和展示各类自由及开源软件&a…

响应式流和reactor框架进阶

响应式流和reactor框架进阶 响应式流创建、转换、处理 本文档主要介绍在响应式编程中如何从流中获取数据并处理。 前提条件 假设您已经能掌握Java基础、Maven使用、Lamda表达式、响应式编程等基础。 如何获取流中数据 🌏 说明 1、不要试图从流中获取数据出来&a…

牛客NC391 快乐数【simple 模拟法 Java/Go/PHP】

题目 题目链接: https://www.nowcoder.com/practice/293b9ddd48444fa493dd17da0feb192d 思路 直接模拟即可Java代码 import java.util.*;public class Solution {/*** 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值…

自动化测试和手动测试相比,哪个更具优势?

自动化测试和手动测试都是软件测试中常用的两种方法,它们在不同的场景下具有不同的优势。下面将从各个角度进行详细的比较,以帮助读者更好地了解两种方法。 1. 效率 自动化测试相对于手动测试具有更高的效率。手动测试需要人工操作执行测试步骤、输入测…

3步找回丢失数据,EasyRecovery让数据恢复如此简单!

在数字化时代,数据已成为我们生活和工作中不可或缺的一部分。然而,数据丢失却是一个令人头疼的问题。无论是因为意外删除、硬盘损坏还是系统崩溃,丢失的文件和数据都可能给我们造成巨大的经济损失和时间浪费。因此,如何高效利用现…

M功能-分布式锁-支付平台(五)

target:离开柬埔寨倒计时-218day 珍藏的图片又拿出来了 前言 M系统中的撮合引擎是最最核心的功能,第一版的撮合引擎不是我写的,也没有做交易对的动态分配这样的功能,都是基于抢锁方式来决定谁拥有该交易对的撮合权限,所…

【加密与解密(第四版)】第十四章笔记

第十四章 漏洞分析技术 14.1 软件漏洞原理 缓冲区溢出漏洞:栈溢出 堆溢出、整型溢出(存储溢出、计算溢出、符号问题) UAF(Use-After-Free)漏洞 14.2 ShellCode 功能模块:下载执行、捆绑、反弹shell 14.3 …

构造器--5.28

不用一个个属性赋值的方法: 知道了类的创建与使用,但是每次赋值都是一个个调用,我们可以用构造器使得方法简单一点,不用一个个调用属性赋值,直接传参就OK了; 点击类名然后ctrl可以查看构造器 public yanxi…

2024最新(PC+WEB+IOS+Android)即时通讯系统客户端仿默往IM源码下载

2024最新(PCWEBIOSAndroid)即时通讯系统客户端仿默往IM源码下载(总大小约2.4G) 系统功能配置灵活、海量并发、稳定可靠、数据安全,2小时快速部署、数据安全、单聊群聊、系统通知等通信功能,支持App、PC、Web等多端快速接入。 群功能&#xf…

Vue热更新出现内存溢出

Vue热更新出现内存溢出 vue-cli2遇到此问题的解决办法:vue-cli3遇到此问题的解决办法:方法一(已测试ok)方法二(未尝试) 开发项目有一段时间了,随着项目越来越大,打包的时间也相应的变…