机器学习预测-CNN手写字识别

介绍

这段代码是使用PyTorch实现的卷积神经网络(CNN),用于在MNIST数据集上进行图像分类。让我一步步解释:

  1. 导入库:代码导入了必要的库,包括PyTorch(torch)、神经网络模块(torch.nn)、函数模块(torch.nn.functional)、图像数据集(torchvision)以及数据处理(torch.utils.data)和可视化(matplotlib.pyplot)的工具。

  2. 设置超参数:定义了超参数,如批大小(Batch_size)、epoch数量(Epoch)和学习率(Lr)。

  3. 加载MNIST数据集:使用torchvision.datasets.MNIST加载MNIST数据集。该数据集包含了0到9的手写数字的灰度图像。transform=torchvision.transforms.ToTensor()将PIL图像转换为PyTorch张量。

  4. 可视化样本数据:打印数据集的大小,并显示数据集中的第一张图像及其相应的标签。

  5. 准备测试数据:准备测试数据与训练数据类似。加载MNIST测试数据集,并选择前2000个图像进行测试。

  6. 创建数据加载器:使用torch.utils.data.DataLoader创建训练数据的数据加载器。它有助于在训练过程中对数据进行分批和混洗。

  7. 定义CNN架构:通过子类化nn.Module来定义CNN类。该架构包括两个卷积层(self.con1self.con2),后面跟有ReLU激活函数和最大池化层。卷积层的输出被展平并馈入全连接层(self.out),产生最终输出。

  8. 初始化CNN:创建CNN类的实例。

  9. 定义损失函数和优化器:使用交叉熵损失(nn.CrossEntropyLoss)作为损失函数,使用随机梯度下降(torch.optim.SGD)作为优化器。

  10. 训练CNN:在指定的epoch数量循环内训练模型。在循环内,将训练数据通过模型,计算损失,进行梯度反向传播,并由优化器更新模型参数。

  11. 测试模型:每50次迭代训练时,对测试数据集进行评估。将测试预测与真实标签进行比较,计算准确率。

  12. 打印结果:训练结束后,打印模型预测及前10个测试样本的真实标签。

总的来说,这段代码训练了一个CNN模型,用于在MNIST数据集上对手写数字进行分类,并在单独的测试数据集上评估其性能。

代码

# -*- coding: utf-8 -*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch.utils.data as Data
import matplotlib.pyplot as plt

# define hyper parameters
Batch_size = 100
Epoch = 1
Lr = 0.5
#DOWNLOAD_MNIST = True # 若没有数据,用此生成数据

# define train data and test data
train_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=True,
    download=False,
    transform=torchvision.transforms.ToTensor()
)
print(train_data.data.size())
print(train_data.targets.size())
print(train_data.data[0])
# 画一个图片显示出来
plt.imshow(train_data.data[0].numpy(),cmap='gray')
plt.title('%i'%train_data.targets[0])
plt.show()
# print(train_data.data.shape)           # torch.Size([60000, 28, 28])
# print(train_data.targets.size())        # torch.Size([60000])
# print(train_data.data[0].size())        # torch.Size([28, 28])
# plt.imshow(train_data.data[0].numpy(), cmap='gray')
# plt.show()
test_data = torchvision.datasets.MNIST(
    root='./mnist',
    train=False,
    # transform=torchvision.transforms.ToTensor()
)
test_x = torch.unsqueeze(test_data.data, dim=1).type(torch.FloatTensor)[:2000]
test_y = test_data.targets[:2000]
# print(test_x.shape)                         # torch.Size([2000, 1, 28, 28])
# print(test_y.shape)                         # torch.Size([2000])
train_loader = Data.DataLoader(
    dataset=train_data,
    shuffle=True,
    batch_size=Batch_size,
)

# define network structure
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.con1 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=16, kernel_size=5, stride=1, padding=2),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2)
        )
        self.con2 = nn.Sequential(
            nn.Conv2d(16, 32, 5, 1, 2),
            nn.ReLU(),
            nn.MaxPool2d(2)
        )
        self.out = nn.Linear(32 * 7 * 7, 10)

    def forward(self, x):
        x = self.con1(x)            # (batch, 16, 14, 14)
        x = self.con2(x)            # (batch, 32, 7, 7)
        x = x.view(x.size(0), -1)
        out = self.out(x)             # (batch_size, 10)
        return out

cnn = CNN()
# print(cnn)
optimizer = torch.optim.SGD(cnn.parameters(), lr=Lr)
loss_fun = nn.CrossEntropyLoss()

for epoch in range(Epoch):
    for i, (x, y) in enumerate(train_loader):
        output = cnn(x)
        loss = loss_fun(output, y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if i % 50 == 0:
            test_output = torch.max(cnn(test_x), dim=1)[1]
            loss = loss_fun(cnn(test_x), test_y).item()
            accuracy = torch.sum(torch.eq(test_output, test_y)).item() / test_y.numpy().size
            print('Epoch:', Epoch, '|loss:%.4f' % loss, '|accuracy:%.4f' % accuracy)

print('real value', test_data.targets[: 10].numpy())
print('train value', torch.max(cnn(test_x)[: 10], dim=1)[1].numpy())




结果

real value [7 2 1 0 4 1 4 9 5 9]
train value [7 2 1 0 4 1 4 9 5 9]

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/644094.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Linux】Linux的安装

文章目录 一、Linux环境的安装虚拟机 镜像文件云服务器(可能需要花钱) 未完待续 一、Linux环境的安装 我们往后的学习用的Linux版本为——CentOs 7 ,使用 Ubuntu 也可以 。这里提供几个安装方法: 电脑安装双系统(不…

LeetCode热题100——矩阵

73.矩阵清零 题目 给定一个 *m* x *n* 的矩阵,如果一个元素为 0 ,则将其所在行和列的所有元素都设为 0 。请使用 原地 算法。 示例 1: 输入:matrix [[1,1,1],[1,0,1],[1,1,1]] 输出:[[1,0,1],[0,0,0],[1,0,1]] 示例…

OpenAI撤回有争议的决定:终止永久性非贬损协议

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗?订阅我们的简报,深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同,从行业内部的深度分析和实用指南中受益。不要错过这个机会,成为AI领…

Docker提示某网络不存在如何解决,添加完网络之后如何删除?

Docker提示某网络不存在如何解决? 创建 Docker 网络 假设现在需要创建一个名为my-mysql-network的网络 docker network create my-mysql-network运行容器 创建网络之后,再运行 mysqld_exporter 容器。完整命令如下: docker run -d -p 9104…

力扣刷题---2283. 判断一个数的数字计数是否等于数位的值【简单】

题目描述 给你一个下标从 0 开始长度为 n 的字符串 num &#xff0c;它只包含数字。 如果对于 每个 0 < i < n 的下标 i &#xff0c;都满足数位 i 在 num 中出现了 num[i]次&#xff0c;那么请你返回 true &#xff0c;否则返回 false 。 示例 1&#xff1a; 输入&a…

机器人物理引擎

机器人物理引擎是用于计算并模拟机器人及其交互环境在虚拟世界中运动轨迹的组件。 MuJoCo&#xff08;Multi-Joint Dynamics with Contact&#xff09;&#xff1a; 基于广义坐标和递归算法&#xff0c;专注于模拟多关节系统如人形机器人。采用了速度相关的算法来仿真连接点力…

AI菜鸟向前飞 — LangChain系列之十四 - Agent系列:从现象看机制(上篇)

上一篇介绍了Agent与LangGraph的基础技能Tool的必知必会 AI菜鸟向前飞 — LangChain系列之十三 - 关于Tool的必知必会 前面已经详细介绍了Promp、RAG&#xff0c;终于来到Agent系列&#xff08;别急后面还有LangGraph&#xff09;&#xff0c;大家可以先看下这张图&#xff1…

网络模型-路由策略

一、路由策略 路由策略(Routing Policy)作用于路由&#xff0c;主要实现了路由过滤和路由属性设置等功能&#xff0c;它通过改变路由属性(包括可达性)来改变网络流量所经过的路径。目的:设备在发布、接收和引入路由信息时&#xff0c;根据实际组网需要实施一些策略&#xff0c…

C++:关联容器及综合运用:

关联容器和顺序容器有着根本的不同:关联容器中的元素是按关键字来保存和访问的,而顺序容器中的元素是按它们在容器中的位置来顺序保存和访问的。关联容器因此相比与顺序容器支持高效的关键字查找和访问。 其底层数据结构&#xff1a;顺序关联容器 ->红黑树&#xff0c;插入…

Redis离线安装(单机)

目录 1-环境准备1-1下载redis-4.0.11.tar.gz1-2gcc环境 2-上传解压3-编译安装(需要gcc环境)4-配置redis5-启动Redis6-开启防火墙(root)7-添加开机启动脚本8-设置权限9-设置开机启动10-测试redis服务11-检查是否安装成功12-创建redis命令软连接13-测试redis14-必要时设置防火墙 …

禅道密码正确但是登录异常处理

禅道密码正确&#xff0c;但是登录提示密码错误的异常处理 排查内容 # 1、服务器异常&#xff0c;存储空间、数据库异常 # 2、服务异常&#xff0c;文件丢失等异常问题定位 # 1、df -h 排查服务器存储空间 # 2、根据my.php排查数据库连接是否正常 # 3、修改my.pho,debugtrue…

外企也半夜发布上线吗?

0 别把问题想得太复杂 如果有灰度发布的能力&#xff0c;最好白天发布&#xff1b;如果没有灰度发布&#xff0c;只能在半夜发布。 即使有灰度发布能力&#xff0c;也不要沾沾自喜&#xff0c;好好反思一下你们的灰度发布是否真的经得起考验&#xff0c;还是仅仅是装装样子。…

区块链技术和应用二

前言 学习长安链的一些基本原理 官网&#xff1a;长安链开源文档 b站课程&#xff1a;区块链基础与应用 一、共识算法 1.1 POW工作量证明 最长链共识&#xff0c;没听明白 1.2 51%攻击 二、区块链的发展 2.1 区块链1.0到3.0 2.2 共有链、联盟链、私有链 2.3 发展趋势 2.4 扩…

Spring Boot Interceptor(拦截器使用及原理)

之前的博客中讲解了关于 Spring AOP的思想和原理&#xff0c;而实际开发中Spring Boot对于AOP的思想的具体实现就是Spring Boot Interceptor。在 Spring Boot 应用程序开发中&#xff0c;拦截器&#xff08;Interceptor&#xff09;是一个非常有用的工具。它允许我们在 HTTP 请…

Unity修改Project下的Assets的子文件的图标

Unity修改文件夹的图标 示例&#xff1a; 在右键可以创建指定文件夹。 github链接 https://github.com/SeaeeesSan/SimpleFolderIconCSDN资源的链接 https://download.csdn.net/download/GoodCooking/89347361 去GitHub下载支持原作者哦。重要的事情 截图来自GitHub 。 U…

文件编码格式查看和转换

1、查看文件编码格式 记事本&#xff1a;打开文件后&#xff0c;点击“文件”--“另存为”&#xff0c;可查看文件的编码格式。**Notepad**&#xff1a;打开文件后&#xff0c;即可在右下角查看文件的编码格式。vim&#xff1a;打开文件后&#xff0c;输入“:set fileencoding…

网络安全基础技术扫盲篇 — 名词解释

网络模块基础&#xff08;网络拓扑图、网络设备、安全设备&#xff09; 用通俗易懂的话说&#xff1a; 网络拓扑图&#xff1a;它就像一张网络世界的地图&#xff0c;它展现了我们数不清的网站、服务器和设备是如何相互连接的。用简单的话说&#xff0c;它就是给我们指路、告…

人工智能 框架 paddlepaddle 飞桨 使用指南 使用例子 线性回归模型demo 详解

安装过程&使用指南&线性回归模型 使用例子 本来预想 是安装 到 conda 版本的 11.7的 但是电脑没有gpu 所以 安装过程稍有变动,下面简单讲下 conda create -n paddle_env117 python=3.9 由于想安装11.7版本 py 是3.9 所以虚拟环境名称也是 paddle_env117 activa…

C语言 | Leetcode C语言题解之第111题二叉树的最小深度

题目&#xff1a; 题解&#xff1a; typedef struct {int val;struct TreeNode *node;struct queNode *next; } queNode;void init(queNode **p, int val, struct TreeNode *node) {(*p) (queNode *)malloc(sizeof(queNode));(*p)->val val;(*p)->node node;(*p)->…