LeNet对MNIST 数据集中的图像进行分类--keras实现

我们将训练一个卷积神经网络来对 MNIST 数据库中的图像进行分类,可以与前面所提到的CNN实现对比CNN对 MNIST 数据库中的图像进行分类-CSDN博客

加载 MNIST 数据库

MNIST 是机器学习领域最著名的数据集之一。

  • 它有 70,000 张手写数字图像 - 下载非常简单 - 图像尺寸为 28x28 - 灰度图像
from keras.datasets import mnist

# 使用 Keras 导入预洗牌 MNIST 数据库
(X_train, y_train), (X_test, y_test) = mnist.load_data()

print("The MNIST database has a training set of %d examples." % len(X_train))
print("The MNIST database has a test set of %d examples." % len(X_test))

将前六个训练图像可视化 

import matplotlib.pyplot as plt
%matplotlib inline
import matplotlib.cm as cm
import numpy as np

# 绘制前六幅训练图像
fig = plt.figure(figsize=(20,20))
for i in range(6):
    ax = fig.add_subplot(1, 6, i+1, xticks=[], yticks=[])
    ax.imshow(X_train[i], cmap='gray')
    ax.set_title(str(y_train[i]))

 查看图像的更多细节

def visualize_input(img, ax):
    ax.imshow(img, cmap='gray')
    width, height = img.shape
    thresh = img.max()/2.5
    for x in range(width):
        for y in range(height):
            ax.annotate(str(round(img[x][y],2)), xy=(y,x),
                        horizontalalignment='center',
                        verticalalignment='center',
                        color='white' if img[x][y]<thresh else 'black')

fig = plt.figure(figsize = (12,12)) 
ax = fig.add_subplot(111)
visualize_input(X_train[0], ax)

预处理输入图像:通过将每幅图像中的每个像素除以 255 来调整图像比例

# normalize the data to accelerate learning
mean = np.mean(X_train)
std = np.std(X_train)
X_train = (X_train-mean)/(std+1e-7)
X_test = (X_test-mean)/(std+1e-7)

print('X_train shape:', X_train.shape)
print(X_train.shape[0], 'train samples')
print(X_test.shape[0], 'test samples')

对标签进行预处理:使用单热方案对分类整数标签进行编码

from keras.utils import np_utils

num_classes = 10 
# print first ten (integer-valued) training labels
print('Integer-valued labels:')
print(y_train[:10])

# one-hot encode the labels
# convert class vectors to binary class matrices
y_train = np_utils.to_categorical(y_train, num_classes)
y_test = np_utils.to_categorical(y_test, num_classes)

# print first ten (one-hot) training labels
print('One-hot labels:')
print(y_train[:10])

重塑数据以适应我们的 CNN(和 input_shape)

# input image dimensions 28x28 pixel images. 
img_rows, img_cols = 28, 28

X_train = X_train.reshape(X_train.shape[0], img_rows, img_cols, 1)
X_test = X_test.reshape(X_test.shape[0], img_rows, img_cols, 1)
input_shape = (img_rows, img_cols, 1)

print('image input shape: ', input_shape)
print('x_train shape:', X_train.shape)

定义模型架构

论文地址:lecun-01a.pdf

要在 Keras 中实现 LeNet-5,请阅读原始论文并从第 6、7 和 8 页中提取架构信息。以下是构建 LeNet-5 网络的主要启示:

  • 每个卷积层的滤波器数量:从图中(以及论文中的定义)可以看出,每个卷积层的深度(滤波器数量)如下:C1 = 6、C3 = 16、C5 = 120 层。
  • 每个 CONV 层的内核大小:根据论文,内核大小 = 5 x 5
  • 每个卷积层之后都会添加一个子采样层(POOL)。每个单元的感受野是一个 2 x 2 的区域(即 pool_size = 2)。请注意,LeNet-5 创建者使用的是平均池化,它计算的是输入的平均值,而不是我们在早期项目中使用的最大池化层,后者传递的是输入的最大值。如果您有兴趣了解两者的区别,可以同时尝试。在本实验中,我们将采用论文架构。
  • 激活函数:LeNet-5 的创建者为隐藏层使用了 tanh 激活函数,因为对称函数被认为比 sigmoid 函数收敛更快。一般来说,我们强烈建议您为网络中的每个卷积层添加 ReLU 激活函数。

需要记住的事项

  • 始终为 CNN 中的 Conv2D 层添加 ReLU 激活函数。除了网络中的最后一层,密集层也应具有 ReLU 激活函数。
  • 在构建分类网络时,网络的最后一层应该是具有软最大激活函数的密集(FC)层。最终层的节点数应等于数据集中的类别总数。
from keras.models import Sequential
from keras.layers import Conv2D, AveragePooling2D, Flatten, Dense
#Instantiate an empty model
model = Sequential()

# C1 Convolutional Layer
model.add(Conv2D(6, kernel_size=(5, 5), strides=(1, 1), activation='tanh', input_shape=input_shape, padding='same'))

# S2 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))

# C3 Convolutional Layer
model.add(Conv2D(16, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))

# S4 Pooling Layer
model.add(AveragePooling2D(pool_size=(2, 2), strides=2, padding='valid'))

# C5 Fully Connected Convolutional Layer
model.add(Conv2D(120, kernel_size=(5, 5), strides=(1, 1), activation='tanh', padding='valid'))

#Flatten the CNN output so that we can connect it with fully connected layers
model.add(Flatten())

# FC6 Fully Connected Layer
model.add(Dense(84, activation='tanh'))

# Output Layer with softmax activation
model.add(Dense(10, activation='softmax'))

# print the model summary
model.summary()

编译模型

我们将使用亚当优化器

# the loss function is categorical cross entropy since we have multiple classes (10) 


# compile the model by defining the loss function, optimizer, and performance metric
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

训练模型

LeCun 和他的团队采用了计划衰减学习法,学习率的值按照以下时间表递减:前两个历元为 0.0005,接下来的三个历元为 0.0002,接下来的四个历元为 0.00005,之后为 0.00001。在论文中,作者对其网络进行了 20 个历元的训练。

from keras.callbacks import ModelCheckpoint, LearningRateScheduler

# set the learning rate schedule as created in the original paper
def lr_schedule(epoch):
    if epoch <= 2:     
        lr = 5e-4
    elif epoch > 2 and epoch <= 5:
        lr = 2e-4
    elif epoch > 5 and epoch <= 9:
        lr = 5e-5
    else: 
        lr = 1e-5
    return lr

lr_scheduler = LearningRateScheduler(lr_schedule)

# set the checkpointer
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, 
                               save_best_only=True)

# train the model
hist = model.fit(X_train, y_train, batch_size=32, epochs=20,
          validation_data=(X_test, y_test), callbacks=[checkpointer, lr_scheduler], 
          verbose=2, shuffle=True)

在验证集上加载分类准确率最高的模型

# load the weights that yielded the best validation accuracy
model.load_weights('model.weights.best.hdf5')

计算测试集的分类准确率

# evaluate test accuracy
score = model.evaluate(X_test, y_test, verbose=0)
accuracy = 100*score[1]

# print test accuracy
print('Test accuracy: %.4f%%' % accuracy)

评估模型

import matplotlib.pyplot as plt

f, ax = plt.subplots()
ax.plot([None] + hist.history['accuracy'], 'o-')
ax.plot([None] + hist.history['val_accuracy'], 'x-')
# 绘制图例并自动使用最佳位置: loc = 0。
ax.legend(['Train acc', 'Validation acc'], loc = 0)
ax.set_title('Training/Validation acc per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('acc')
plt.show()

import matplotlib.pyplot as plt

f, ax = plt.subplots()
ax.plot([None] + hist.history['loss'], 'o-')
ax.plot([None] + hist.history['val_loss'], 'x-')

# Plot legend and use the best location automatically: loc = 0.
ax.legend(['Train loss', "Val loss"], loc = 0)
ax.set_title('Training/Validation Loss per Epoch')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
plt.show()

 

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/209325.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

CMMI认证含金量高吗

一、CMMI认证含金量解答 CMMI&#xff0c;即能力成熟度模型集成&#xff0c;是由美国卡内基梅隆大学软件工程研究所开发的一种评估企业软件开发过程成熟度的模型。CMMI认证的含金量究竟高不高呢&#xff1f;答案是肯定的。CMMI认证被誉为软件开发行业的“金牌标准”&#xff0…

vue循环v-for遍历图表

循环遍历图表 index.vue主页面 <view v-if"powerPage"><view v-for"(item, index) in powerDetailsData.addMap" :key"index"><PowerEChartsCity:echartData"powerDetailsData.addMap[index]"></PowerEChartsC…

经典神经网络——VGGNet模型论文详解及代码复现

论文地址&#xff1a;1409.1556.pdf。 (arxiv.org)&#xff1b;1409.1556.pdf (arxiv.org) 项目地址&#xff1a;Kaggle Code 一、背景 ImageNet Large Scale Visual Recognition Challenge 是李飞飞等人于2010年创办的图像识别挑战赛&#xff0c;自2010起连续举办8年&#xf…

Unity 关于SpriteRenderer 和正交相机缩放

float oldWidth 750f;float oldHeight 1334f;float newWidth Screen.width;float newHeight Screen.height;float oldAspect oldWidth / oldHeight;float newAspect newWidth / newHeight;//水平方向缩放float horizontalCompressionRatio newAspect / oldAspect;//垂直…

Python 中的 FileSystem Connector:打通文件系统的便捷通道

更多Python学习内容&#xff1a;ipengtao.com 大家好&#xff0c;我是涛哥&#xff0c;今天为大家分享 Python 中的 FileSystem Connector&#xff1a;打通文件系统的便捷通道&#xff0c;全文4100字&#xff0c;阅读大约11分钟。 在现代软件开发中&#xff0c;文件系统是不可或…

【Python表白系列】一起去看流星雨吧!(完整代码)

文章目录 流星雨环境需求完整代码详细分析系列文章流星雨 环境需求 python3.11.4PyCharm Community Edition 2023.2.5pyinstaller6.2.0(可选,这个库用于打包,使程序没有python环境也可以运行,如果想发给好朋友的话需要这个库哦~)【注】 python环境搭建请见:https://want5…

数据结构算法-选择排序算法

引言 说起排序算法&#xff0c;那可就多了去&#xff0c;首先了解什么叫排序 以B站为例&#xff1a; 蔡徐坤在B站很受欢迎呀&#xff0c;先来看一下综合排序 就是播放量和弹幕量&#xff0c;收藏量 一键三连 都很高这是通过一些排序算法 才能体现出综合排序 蔡徐坤鬼畜 按照播…

flutter-一个可以输入的数字增减器

效果 参考文章 代码 在参考文章上边&#xff0c;主要是改了一下样式&#xff0c;逻辑也比较清楚&#xff0c;对左右两边添加增减方法。 我在此基础上加了_numcontroller 输入框的监听。 加了数字输入框的控制 keyboardType: TextInputType.number, //设置键盘为数字 inputF…

异常处理啊

异常处理 异常 程序运行过程中&#xff0c;发生错误导致异常退出&#xff08;不是程序的语法问题&#xff0c;而是代码的逻辑问题&#xff0c;编译不出错&#xff09;。 e.g. string 字符串&#xff0c;使用 at 函数访问其中的字符元素时&#xff0c;如果越界&#xff0c;程…

Next.js初步使用

文章目录 安装和运行页面静态文件 React初步&#xff0c;但不熟悉React也可以学习本文。 安装和运行 Next.js是一个基于React的服务端渲染框架&#xff0c;可以实现构建高性能、可扩展的React应用&#xff0c;提供了很多方便的工具和功能&#xff0c;包括自动代码分割、服务器…

rdf-file:SM2加解密

一&#xff1a;SM2简介 SM2是中国密码学算法标准中的一种非对称加密算法&#xff08;包括公钥和私钥&#xff09;。SM2主要用于数字签名、密钥交换和加密解密等密码学。 生成秘钥&#xff1a;用于生成一对公钥和私钥。公钥&#xff1a;用于加密数据和验证数字签名。私钥&…

【代码】两阶段鲁棒优化/微电网经济调度入门到编程

内容包括 matlab-yalmipcplex微电网两阶段鲁棒经济调度&#xff08;刘&#xff09; matlab-yalmipcplex两阶段鲁棒微电网容量经济优化调度 两阶段鲁棒优化CCG列于约束生成和Benders代码&#xff0c;可扩展改编&#xff0c;复现自原外文论文 【赠送】虚拟储能单元电动汽车建…

【二叉树】常见题目解析(2)

题目1&#xff1a;104. 二叉树的最大深度 - 力扣&#xff08;LeetCode&#xff09; 题目1描述&#xff1a; 题目1分析及解决&#xff1a; &#xff08;1&#xff09;base case&#xff1a;当前节点为null时&#xff0c;以当前节点为根节点的树最大深度是0。 &#xff08;2&…

【C/PTA —— 13.指针2(课外实践)】

C/PTA —— 13.指针2&#xff08;课外实践&#xff09; 一.函数题6-1 鸡兔同笼问题6-2 冒泡排序6-3 字符串反正序连接6-4 计算最长的字符串长度6-5 查找星期 二.编程题7-1 C程序设计 实验5-7 数组指针作函数参数7-2 查找奥运五环色的位置 一.函数题 6-1 鸡兔同笼问题 int Chic…

Nginx反向代理详解

Nginx反向代理详解 nginx反向代理是一种常用的服务器架构设计方案&#xff0c;其原理是将客户端请求先发送到反向代理服务器&#xff0c;反向代理服务器再将请求转发到后端真实服务器处理&#xff0c;并将处理结果返回给客户端&#xff0c;从而实现负载均衡、高可用、安全和减…

VSC++: 双进制回文

缘由双进制回文数&#xff0c;一道C程序题&#xff0c;求解&#xff01;&#xff01;&#xff01;&#xff1f;_编程语言-CSDN问答 int 合成100回文(int 数) { int 合 0, 倒 数>10 && 数 < 100 ? 数 / 10 : 数;while (倒)合 * 10, 合 倒 % 10, 倒 / 10, (合…

.net framwork4.6操作MySQL报错Character set ‘utf8mb3‘ is not supported 解决方法

文章目录 .net framwork4.6操作MySQL报错Character set ‘utf8mb3‘ is not supported 解决方法详细报错内容解决方案修改数据修改表修改字段 .net framwork4.6操作MySQL报错Character set ‘utf8mb3‘ is not supported 解决方法 详细报错内容 System.NotSupportedException…

Elasticsearch分词器--空格分词器(whitespace analyzer)

介绍 文本分析&#xff0c;是将全文本转换为一系列单词的过程&#xff0c;也叫分词。analysis是通过analyzer(分词器)来实现的&#xff0c;可以使用Elasticearch内置的分词器&#xff0c;也可以自己去定制一些分词器。除了在数据写入时将词条进行转换&#xff0c;那么在查询的时…

使用策略模式彻底消除if-else

文章目录 使用策略模式彻底消除if-else1. 场景描述2. if-else方式3. 策略模式 使用策略模式彻底消除if-else 如果一个对象有很多的行为&#xff0c;如果不用恰当的模式&#xff0c;这些行为就只好使用多重的条件选择语句来实现&#xff0c;这样会显得代码逻辑很臃肿&#xff0c…

C++学习之路(十五)C++ 用Qt5实现一个工具箱(增加16进制颜色码转换和屏幕颜色提取功能)- 示例代码拆分讲解

上篇文章&#xff0c;我们用 Qt5 实现了在小工具箱中添加了《Base64图片编码预览功能》功能。为了继续丰富我们的工具箱&#xff0c;今天我们就再增加两个平时经常用到的功能吧&#xff0c;就是「 16进制颜色码转RGB文本 」和 「屏幕颜色提取」功能。下面我们就来看看如何来规划…