训练 CNN 对 CIFAR-10 数据中的图像进行分类-keras实现

1. 加载 CIFAR-10 数据库

import keras
from keras.datasets import cifar10

# 加载预先处理的训练数据和测试数据
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

2. 可视化前 24 个训练图像

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

fig = plt.figure(figsize=(20,5))
for i in range(36):
    ax = fig.add_subplot(3, 12, i + 1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(x_train[i]))

3. 通过将每幅图像中的每个像素除以 255 来调整图像比例

事实上,代价函数的形状是一个碗,但如果特征的比例非常不同,它也可能是一个拉长的碗。下图显示了梯度下降法在特征 1 和特征 2 比例相同的训练集上的应用(左图),以及在特征 1 的值远小于特征 2 的训练集上的应用(右图)。

Tips : 使用梯度下降法时,应确保所有特征的比例相似,以加快训练速度,否则收敛时间会更长。

# rescale [0,255] --> [0,1]
x_train = x_train.astype('float32')/255
x_test = x_test.astype('float32')/255

 4. 将数据集分为训练集、测试集和验证集

from keras.utils import to_categorical

# 对标签进行一次热编码
num_classes = len(np.unique(y_train))
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

# 将训练集分为训练集和验证集
(x_train, x_valid) = x_train[5000:], x_train[:5000]
(y_train, y_valid) = y_train[5000:], y_train[:5000]

# 打印训练集的形状
print('x_train shape:', x_train.shape)

# 打印训练、验证和测试图像的数量
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')
print(x_valid.shape[0], 'validation samples')

5. 定义模型架构 

from keras.models import Sequential
from keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout

model = Sequential()
model.add(Conv2D(filters=16, kernel_size=2, padding='same', activation='relu', 
                        input_shape=(32, 32, 3)))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=32, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Conv2D(filters=64, kernel_size=2, padding='same', activation='relu'))
model.add(MaxPooling2D(pool_size=2))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(500, activation='relu'))
model.add(Dropout(0.4))
model.add(Dense(10, activation='softmax'))

model.summary()

6. 编译模型 

# compile the model
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

7. 训练模型

from keras.callbacks import ModelCheckpoint   

# 训练模型
checkpointer = ModelCheckpoint(filepath='model.weights.best.hdf5', verbose=1, save_best_only=True)

hist = model.fit(x_train, y_train, batch_size=32, epochs=100,
          validation_data=(x_valid, y_valid), callbacks=[checkpointer], 
          verbose=2, shuffle=True)

8. 加载验证精度最高的模型

# 加载验证精度最高的权重
model.load_weights('model.weights.best.hdf5')

 9. 计算测试集的分类精度

# 评估和打印测试精度
score = model.evaluate(x_test, y_test, verbose=0)
print('\n', 'Test accuracy:', score[1])

10. 可视化一些预测

这可能会让你对网络错误分类某些对象的原因有所了解。

# 在测试集上得到预测
y_hat = model.predict(x_test)

# 定义文本标签 (source: https://www.cs.toronto.edu/~kriz/cifar.html)
cifar10_labels = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
# 绘制测试图像的随机样本、预测标签和基本真实图像
fig = plt.figure(figsize=(20, 8))
for i, idx in enumerate(np.random.choice(x_test.shape[0], size=32, replace=False)):
    ax = fig.add_subplot(4, 8, i + 1, xticks=[], yticks=[])
    ax.imshow(np.squeeze(x_test[idx]))
    pred_idx = np.argmax(y_hat[idx])
    true_idx = np.argmax(y_test[idx])
    ax.set_title("{} ({})".format(cifar10_labels[pred_idx], cifar10_labels[true_idx]),
                 color=("green" if pred_idx == true_idx else "red"))

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/206992.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

出一个画质demo

#灵感# 画质demo, 适应的场景不一定多,但演示的功能却一个不少。简单列一下,短时间出一个画质demo的流程。 目录 1、基础检查 2、目标确认 3、调试 4、释放demo 参数 1、基础检查 分辨率、帧率、max gain、bit depth(输出raw图像位宽&a…

【Flink进阶】-- Flink kubernetes operator 快速入门与实战

1、课程目录 2、课程链接 https://edu.csdn.net/course/detail/38831

新型信息基础设施下的IP追溯技术:构建数字化安全新境界

随着新型信息基础设施的快速发展,IP(Internet Protocol)追溯技术在数字化安全领域变得愈发重要。IP追溯不仅能够帮助识别网络攻击源,提升网络安全水平,还有助于数字证据追踪、合规性审计等方面。本文将探讨新型信息基础…

Linux shell for jar test

Linux shell 脚本,循环解析命令行传入的所有参数,并按照不同的传参实现对不同的 java jar文件 进行测试执行。 [rootlocalhost demo]# cat connTest.sh #!/bin/bash# Linux shell for qftool java jar test# modes DEFAULT_MODE2jarfiles[1]common-1.0…

智慧城市管理的得力助手——无人机管理平台

近年来,随着无人机技术的迅猛发展,无人机管理平台应运而生,成为这一行业的重要基础设施。该平台充分利用智能无人机管理和物联网技术,实现了对城市重要领域的精细化动态管控。 一、无人机管理平台的特点 1、全方位数据获取&#…

Vue3使用LeaderLine

LeaderLine官方文档在这里 1.安装插件 npm install leader-line-vue 2.导出LeaderLine import LeaderLine from leader-line-vue; 3.创建连接线 let line LeaderLine.setLine(startElement, endElement, { startPlug: disc, endPlug: disc,color: white, size: 2 }); …

XSS骚操作

在网上看到的xss无字母的一个骚操作,可以先看一下下面这两个代码: 那么xss也是可以这样的,比如: 实战效果: 这里先将这个编码一下,直接塞到靶场看效果: alert(1)的编译代码: X[![]]!…

【Android】使用intent.putExtra()方法在启动Activity时传递数据

食用方法 在Android中,你可以使用Intent对象来在启动Activity时传递数据。以下是一个示例,展示了如何在startActivity时传递数据到被启动的Activity: 在启动Activity的地方,创建一个Intent对象,并使用putExtra()方法…

Java开发项目之KTV点歌系统设计和实现

项目介绍 本系统实现KTV点歌管理的信息化,可以方便管理员进行更加方便快捷的管理。系统的主要使用者分为管理员和普通用户。 管理员功能模块: 个人中心、用户管理、歌曲库管理、歌曲类型管理、点歌信息管理。 普通用户功能模块: 个人中心…

AC-DC 220V转12V 500毫安非隔离恒压恒流降压芯片

AC-DC 220V转12V 500毫安非隔离恒压恒流降压芯片是一款高性能的电源管理芯片,它能够将220V的交流电压降低到12V直流电压,并且具有恒压恒流输出、多模式控制、低待机功耗、高精度输出、内置软启动、多种保护功能等特点。 该芯片的非隔离系统恒压恒流输出可…

浅析什么是组态图和组态图设计

随着计算机技术和工业自动化水平迅速提高,而车间现场种类繁杂的控制设备和过程监控装置使得传统的工业控制软件无法满足用户的各种需求。在“组态”概念出现之前,工程技术人员需要通过编写程序来实现某一任务,不但工作量大、周期长&#xff0…

深度学习手势检测与识别算法 - opencv python 计算机竞赛

文章目录 0 前言1 实现效果2 技术原理2.1 手部检测2.1.1 基于肤色空间的手势检测方法2.1.2 基于运动的手势检测方法2.1.3 基于边缘的手势检测方法2.1.4 基于模板的手势检测方法2.1.5 基于机器学习的手势检测方法 3 手部识别3.1 SSD网络3.2 数据集3.3 最终改进的网络结构 4 最后…

揭秘各种编程语言在不同领域中的精彩表现

文章目录 🔊博主介绍🥤本文内容📢文章总结📥博主目标 🔊博主介绍 🌟我是廖志伟,一名Java开发工程师、Java领域优质创作者、CSDN博客专家、51CTO专家博主、阿里云专家博主、清华大学出版社签约作…

golang 函数选项模式

一 什么是函数选项模式 函数选项模式允许你使用接受零个或多个函数作为参数的可变构造函数来构建复杂结构。我们将这些函数称为选项,由此得名函数选项模式。 例子: 有业务实体Animal结构体,构造函数NewAnimal()&…

【opencv】计算机视觉基础知识

目录 前言 1、什么是计算机视觉 2、图片处理基础操作 2.1 图片处理:读入图像 2.2 图片处理:显示图像 2.3 图片处理:图像保存 3、图像处理入门基础 3.1 图像成像原理介绍 3.2 图像分类 3.2.1 二值图像 3.2.2灰度图像 3.2.3彩色图像…

Unity DOTS《群体战斗弹幕游戏》核心技术分析之3D角色动画

最近DOTS发布了正式的版本, 我们来分享现在流行基于群体战斗的弹幕类游戏,实现的核心原理。今天给大家介绍大规模战斗群体3D角色的动画如何来实现。 DOTS 对角色动画支持的局限性 截止到Unity DOTS发布的版本1.0.16,目前还是无法很好的支持3D角色动画。在DOTS 的b…

iOS系统上待办事项提醒软件哪个好

在这个快节奏的生活中,各种待办事项充斥了我们的日常工作和生活,尤其对于像我这样的iPhone用户而言,一款能够在iOS系统上快速和准确记录和提醒待办事项的软件,显得至关重要。 正如前几天,我正沉浸在工作中的时突然被领…

采购流程的简要概述

外部采购流程 一般来讲,企业的采购业务一般是对外采购活动,一个比较典型采购业务循环通常包括:需求提报、货源确定和供应商选择、采购订单处理、采购订单状态跟踪监控、到厂收货、发票校验、付款。 以下对几个节点进行详细的解释&#xff…

亚信科技AntDB数据库完成中国信通院数据库迁移工具专项测试

近日,在中国信通院“可信数据库”数据库迁移工具专项测试中,湖南亚信安慧科技有限公司(简称:亚信安慧科技)数据库数据同步平台V2.1产品依据《数据库迁移工具能力要求》、结合亚信科技AntDB分布式关系型数据库产品&…

JavaScript添加快捷键、取消浏览器默认的快捷操作、js查看键盘按钮keycode值

document.addEventListener("keydown",function (event) {// 如果不知道按键对应的数字(keyCode)是多少可以弹出查看一下// alert(event.keyCode)if (event.ctrlKey && event.altKey && event.view["0"] null){if(…