pytorch实现 --- 手写数字识别

        本篇文章是博主在人工智能等领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对人工智能等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在Pytorch

       Pytorch(1)---pytorch实现 --- 手写数字识别》

pytorch实现 --- 手写数字识别

目录

1.项目介绍

2.实现方法

3.程序代码

4.运行结果


1.项目介绍

        使用pytorch实现手写数字识别,十分简单的小项目,环境搭建好,一跑就通。


2.实现方法

2.1方式1        

 安装库:

pip install numpy torch torchvision matplotlib

 运行:

python test.py

首次运行会下载MNIST数据集,请保持网络畅通

2.2方式2

        如果使用pycharm,已经安装好了pytorch环境,那么直接在pytorch环境中运行下面这份代码就好。


3.程序代码

"""手写数字识别项目
    时间:2023.11.6
    环境:pytorch
    作者:Rainbook
"""

import torch
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST
import matplotlib.pyplot as plt

class Net(torch.nn.Module):  # 定义一个Net类,神经网络的主体
    def __init__(self):  # 全连接层,四个
        super().__init__()
        self.fc1 = torch.nn.Linear(28*28, 64)  # 输入层输入28*28,输出64
        self.fc2 = torch.nn.Linear(64, 64)  # 中间层,输入64,输出64
        self.fc3 = torch.nn.Linear(64, 64)
        self.fc4 = torch.nn.Linear(64, 10)  # 中间层(隐藏层)的最后一层,输出10个特征值
    
    def forward(self, x):  # 前向传播过程
        # self.fc1(x)全连接线性计算,再套上一个激活函数torch.nn.functional.relu()
        x = torch.nn.functional.relu(self.fc1(x))
        x = torch.nn.functional.relu(self.fc2(x))
        x = torch.nn.functional.relu(self.fc3(x))
        # 最后一层进行softmax归一化,log_softmax是为了提高计算稳定性,在softmax后面套上了一个对数运算
        x = torch.nn.functional.log_softmax(self.fc4(x), dim=1)
        return x


def get_data_loader(is_train):
    to_tensor = transforms.Compose([transforms.ToTensor()])  # 定义数据转换类型tensor,多维数组(张量)
    """下载MNIST数据集,
        "":当前位置
        is_train:判断是训练集还是测试集;
        batch_size:一个批次包含15张图片;
        shuffle:数据随机打乱的
    """
    data_set = MNIST("", is_train, transform=to_tensor, download=True)
    return DataLoader(data_set, batch_size=15, shuffle=True)  # 数据加载器


def evaluate(test_data, net):  # 用来评估神经网络
    n_correct = 0
    n_total = 0
    with torch.no_grad():
        for (x, y) in test_data:
            outputs = net.forward(x.view(-1, 28*28))  # 计算神经网络的预测值
            for i, output in enumerate(outputs):  # 对每个批次的预测值进行比较,累加正确预测的数量
                if torch.argmax(output) == y[i]:
                    n_correct += 1
                n_total += 1
    return n_correct / n_total  # 返回正确率


def main():
    # 导入训练集和测试集
    train_data = get_data_loader(is_train=True)
    test_data = get_data_loader(is_train=False)
    net = Net()  # 初始化神经网络

    # 打印初始网络的正确率,应当是10%附近。手写数字有十种结果,随机猜的正确率就是1/10
    print("initial accuracy:", evaluate(test_data, net))
    """训练神经网络
    pytorch的固定写法
    """
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)
    for epoch in range(5):  # 需要在一个数据集上反复训练神经网络,epoch网络轮次,提高数据集的利用率
        for (x, y) in train_data:
            net.zero_grad()  # 初始化
            output = net.forward(x.view(-1, 28*28))  # 正向传播
            # 计算差值,nll_loss对数损失函数,为了匹配log_softmax的log运算
            loss = torch.nn.functional.nll_loss(output, y)
            loss.backward()  # 反向误差传播
            optimizer.step()  # 优化网络参数
        print("epoch", epoch, "accuracy:", evaluate(test_data, net))  # 打印当前网络的正确率

    """测试神经网络
        训练完成后,随机抽取3张图片进行测试
    """
    for (n, (x, _)) in enumerate(test_data):
        if n > 3:
            break
        predict = torch.argmax(net.forward(x[0].view(-1, 28*28)))  # 测试结果
        plt.figure(n)  # 画出图像
        plt.imshow(x[0].view(28, 28))  # 像素大小28*28
        plt.title("prediction: " + str(int(predict)))  # figure的标题
    plt.show()


if __name__ == "__main__":
    main()

4.运行结果

4.1正确率

4.2测试结果

        参考资料来源:B站

        文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/117922.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

海康威视解码器维修DS-6900系列DS-6916UD

海康威视解码器常见维修型号:DS-6916UD/DS-6901/DS-6904/DS-6908/DS-6910/DS-6912UD/6A16 DS-6A16UD 产品类型:视音频解码器纠错 I/O接口:输入 DVI-I纠错;输出 VGA,BNC纠错;音频输入 HDMI纠错 产品特性 …

安科瑞关于新能源电动汽车有序充电的对策-安科瑞黄安南

摘要 随着我国能源战略发展以及低碳行动的实施,电动汽车已逐步广泛应用,而电动汽车的应用非常符合当今社会对环保意识的要求,以及有效节省化石燃料的消耗。由于其没有污染排放的优点以及政府部门的关注,电动汽车将成为以后出行的…

【网络知识必知必会】聊聊数据链路层以太网

文章目录 前言1. 认识以太网2. 以太网帧格式已经有了ip地址, 为什么还要有 mac 地址呢?认识MTUMTU对IP协议的影响MTU对UDP协议的影响MTU对于TCP协议的影响 总结 前言 本文继续来聊聊网络传输中数据链路层中的一个代表协议, 以太网. 以太这个词其实最早出现在物理学当中, 在早…

基于SpringAOP实现自定义接口权限控制

文章目录 一、接口鉴权方案分析1、接口鉴权方案2、角色分配权限树 二、编码实战1、定义权限树与常用方法2、自定义AOP注解3、AOP切面类(也可以用拦截器实现)4、测试一下 一、接口鉴权方案分析 1、接口鉴权方案 目前大部分接口鉴权方案,一般…

HTML5的语义元素

HTML5语义元素&#xff1a; HTML5提供新的语义元素来明确一个web页面的不同部分&#xff1a;<head>、<nav>、<section>、<article>、<aside>、<figcation>、<figure>、<footer>。 1&#xff09;、<section>元素&#x…

dockerfile避坑笔记(VMWare下使用Ubuntu在Ubuntu20.04基础镜像下docker打包多个go项目)

一、docker简介 docker是一种方便跨平台迁移应用的程序&#xff0c;通过docker可以实现在同一类操作系统中&#xff0c;如Ubuntu和RedHat两个linux操作系统中&#xff0c;实现程序的跨平台部署。比如我在Ubuntu中打包了一个go项目的docker镜像&#xff08;镜像为二进制文件&am…

2023年11月5日网规考试备忘

早上题目回忆&#xff1a; pki体系 ipsec&#xff0c;交换安全&#xff08;流量抑制&#xff09; aohdlc bob metclaf —ethernet pon tcp三次握手 OSPF lsa&#xff1f;交换机组ospf配置问题&#xff0c;ping网关可通&#xff0c;AB不通 raid6 300G*8 网络利用率 停等协议10…

【C++初阶】一、入门知识讲解(C++关键字、命名空间、C++输入输出、缺省参数、函数重载)

相关代码gitee自取&#xff1a; C语言学习日记: 加油努力 (gitee.com) 接上期&#xff1a; 【数据结构初阶】十一、归并排序(比较排序)的讲解和实现 &#xff08;递归版本 非递归版本 -- C语言实现&#xff09;-CSDN博客 引入&#xff1a;什么是C C语言是结构化和模块化的…

剑指JUC原理-9.Java无锁模型

&#x1f44f;作者简介&#xff1a;大家好&#xff0c;我是爱吃芝士的土豆倪&#xff0c;24届校招生Java选手&#xff0c;很高兴认识大家&#x1f4d5;系列专栏&#xff1a;Spring源码、JUC源码&#x1f525;如果感觉博主的文章还不错的话&#xff0c;请&#x1f44d;三连支持&…

Flink SQL 窗口聚合详解

1.滚动窗⼝&#xff08;TUMBLE&#xff09; **滚动窗⼝定义&#xff1a;**滚动窗⼝将每个元素指定给指定窗⼝⼤⼩的窗⼝&#xff0c;滚动窗⼝具有固定⼤⼩&#xff0c;且不重叠。 例如&#xff0c;指定⼀个⼤⼩为 5 分钟的滚动窗⼝&#xff0c;Flink 将每隔 5 分钟开启⼀个新…

如何在知识付费系统小程序开发中实现社区互动和用户参与

在知识付费系统小程序的开发中&#xff0c;实现社区互动和用户参与可以通过以下步骤实现&#xff1a; 1. 建立用户身份验证和管理系统 // 后端示例代码&#xff08;Node.js&#xff09; // 用户注册 app.post(/register, (req, res) > {const { username, email, passwor…

如何在电脑上制作可视化待办任务清单?

在现代高效工作的节奏下&#xff0c;上班族们需要管理大量的待办任务和工作事项。可视化的待办任务清单能够使我们清晰地了解自己的任务进度和工作优先级。每天打开电脑&#xff0c;我们可以直观地看到还有哪些任务需要完成&#xff0c;避免遗漏和混乱。而如何将这些任务清单可…

数据结构之堆的实现(图解➕源代码)

一、堆的定义 首先明确堆是一种特殊的完全二叉树&#xff0c;分为大根堆和小根堆&#xff0c;接下来我们就分别介绍一下这两种不同的堆。 1.1 大根堆&#xff08;简称&#xff1a;大堆&#xff09; 在大堆里面&#xff1a;父节点的值 ≥ 孩子节点的值 我们的兄弟节点没有限制&…

Nacos2.2.3版本运行startup.cmd出现闪退,无错误信息解决方法

Nacos2.2.3版本运行startup.cmd出现闪退&#xff0c;无错误信息解决方法 一、问题描述二、解决方法 一、问题描述 当我下载好nacos2.2.3版解压之后&#xff0c;直接双击startup.cmd出现闪退&#xff0c;而且 没有错误提示信息。后来经过一番搜索尝试&#xff0c;终于解决了自己…

Spring 中 @Qualifier 注解还能这么用?

今天想和小伙伴们聊一聊 Qualifier 注解的完整用法&#xff0c;同时也顺便分析一下它的实现原理。 说到 Qualifier&#xff0c;有的小伙伴可能会觉得诧异&#xff0c;这也只得写一篇文章&#xff1f;确实&#xff0c;但凡有点开发经验&#xff0c;多多少少可能都遇到过 Qualif…

《算法通关村—轻松搞定合并二叉树》

《算法通关村—轻松搞定合并二叉树》 描述 leetcode 617 给你两棵二叉树&#xff1a; root1 和 root2 。 想象一下&#xff0c;当你将其中一棵覆盖到另一棵之上时&#xff0c;两棵树上的一些节点将会重叠&#xff08;而另一些不会&#xff09;。你需要将这两棵树合并成一棵…

酒水展示预约小程序的效果如何

酒的需求度非常高&#xff0c;各种品牌、海量经销商组成了庞大市场&#xff0c;而在实际经营中&#xff0c;酒水品牌、经销商、门店经营者等环节往往也面临着品牌传播拓客引流难、产品展示预约订购难、营销难、销售渠道单一等痛点。 那么商家们应该怎样解决呢&#xff1f; 可以…

Vue3多页面开发实践

前言&#xff1a; 项目需求&#xff0c;把项目中的一个路由页面单摘出来作为一个新的项目。项目部署到服务器上后&#xff0c;通过一个链接的形式可以直接访问到新项目的页面。 解决方式&#xff1a; 使用Vue多页面方式打包项目 实现步骤&#xff1a; 1、在项目的src目录下&am…

MySQL(8):聚合函数

聚合函数介绍 聚合函数&#xff1a; 对一组数据进行汇总的函数&#xff0c;输入的是一组数据的集合&#xff0c;输出的是单个值。 聚合函数类型&#xff1a;AVG(),SUM(),MAX(),MIN(),COUNT() AVG / SUM 只适用于数值类型的字段&#xff08;或变量&#xff09; SELECT AVG(…

【IK分词器安装】

安装IK分词器&#xff1a; 下载链接&#xff08;如果es版本不同可以修改下版本号&#xff09;&#xff1a;https://github.com/medcl/elasticsearch-analysis-ik/releases/download/v7.12.1/elasticsearch-analysis-ik-7.12.1.zip 通常下载是比较慢的&#xff1a;有需要可以从…