基于CNN对彩色图像数据集CIFAR-10实现图像分类--keras框架实现

项目地址(kaggle):基于CNN对彩色图像数据集CIFAR-10实现图像分类--keras | Kaggle

项目地址(Colab):https://colab.research.google.com/drive/1gjzglPBfQKuhfyT3RlltCLUPgfccT_G9

 导入依赖

在tensorflow-keras-gpu环境中导入下面依赖:

from keras.datasets import cifar10

from keras import regularizers
from keras.callbacks import ModelCheckpoint
from keras.layers import Conv2D, Activation, BatchNormalization, MaxPooling2D, Dropout, Flatten, Dense
from keras.models import Sequential
from keras.preprocessing.image import ImageDataGenerator
from matplotlib import pyplot
from keras import optimizers
import numpy as np

 准备训练数据

本次实验使用的是keras提供的CIFAER-10数据集,这些数据集是经过预处理,基本可以当作神经网络的输入直接使用,其中包含5000张32x32大小的彩色训练图像和超过10个类别的标注,以及10000张测试图像。

打印数据集
Keras提供的CIFAR-10数据集已被划分为训练集和测试集,并打印测试集和训练集的形状。

# download and split the data
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

print("training data = ", x_train.shape)
print("testing data = ", x_test.shape)

 数据归一化处理

要对图像的像素值进行归一化处理,应将每个像素减去平均值并以所得结果除以标准差。

# Normalize the data to speed up training
mean = np.mean(x_train)
std = np.std(x_train)
x_train = (x_train-mean)/(std+1e-7)
x_test = (x_test-mean)/(std+1e-7)

# let's look at the normalized values of a sample image
x_train[0]

对标签进行one-hot编码 

# one-hot encode the labels in train and test datasets
# we use “to_categorical” function in keras
from keras.utils import to_categorical
num_classes = 10
y_train = to_categorical(y_train,num_classes)
y_test = to_categorical(y_test,num_classes)

# let's display one of the one-hot encoded labels
y_train[0]

构建模型架构 

模型的网络结构配置如下:

(1)之前在一个卷积层后面加一个池化层,而在全新的架构中,将在每两个卷积层后面加一个池化层,这个想法是受到VGGNet的启发

(2)这里的卷积层的dilation_rate设置为3x3,并将池化层的pool_size设置为2x2。

(3)每隔一个卷积层就添加dropout层,舍弃率p的取值为0.2-0.4

(4)在Keras中,L2正则化被添加到卷积层中

# build the model

# number of hidden units variable
# we are declaring this variable here and use it in our CONV layers to make it easier to update from one place
base_hidden_units = 32

# l2 regularization hyperparameter
weight_decay = 1e-4

# instantiate an empty sequential model
model = Sequential()

# CONV1
# notice that we defined the input_shape here because this is the first CONV layer.
# we don’t need to do that for the remaining layers
model.add(Conv2D(base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay), input_shape=x_train.shape[1:]))
model.add(Activation('relu'))
model.add(BatchNormalization())

# CONV2
model.add(Conv2D(base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.2))

# CONV3
model.add(Conv2D(2*base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())

# CONV4
model.add(Conv2D(2*base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.3))

# CONV5
model.add(Conv2D(4*base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())

# CONV6
model.add(Conv2D(4*base_hidden_units, (3,3), padding='same', kernel_regularizer=regularizers.l2(weight_decay)))
model.add(Activation('relu'))
model.add(BatchNormalization())
model.add(MaxPooling2D(pool_size=(2,2)))
model.add(Dropout(0.4))

# FC7
model.add(Flatten())
model.add(Dense(num_classes, activation='softmax'))

# print model summary
model.summary()

模型摘要如下:

 数据增强

本实验将随意采用旋转、高度、和宽度变换、水平翻转等数据增强技术。处理问题时,请检查看网络没有进行分类或分类结果较差的图像,并尝试理解网络在这些图像上表现不佳的原因,然后提出改进假设并进行试验。分析、试验、评估并重复这个过程,通过纯粹的数据分析和对网络性能的理解来做出决定

# data augmentation
datagen = ImageDataGenerator(
    rotation_range=15,
    width_shift_range=0.1,
    height_shift_range=0.1,
    horizontal_flip=True,
    vertical_flip=False
    )

# compute the data augmentation on the training set
datagen.fit(x_train)

训练模型 

训练模型之前先讨论一些超参数的设置策略。

(1)batch_size:batch_size越大,算法学习的越快。可将初始值设置为64,然后将该值翻倍来加速训练。

(2)epochs:开始时将值设为50,但是发现网络仍在改进,所以不断则更加训练轮数并观察训练结果

(3)optimizer:本实验实验了Adam优化器。因新版本的keras很多优化器找不到配置文件的问题,最终解决Adam优化器配置的问题。


# training
from tensorflow.keras.optimizers import legacy
batch_size = 128
epochs=200

checkpointer = ModelCheckpoint(filepath='model.125epochs.hdf5', verbose=1, save_best_only=True)

# you can try any of these optimizers by uncommenting the line
#optimizer = rmsprop(lr=0.001,decay=1e-6)
optimizer = legacy.Adam(learning_rate=0.0001,decay=1e-6)

#optimizer =keras.optimizers.rmsprop(lr=0.0003,decay=1e-6)
model.compile(loss='categorical_crossentropy', optimizer=optimizer, metrics=['accuracy'])
history = model.fit(datagen.flow(x_train, y_train, batch_size=batch_size), callbacks=[checkpointer],
                steps_per_epoch=x_train.shape[0] // batch_size, epochs=epochs,verbose=2,
                validation_data=(x_test,y_test))

评估模型 

调用Keras的evalute函数来评估模型并打印结果

# evaluating the model
scores = model.evaluate(x_test, y_test, batch_size=128, verbose=1)
print('\nTest result: %.3f loss: %.3f' % (scores[1]*100,scores[0]))

打印学习曲线,分析训练性能

# plot learning curves of model accuracy

pyplot.plot(history.history['accuracy'], label='train')
pyplot.plot(history.history['val_accuracy'], label='test')
pyplot.legend()
pyplot.show()

 

调参

为了提升模型的结果,需要对模型进一步改进:

(1)增加训练的轮数:通过上述效果可以得出模型在125轮之前一直在增加,可将模型的训练轮数进行进一步增加。

(2)使用更深的网络结构:尝试添加更多层来提升模型的复杂度,以增强其学习能力。

(3)降低学习率:通过降低学习率learning_rate的方式使其模型使用更长的时间去学习。

(4)使用不同的CNN架构。

最终我们经过多次调参得到如下结果

序号

batch_size

epochs

learning_rate

Test result

学习曲线

1

128

125

0.0001

86.560

2

128

200

0.0001

87.360

3

256

200

0.001

86.930

4

256

200

0.0001

88.120

5

256

200

0.0003

87.820

 我们经过了五次实验发现当batch_size=256,epochs=200,learning_rate=0.0001的时候,Test result最高,分类效果最好,当然可以继续尝试添加更多层来提升模型的复杂度,以增强其学习能力。

异常问题与解决方案

1、报错:Failed to get convolution algorithm. cudnn failed to initialize

解决办法:在模型前面加上这几句话,意思大概也是运行内存增加

physical_devices = tf.config.experimental.list_physical_devices('GPU')

if len(physical_devices) > 0:

     for k in range(len(physical_devices)):

         tf.config.experimental.set_memory_growth(physical_devices[k], True)

     print('memory growth:', tf.config.experimental.get_memory_growth(physical_devices[k]))

else:

     print("Not enough GPU hardware devices available")

2、报错:lr "参数已被弃用,请使用 "learning_rate "参数。 super().__init__(name, **kwargs)

解决办法:将lr换为learning_rate

3、报错:`Model.fit_generator` is deprecated and will be removed in a future version. Please use `Model.fit`, which supports generators.

解决办法:将Model.fit_generator改为Model.fit

4、报错:No module named ‘adam’

解决办法:将keras.optimizers.adam改为legacy.Adam,并重新导入legacy包

5、报错:Image transformations require SciPy. Install SciPy.

解决办法:重新安装SciPy

6、报错----> 3 pyplot.plot(history.history['acc'], label='train')

      4 pyplot.plot(history.history['val_acc'], label='test')

      5 pyplot.legend()

KeyError: 'acc'

解决办法:将acc替换为accuracy;val_acc替换为val_accuracy

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/214816.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

第一百八十八回 分享三个使用TextField的细节

文章目录 1. 概念介绍2. 使用方法2.1 修改组件的填充颜色2.2 修改组件的高度2.3 给组件添加圆角3. 示例代码4. 内容总结我们在上一章回中介绍了"DropdownButton组件"相关的内容,本章回中将介绍**TextField组件的细节.**闲话休提,让我们一起Talk Flutter吧。 1. 概念…

EasyRecovery易恢复2024最新免费版电脑数据恢复软件功能介绍

EasyRecovery从(易恢复2024)支持恢复不同存储介质数据,在Windows中恢复受损和删除文件,以及能检索数据格式化或损坏卷,甚至还可以从初始化磁盘。同时,你只需要最简单的操作就可以恢复数据文件,如&#xff1…

YITH Product Shipping for WooCommerce商城产品配送运输插件

点击访问原文 YITH Product Shipping for WooCommerce商城产品配送运输插件 - 易服客工作室 YITH Product Shipping for WooCommerce商城产品配送运输插件根据商店的每个产品处理不同的运费,例如您可以为每个州、地区或城市设置不同的费用。 根据店铺的单品处理不…

搭建 ebpf 开发测试环境

0 内容说明 这部分主要讲述了如何通过官网学习ebpf,以及如何搭建自己的ebpf开发测试环境,主要是需要安装哪些工具链。 1 ebpf在线学习 ebpf官网中提供了一个快速在线学习ebpf的路径,在这个学习平台中一共有两项学习内容,一个是…

在Spring Boot中隔离@Async异步任务的线程池

在异步任务执行的时候,我们知道其背后都有一个线程池来执行任务,但是为了控制异步任务的并发不影响到应用的正常运作,我们需要对线程池做好相关的配置,以防资源过度使用。这个时候我们就考虑将线程池进行隔离了。 那么我们为啥要…

高校人员信息管理系统C++

代码:https://mbd.pub/o/bread/ZZeZk5lx 一、基本内容论述 1、问题描述 某高校有四类员工:教师、实验员、行政人员、教师兼行政人员;共有的信息包括:编号、姓名、性别、年龄等。其中,教师还包含的信息有:所…

2023年GopherChina大会-核心PPT资料下载

一、峰会简介 自 Go 语言诞生以来,中国便是其应用最早和最广的国家之一,根据 Jetbrains 在 2021 年初做的调查报告,总体来说目前大概有 110 万专业的开发者 选择 Go 作为其主要开发语言。就其全球分布而言, 居住在亚洲的开发者最多&#xff…

矩阵元素求和:按行、按列、所有元素np.einsum()

【小白从小学Python、C、Java】 【计算机等考500强证书考研】 【Python-数据分析】 矩阵元素求和: 按行、按列、所有元素 np.einsum() [太阳]选择题 下列说法正确的是: import numpy as np A np.array([[1, 2],[3, 4]]) print("【显示】A") p…

为何要3次握手?TCP协议的稳定性保障机制

🚀 作者主页: 有来技术 🔥 开源项目: youlai-mall 🍃 vue3-element-admin 🍃 youlai-boot 🌺 仓库主页: Gitee 💫 Github 💫 GitCode 💖 欢迎点赞…

绝地求生在steam叫什么?

绝地求生在Steam的全名是《PlayerUnknowns Battlegrounds》,简称为PUBG。作为一款风靡全球的多人在线游戏,PUBG于2017年3月23日正式上线Steam平台,并迅速成为一部热门游戏。 PUBG以生存竞技为核心玩法,玩家将被投放到一个辽阔的荒…

布林线BOLL的实战应用技巧

一、认识布林线BOLL 布林线BOLL,又称布林带,是股市中非常常用的一个技术指标。 以金斗云智投电脑版软件为例,任意打开一支个股,选择BOLL指标,在K线区域就可以看到上中下排列的3条线,这3条线就组成了布林带。…

【Three.js】创建CAD标注线

目录 📖前言 🐞创建箭头对象 🐬创建文字 👻箭头两端的线段 ✈️封装方法 📖前言 CAD标注线在工程和制造领域中被广泛用于标记零部件、装配体和机械系统的尺寸、距离、角度等信息。它们帮助工程师和设计师更好地理…

web自动化 -- selenium及应用

selenium简介 随着互联网的发展,前端技术不断变化,数据加载方式也不再是通过服务端渲染。现在许多网站使用接口或JSON数据通过JavaScript进行渲染。因此,使用requests来爬取内容已经不再适用,因为它只能获取服务器端网页的源码&am…

计算机网络体系的形成

目录 1、开放系统互连参考模型OSI/RM 2、两种国际标准 3、协议与划分层次 4、网络协议的三要素 5、划分层次 (1)文件发送模块使两个主机交换文件 (2)通信服务模块 (3)接入网络模块 6、分层带来的好…

Linux下Python调用C语言

一:Python调用C语言场景 1,已经写好的C语言代码,不容易用Python实现,想直接通过Python调用写好的C语言代码 2,C比Python快(只是从语言层面,不能绝对说C程序就是比Python快) 3&…

docker配置redis插件

docker配置redis插件 运行容器redis_6390 docker run -it \ --name redis_6390 \ --privileged \ -p 6390:6379 \ --network wn_docker_net \ --ip 172.18.12.19 \ --sysctl net.core.somaxconn1024 \ -e TIME_ZONE"Asia/Shanghai" -e TZ"Asia/Shanghai"…

电磁兼容EMC理论基础汇总

目录 0. 序言 1. EMC的基础介绍 1.1 EMC电磁兼容的定义 1.2 EMC的重要性 1.3 EMC的三要素 2. 库仑定律 3. 趋肤效应与趋肤深度 4. 电阻抗公式 4.1 电阻 4.2 容抗 4.3 感抗 4.4 电路元件的非理想性 5. 麦克斯韦方程组 5.1 高斯磁定律 5.2 高斯定律 5.3 法拉…

【FPGA】Verilog:二进制并行加法器 | 超前进位 | 实现 4 位二进制并行加法器和减法器 | MSI/LSI 运算电路

Ⅰ. 前置知识 0x00 并行加法器和减法器 如果我们要对 4 位加法器和减法器进行关于二进制并行运算功能,可以通过将加法器和减法器以 N 个并行连接的方式,创建一个执行 N 位加法和减法运算的电路。 4 位二进制并行加法器 4 位二进制并行减法器 换…

YoloV8改进策略:Swift Parameter-free Attention,无参注意力机制,超分模型的完美迁移

摘要 https://arxiv.org/pdf/2311.12770.pdf https://github.com/hongyuanyu/SPAN SPAN是一种超分网络模型。SPAN模型通过使用参数自由的注意力机制来提高SISR的性能。这种注意力机制能够增强重要信息并减少冗余,从而在图像超分辨率过程中提高图像质量。 具体来说,SPAN模…

J-Link RTT的使用(原理 + 教程 + 应用 + 代码)

MCU:STM32F407VE MDK:5.29 IAR:8.32 目录--点击可快速直达 目录 写在前面什么是RTT?RTT的工作原理RTT的性能快速使用教程高级使用教程附上测试代码2019年12月27日更新--增加打印float的功能 写在前面 本文介绍了J-Link RTT的部分使用内容,很多地方参考和使用…