深度学习实验第T1周:实现mnist手写数字识别

>- **🍨 本文为[🔗365天深度学习训练营](https://mp.weixin.qq.com/s/0dvHCaOoFnW8SCp3JpzKxg) 中的学习记录博客**
>- **🍖 原作者:[K同学啊](https://mtyjkh.blog.csdn.net/)**

目录

目录

一、前言

二、我的环境

三、前期准备

1.设置GPU

2.导入数据

3.归一化

 4.可视化图片

5.调整图片格式

四、构建简单的cnn网络

五.编译模型

六、训练模型

七预测

八、知识点详解

1. MNIST手写数字数据集介绍

2. 神经网络程序说明

 3.模型结构说明

八、总结


一、前言

作为一名研究牲,一定要了解pytorch和tensorflow。下面我来介绍一下。

TensorFlow和PyTorch是两个流行的开源机器学习库,它们都支持深度学习模型的开发和训练。尽管它们在很多方面有相似之处,但它们之间也存在一些关键的区别:

1. **设计哲学**:
   - **TensorFlow**:最初由Google Brain团队开发,TensorFlow的设计更倾向于生产环境,强调模型的可扩展性和部署的灵活性。TensorFlow提供了一个静态计算图,这意味着在执行之前,整个计算图需要被定义和优化。
   - **PyTorch**:由Facebook的AI研究团队开发,PyTorch的设计更倾向于研究和快速原型开发,强调动态性和易用性。PyTorch使用动态计算图,允许在运行时修改图。

2. **易用性**:
   - **TensorFlow**:对于初学者来说可能稍微复杂一些,因为它需要用户理解计算图的概念。
   - **PyTorch**:提供了一个更接近于NumPy的API,使得从NumPy过渡到深度学习更加自然。

3. **灵活性**:
   - **TensorFlow**:由于其静态图的特性,可能在某些需要高度灵活性的场景下不如PyTorch灵活。
   - **PyTorch**:动态图使得在运行时修改模型变得更加容易,这对于研究和快速迭代非常有用。

4. **性能**:
   - 两者在性能上都非常出色,但TensorFlow在某些情况下可能因为其优化的静态图而提供更好的性能。

5. **社区和生态系统**:
   - **TensorFlow**:由于其较早的发布和Google的支持,拥有一个庞大的社区和丰富的库。
   - **PyTorch**:虽然起步较晚,但社区发展迅速,特别是在研究领域。

6. **部署**:
   - **TensorFlow**:提供了TensorFlow Serving等工具,使得模型部署更加方便。
   - **PyTorch**:模型部署可能需要更多的工作,但PyTorch与ONNX(Open Neural Network Exchange)的集成正在改善这一状况。

7. **多GPU支持**:
   - **TensorFlow**:从设计之初就考虑了多GPU支持。
   - **PyTorch**:虽然也支持多GPU,但在某些情况下可能需要更多的手动配置。

8. **API一致性**:
   - **TensorFlow**:API在不同版本之间可能发生变化,这可能会影响向后兼容性。
   - **PyTorch**:API相对稳定,变化较少。

选择哪个框架往往取决于个人偏好、项目需求和团队熟悉度。两者都是强大的工具,能够支持复杂的深度学习任务。

二、我的环境

三、前期准备

1.设置GPU

import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")

if gpus:
    gpu0 = gpus[0] #如果有多个GPU,仅使用第0个GPU
    tf.config.experimental.set_memory_growth(gpu0, True) #设置GPU显存用量按需使用
    tf.config.set_visible_devices([gpu0],"GPU")

2.导入数据

import tensorflow as tf
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

# 导入mnist数据,依次分别为训练集图片、训练集标签、测试集图片、测试集标签
(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()

3.归一化

归一化与标准化icon-default.png?t=N7T8https://blog.csdn.net/qq_38251616/article/details/126048261

# 将像素的值标准化至0到1的区间内。(对于灰度图片来说,每个像素最大值是255,每个像素最小值是0,也就是直接除以255就可以完成归一化。)
train_images, test_images = train_images / 255.0, test_images / 255.0
# 查看数据维数信息
train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
"""
输出:((60000, 28, 28), (10000, 28, 28), (60000,), (10000,))
"""

 4.可视化图片

# 将数据集前20个图片数据可视化显示
# 进行图像大小为20宽、10长的绘图(单位为英寸inch)
plt.figure(figsize=(20,10))
# 遍历MNIST数据集下标数值0~49
for i in range(20):
    # 将整个figure分成5行10列,绘制第i+1个子图。
    plt.subplot(2,10,i+1)
    # 设置不显示x轴刻度
    plt.xticks([])
    # 设置不显示y轴刻度
    plt.yticks([])
    # 设置不显示子图网格线
    plt.grid(False)
    # 图像展示,cmap为颜色图谱,"plt.cm.binary"为matplotlib.cm中的色表
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    # 设置x轴标签显示为图片对应的数字
    plt.xlabel(train_labels[i])
# 显示图片
plt.show()

5.调整图片格式

#调整数据到我们需要的格式
train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))

train_images.shape,test_images.shape,train_labels.shape,test_labels.shape
"""
输出:((60000, 28, 28, 1), (10000, 28, 28, 1), (60000,), (10000,))
"""

四、构建简单的cnn网络

网络结构图

(1)第一步构建cnn网络模型

(2)第二步:加载并打印模型

(3)第三步: 输出结果​编辑

# 创建并设置卷积神经网络
# 卷积层:通过卷积操作对输入图像进行降维和特征抽取
# 池化层:是一种非线性形式的下采样。主要用于特征降维,压缩数据和参数的数量,减小过拟合,同时提高模型的鲁棒性。
# 全连接层:在经过几个卷积和池化层之后,神经网络中的高级推理通过全连接层来完成。
model = models.Sequential([
    # 设置二维卷积层1,设置32个3*3卷积核,activation参数将激活函数设置为ReLu函数,input_shape参数将图层的输入形状设置为(28, 28, 1)
    # ReLu函数作为激活励函数可以增强判定函数和整个神经网络的非线性特性,而本身并不会改变卷积层
    # 相比其它函数来说,ReLU函数更受青睐,这是因为它可以将神经网络的训练速度提升数倍,而并不会对模型的泛化准确度造成显著影响。
    layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
    #池化层1,2*2采样
    layers.MaxPooling2D((2, 2)),                   
    # 设置二维卷积层2,设置64个3*3卷积核,activation参数将激活函数设置为ReLu函数
    layers.Conv2D(64, (3, 3), activation='relu'),  
    #池化层2,2*2采样
    layers.MaxPooling2D((2, 2)),                   
    
    layers.Flatten(),                    #Flatten层,连接卷积层与全连接层
    layers.Dense(64, activation='relu'), #全连接层,特征进一步提取,64为输出空间的维数,activation参数将激活函数设置为ReLu函数
    layers.Dense(10)                     #输出层,输出预期结果,10为输出空间的维数
])
# 打印网络结构
model.summary()

五.编译模型

"""
这里设置优化器、损失函数以及metrics
这三者具体介绍可参考我的博客:
https://blog.csdn.net/qq_38251616/category_10258234.html
"""
# model.compile()方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准
model.compile(
	# 设置优化器为Adam优化器
    optimizer='adam',
	# 设置损失函数为交叉熵损失函数(tf.keras.losses.SparseCategoricalCrossentropy())
    # from_logits为True时,会将y_pred转化为概率(用softmax),否则不进行转换,通常情况下用True结果更稳定
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    # 设置性能指标列表,将在模型训练时监控列表中的指标
    metrics=['accuracy'])

 # model.compile()方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准

六、训练模型

"""
这里设置输入训练数据集(图片及标签)、验证数据集(图片及标签)以及迭代次数epochs
关于model.fit()函数的具体介绍可参考我的博客:
https://blog.csdn.net/qq_38251616/category_10258234.html
"""
history = model.fit(
    # 输入训练集图片
	train_images, 
	# 输入训练集标签
	train_labels, 
	# 设置10个epoch,每一个epoch都将会把所有的数据输入模型完成一次训练。
	epochs=10, 
	# 设置验证集
    validation_data=(test_images, test_labels))

七预测

通过下面的网络结构我们可以简单理解为,输入一张图片,将会得到一组数,这组代表这张图片上的数字为0~9中每一个数字的几率(并非概率),out数字越大可能性越大,仅此而已

在这一步中部分同学会因为 matplotlib 版本原因报 Invalid shape (28, 28, 1) for image data 的错误提示,可以将代码改为 plt.imshow(test_images[1].reshape(28,28)) 。 

plt.imshow(test_images[1])

 

#输出测试集中第一张图片的预测结果
pre = model.predict(test_images) # 对所有测试图片进行预测
pre[1] # 输出第一张图片的预测结果

八、知识点详解

本文使用的是最简单的CNN模型- -LeNet-5,如果是第一次接触深度学习的话,可以先试着把代码跑通,然后再尝试去理解其中的代码。

1. MNIST手写数字数据集介绍

MNIST手写数字数据集来源于是美国国家标准与技术研究所,是著名的公开数据集之一。数据集中的数字图片是由250个不同职业的人纯手写绘制,数据集获取的网址为:MNIST handwritten digit database, Yann LeCun, Corinna Cortes and Chris Burges(下载后需解压)。我们一般会采用(train_images, train_labels), (test_images, test_labels) = datasets.mnist.load_data()这行代码直接调用,这样就比较简单

MNIST手写数字数据集中包含了70000张图片,其中60000张为训练数据,10000为测试数据,70000张图片均是28*28,数据集样本如下:

如果我们把每一张图片中的像素转换为向量,则得到长度为28*28=784的向量。因此我们可以把训练集看成是一个[60000,784]的张量,第一个维度表示图片的索引,第二个维度表示每张图片中的像素点。而图片里的每个像素点的值介于0-1之间。

2. 神经网络程序说明

 3.模型结构说明

各层的作用

  • 输入层:用于将数据输入到训练网络
  • 卷积层:使用卷积核提取图片特征
  • 池化层:进行下采样,用更高层的抽象表示图像特征
  • Flatten层:将多维的输入一维化,常用在卷积层到全连接层的过渡
  • 全连接层:起到“特征提取器”的作用
  • 输出层:输出结果

八、总结

本周的任务中,实现了手写数字识别的任务,第一点就是准备数据集,本次数据集是可以直接下载的不用导入,构建模型,使用的是最基础的- -LeNet-5,卷积层提取特征,池化层降采样,重复两遍之后来个flatten层拉伸一下,便于全连接层输入,全连接层得出分类结果。优化器损失函数直接放在# model.compile()方法用于在配置训练方法时,告知训练时用的优化器、损失函数和准确率评测标准#方法里面了,最后直接训练即可。整体比较顺利。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/751627.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【数据集划分——针对于原先图片已经整理好类别】训练集|验证集|测试集

目标:用split-folders进行数据集划分 学习资源:https://www.youtube.com/watch?vC6wbr1jJvVs 努力的小巴掌 记录计算机视觉学习道路上的所思所得。 现在已经有了数据集,并且,注意,是已经划分好类别的! …

esp12实现的网络时钟校准

网络时间的获取是通过向第三方服务器发送GET请求获取并解析出来的。 在本篇博客中,网络时间的获取是一种自动的行为,当系统成功连接WiFi获取到网络天气后,系统将自动获取并解析得到时间和日期,为了减少误差每两分钟左右进行一次校…

【Docker】创建 swarm 集群

目录 1. 更改防火墙设置 2. 安装 Docker 组件 3. 启动 Docker 服务,并检查服务状态。 4. 修改配置文件,监听同一端口号。 5. 下载 Swarm 组件 6. 创建集群,加入节点 7. 启动集群 8. 查询集群节点信息 9. 查询集群具体信息 10. 查询…

用Python设置Excel工作表网格线的隐藏与显示

Excel表格界面的直观性很大程度上得益于表格中的网格线设计,这些线条帮助用户精确对齐数据,清晰划分单元格。网格线是Excel界面中默认显示的辅助线,用于辅助定位,与单元格边框不同,不影响打印输出。然而,在…

【Linux:文件描述符】

文件描述符: 文件描述符的分配原则:最小未分配原则 每一个进程中有一个task_struct结构体(PCB),而task_struct中含有struct file_sturct*file的指针,该指针指向了一个struct files_struct的结构体该结构体中含有一个f…

Python27 神经网络中的重要概念和可视化实现

1. 神经网络背后的直观知识 神经网络的工作方式非常相似:它接受多个输入,经过多个隐藏层中的多个神经元进行处理,并通过输出层返回结果,这个过程在技术上称为“前向传播”。 接下来,将神经网络的输出与实际输出进行比…

Java | Leetcode Java题解之第202题快乐数

题目&#xff1a; 题解&#xff1a; class Solution {private static Set<Integer> cycleMembers new HashSet<>(Arrays.asList(4, 16, 37, 58, 89, 145, 42, 20));public int getNext(int n) {int totalSum 0;while (n > 0) {int d n % 10;n n / 10;totalS…

Windows下activemq开启jmx

1.activemq版本信息 activemq&#xff1a;apache-activemq-5.18.4 2.Windows下activemq开启jmx 1.进入activemq conf目录&#xff0c;备份activemq.xml文件 2.编辑activemq.xml文件&#xff0c;在broker节点增加useJmx"true" <broker xmlns"http://active…

Vuetify3:​快捷回到顶部

在Vuetify 3中&#xff0c;要实现回到顶部&#xff0c;我们需要创建悬浮按钮&#xff0c;如下&#xff1a; <template><v-list><div class"position-fixed right-0 bottom-0" style"top:50%;"><v-list-item ><v-btn icon"…

黑马点评项目总结1-使用Session发送验证码和登录login和 使用Redis存储验证码和Redis的token登录

黑马先是总结了从session实现登录&#xff0c;然后是因为如果使用了集群方式的服务器的话&#xff0c;存在集群共享session互相拷贝效率低下的问题&#xff0c;接着引出了速度更快的内存型的kv数据库Redis&#xff0c; 使用Session发送验证码和登录login 举个例子&#xff1a…

『Django』模型入门教程-操作MySQL

theme: smartblue 点赞 关注 收藏 学会了 本文简介 一个后台如果没有数据库可以说废了一半。日常开发中大多数时候都在与数据库打交道。Django 为我们提供了一种更简单的操作数据库的方式。 在 Django 中&#xff0c;模型(Model)是用来定义数据库结构的类。每个模型类通常对…

kali下安装使用蚁剑(AntSword)

目录 0x00 介绍0x01 安装0x02 使用1. 设置代理2. 请求头配置3. 编码器 0x00 介绍 蚁剑&#xff08;AntSword&#xff09;是一个webshell管理工具。 官方文档&#xff1a;https://www.yuque.com/antswordproject/antsword 0x01 安装 在kali中安装蚁剑&#xff0c;分为两部分&am…

matlab绘制二维曲线,如何设置线型、颜色、标记点类型、如何设置坐标轴、matlab 图表标注、在图中标记想要的点

matlab绘制二维曲线&#xff0c;如何设置线型、颜色、标记点类型、如何设置坐标轴、matlab 图表如何标注、如何在图中标记想要的点 matlab绘制二维曲线&#xff0c;如何在图中标记想要的点。。。如何设置线型、颜色、标记点类型。。。如何设置坐标轴。。。matlab 图表标注操作…

视频网站系统

摘 要 随着互联网的快速发展和人们对视频内容的需求增加&#xff0c;视频网站成为了人们获取信息和娱乐的重要平台。本论文基于SpringBoot框架&#xff0c;设计与实现了一个视频网站系统。首先&#xff0c;通过对国内外视频网站发展现状的调研&#xff0c;分析了视频网站的背景…

潮玩手办盲盒前端项目模版的技术探索与应用案例

一、引言 在数字化时代&#xff0c;随着消费者对个性化和艺术化产品的需求日益增长&#xff0c;潮玩手办和盲盒市场逐渐崭露头角。为了满足这一市场需求&#xff0c;前端技术团队需要构建一个功能丰富、用户友好的在线平台。本文旨在探讨潮玩手办盲盒前端项目模版的技术实现&a…

C++ | Leetcode C++题解之第201题数字范围按位与

题目&#xff1a; 题解&#xff1a; class Solution { public:int rangeBitwiseAnd(int m, int n) {while (m < n) {// 抹去最右边的 1n n & (n - 1);}return n;} };

序列检测器(Moore型)

目录 描述 输入描述&#xff1a; 输出描述&#xff1a; 参考代码 描述 请用Moore型状态机实现序列“1101”从左至右的不重叠检测。 电路的接口如下图所示。当检测到“1101”&#xff0c;Y输出一个时钟周期的高电平脉冲。 接口电路图如下&#xff1a; 输入描述&#xff1a…

携程任我行有什么用?

眼看一直到十月份都没啥假期了 五一出去玩买了几张携程的卡&#xff0c;想着买景点门票、酒店啥的能有优惠&#xff0c;但最后卡里的钱没用完不说&#xff0c;还有几张压根就没用出去 但是我又不想把卡一直闲置在手里&#xff0c;就怕过期了 最后在收卡云上99.1折出掉了&…

注意力机制在大语言模型中的应用

在大语言模型中&#xff0c;注意力机制&#xff08;Attention Mechanism&#xff09;用于捕获输入序列中不同标记&#xff08;token&#xff09;之间的关系和依赖性。这种机制可以动态地调整每个标记对当前处理任务的重要性&#xff0c;从而提高模型的性能。具体来说&#xff0…