深度学习——图像分类(CNN)—训练模型

训练模型

    • 1.导入必要的库
    • 2.定义超参数
    • 3.读取训练和测试标签CSV文件
    • 4.确保标签是字符串类型
    • 5.显示两个数据框的前几行以了解它们的结构
    • 6.定义图像处理参数
    • 7.创建图像数据生成器
    • 8.设置目录路径
    • 9.创建训练和验证数据生成器
    • 10.构建模型
    • 11.编译模型
    • 12.训练模型并收集历史
    • 13.绘制损失和准确率曲线
    • 14.保存图表
    • 15.保存模型到本地

1.导入必要的库

pandas as pd: Pandas是一个强大的数据分析和处理库,它提供了数据结构(如DataFrame)和工具,用于数据操作和分析。
tensorflow.keras.preprocessing.image import ImageDataGenerator: ImageDataGenerator是Keras的一部分,它用于图像数据的预处理和增强,例如,随机裁剪、旋转、缩放等。
tensorflow.keras.models import Sequential: Sequential模型是Keras中的一种模型,它允许您顺序地堆叠层。
tensorflow.keras.layers: 包含了Keras中所有的层类型,如Conv2D、MaxPooling2D、Flatten、Dense等。
tensorflow.keras.optimizers: 包含了Keras中所有的优化器类型,如Adam、SGD等。
sklearn.model_selection import train_test_split: train_test_split是Scikit-Learn的一部分,它用于将数据集分割为训练集和测试集。
numpy as np: NumPy是一个用于科学计算的库,它提供了高效的数组处理能力,对于图像处理等任务非常有用。
sklearn.preprocessing import LabelBinarizer: LabelBinarizer是Scikit-Learn的一部分,它用于将类别标签转换为二进制数组。
matplotlib.pyplot as plt: Matplotlib是一个绘图库,pyplot是其中的一个模块,它提供了一个类似于MATLAB的绘图框架。
import pickle: pickle是Python的标准库,它用于序列化Python对象,以便将它们保存到文件或从文件中加载。

import pandas as pd
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from sklearn.model_selection import train_test_split
import numpy as np
from sklearn.preprocessing import LabelBinarizer
import matplotlib.pyplot as plt
import pickle

2.定义超参数

INIT_LR = 0.01
EPOCHS = 30
BS = 32

3.读取训练和测试标签CSV文件

train_labels.csv和test_labels.csv在资源中。

# 读取训练标签CSV文件
train_labels_filename = 'train_labels.csv'
train_labels_df = pd.read_csv(train_labels_filename)

# 读取测试标签CSV文件
test_labels_filename = 'test_labels.csv'
test_labels_df = pd.read_csv(test_labels_filename)

4.确保标签是字符串类型

train_labels_df[‘label’] = train_labels_df[‘label’].astype(str):

train_labels_df['label']:这是train_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

test_labels_df[‘label’] = test_labels_df[‘label’].astype(str):

test_labels_df['label']:这是test_labels_df DataFrame中名为label的列。
.astype(str):这是Pandas中的一个方法,用于将列的数据类型转换为字符串类型。

train_labels_df['label'] = train_labels_df['label'].astype(str)
test_labels_df['label'] = test_labels_df['label'].astype(str)

5.显示两个数据框的前几行以了解它们的结构

print(train_labels_df.head())
print(test_labels_df.head())

6.定义图像处理参数

img_width:这是一个变量,用于存储图像的宽度。
img_height:这是一个变量,用于存储图像的高度。
= 150, 150:这行代码将img_width和img_height变量分别设置为150。

img_width, img_height = 150, 150

7.创建图像数据生成器

ImageDataGenerator:这是Keras中的一个类,用于创建一个数据生成器,用于图像数据的增强和预处理。
rescale=1./255:这是一个参数,用于将图像的像素值从0到255的范围转换为0到1的范围,这是常见的图像预处理步骤。
validation_split=0.2:这是一个参数,用于指定训练数据中用于验证的比例。在这里,20%的数据将用于验证,80%的数据将用于训练。
data_gen:这是生成的ImageDataGenerator对象,它将在后续的训练过程中用于生成增强的图像数据。

data_gen = ImageDataGenerator(rescale=1./255, validation_split=0.2)

8.设置目录路径

train和test压缩文件在资源中

# 并且数据集应该存储在环境可访问的路径中
train_dir = 'D:/rgzn/face/DATASET/train'  # 包含子文件夹的父目录
test_dir = 'D:/rgzn/face/DATASET/test'    # 包含子文件夹的父目录

9.创建训练和验证数据生成器

#flow_from_dataframe:这是Keras中的一个方法,用于创建一个数据生成器,它可以从DataFrame中加载图像和标签。
train_data_gen = data_gen.flow_from_dataframe(

#要加载的数据源
dataframe=train_labels_df,
#包含图像文件的目录
directory=train_dir,  
#DataFrame中包含图像路径的列名。
x_col='image',
#DataFrame中包含标签的列名。
y_col='label',
#目标图像的大小
target_size=(img_width, img_height),
#每次迭代中从数据生成器中获取的样本数量。
batch_size=32,
#随机种子,用于确保每次运行时生成相同的数据增强
seed=42,
#数据集的子集,用于训练。
    subset='training',
)
validation_data_gen = data_gen.flow_from_dataframe(
    dataframe=test_labels_df,
    directory=test_dir,  # 包含子文件夹的父目录
    x_col='image',
    y_col='label',
    target_size=(img_width, img_height),
    batch_size=32,
seed=42,
#数据集的子集,用于验证。
    subset='validation',
)

10.构建模型

# 构建模型
model = Sequential()
model.add(Conv2D(32, (3, 3), activation='relu', input_shape=(150, 150, 3)))
model.add(MaxPooling2D(pool_size=(2, 2)))

# 新增的卷积层
model.add(Conv2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

model.add(Conv2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))

# 展平层
model.add(Flatten())

# 全连接层
model.add(Dense(128, activation='relu'))
model.add(Dropout(0.5))

# 输出层
model.add(Dense(7, activation='softmax'))

11.编译模型

model.compile(loss='categorical_crossentropy',
              optimizer='adam',
              metrics=['accuracy'])

model:这是之前创建和配置的Keras模型。
compile:这是Keras中的一个方法,用于编译模型,指定训练过程中使用的损失函数、优化器和评估指标。
loss='categorical_crossentropy':这是模型使用的损失函数,适用于多类分类问题。
optimizer='adam':这是模型使用的优化器,用于调整模型的权重以最小化损失函数。
metrics=['accuracy']:这是模型使用的评估指标,用于评估模型在训练数据上的性能。

12.训练模型并收集历史

history = model.fit(train_data_gen, epochs=EPOCHS, validation_data=validation_data_gen, batch_size=BS)

fit:这是Keras中的一个方法,用于训练模型。
train_data_gen:这是之前创建的训练数据生成器。
epochs=EPOCHS:这是训练过程中重复训练数据的次数。
validation_data=validation_data_gen:这是用于验证模型的数据。
batch_size=BS:这是每次迭代中从数据生成器中获取的样本数量。
history:这是训练过程中记录的性能指标,如损失和准确率。

13.绘制损失和准确率曲线

N = np.arange(0, EPOCHS)
#设置图表的样式
plt.style.use('ggplot')
plt.figure()

plt.plot(N, history.history['loss'], label='train_loss')
plt.plot(N, history.history['val_loss'], label='val_loss')
plt.plot(N, history.history['accuracy'], label='train_acc')
plt.plot(N, history.history['val_accuracy'], label='val_acc')

plt.title("Training Loss And Accuracy (CNN)")
plt.xlabel('Epoch #')
plt.ylabel('Loss/Accuracy')
plt.legend()
plt.axis([0, EPOCHS, 0, 2])

14.保存图表

plt.savefig('plot.png')

15.保存模型到本地

print('[INFO] 正在保存模型')
model.save('model.h5')

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/637132.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【AD21】PCB板尺寸与层名称标注

PCB绘制完成后,需要给上级或生产制造商发送输出文件,输出文件中包含板尺寸标识和层标识可以方便工作的交接。 1. 板尺寸标识 首先板尺寸标识所在的层要在与板框不同的机械层,这里我选择机械5层。 点击放置->尺寸->线性尺寸 这里板尺…

微信小程序uniapp+django洗脚按摩足浴城消费系统springboot

原生wxml开发对Node、预编译器、webpack支持不好,影响开发效率和工程构建。所以都会用uniapp框架开发 前后端分离,后端给接口和API文档,注重前端,接近原生系统 使用Navicat或者其它工具,在mysql中创建对应名称的数据库&#xff0…

利用大模型构造数据集,并微调大模型

一、前言 目前大模型的微调方法有很多,而且大多可以在消费级显卡上进行,每个人都可以在自己的电脑上微调自己的大模型。 但是在微调时我们时常面对一个问题,就是数据集问题。网络上有许多开源数据集,但是很多时候我们并不想用这…

Gerchberg-Saxton (GS) 和混合输入输出(Hybrid Input-Output, HIO)算法

文章目录 1. 简介2. 算法描述3. 混合输入输出(Hybrid Input-Output, HIO)算法3.1 HIO算法步骤3.2 HIO算法的优势3.3 算法描述 4. 算法实现与对比5. 总结参考文献 1. 简介 Gerchberg-Saxton (GS) 算法是一种常用于相位恢复和光学成像的迭代算法。该算法最…

【抽代复习笔记】18-置换练习题(2)及两个重要定理

最近一直忙于学校的事情,好久没更新了,实在抱歉。接下来几期大概也会更得慢一些,望见谅。 练习4:写出4次对称群S4中所有置换。 解:由上一篇笔记结尾的定理我们知道,4次对称群的阶(也就是所含元…

JSON的序列化与反序列化以及VSCode执行Run Code 报错

JSON JSON: JavaScript Object Notation JS对象简谱 , 是一种轻量级的数据交换格式。 JSON格式 { "name":"金苹果", "info":"种苹果" } 一个对象:由一个大括号表示.括号中通过键值对来描述对象的属性 (可以理解为, 大…

2024年 电工杯 (A题)大学生数学建模挑战赛 | 园区微电网风光储协调优化配置 | 数学建模完整代码解析

DeepVisionary 每日深度学习前沿科技推送&顶会论文&数学建模与科技信息前沿资讯分享,与你一起了解前沿科技知识! 本次DeepVisionary带来的是电工杯的详细解读: 完整内容可以在文章末尾全文免费领取&阅读! 问题重述…

MVS net笔记和理解

文章目录 传统的方法有什么缺陷吗?MVSnet深度的预估 传统的方法有什么缺陷吗? 传统的mvs算法它对图像的光照要求相对较高,但是在实际中要保证照片的光照效果很好是很难的。所以传统算法对镜面反射,白墙这种的重建效果就比较差。 …

【Python自动化测试】:Unittest单元测试与HTMLTestRunner自动生成测试用例的好帮手

读者大大们好呀!!!☀️☀️☀️ 🔥 欢迎来到我的博客 👀期待大大的关注哦❗️❗️❗️ 🚀欢迎收看我的主页文章➡️寻至善的主页 文章目录 🔥前言🚀unittest编写测试用例🚀unittest测…

【408精华知识】Cache类题目解题套路大揭秘

有关Cache的题目,需要理解Cache的工作原理,也即给出一个地址,要知道如何在Cache中寻找或者如何将其从主存中复制入Cache,同时理解Cache中具体是如何存储的,包含三种存储方式,分别是直接映射、全相联映射、组…

clion/pycharm 安装中文

楼主版本 2024.1 mac 操作系统,理论上不同版本和不同操作系统操作应该大同小异 首先找到插件的位置 方式一 1、进入工程,右上角找到设置 2、找到插件(欢迎界面也能找到这个) 方式二 在欢迎界面找到插件 最后 插件商店搜索 l…

矩阵乘法不满足交换律-反证法

假定有2个矩阵A和B A*B 不等于 B*A 手写证明: A*B为 B*A为 由此可以看出,矩阵乘法不满足交换律!!

Python | Leetcode Python题解之第100题相同的树

题目: 题解: class Solution:def isSameTree(self, p: TreeNode, q: TreeNode) -> bool:if not p and not q:return Trueif not p or not q:return Falsequeue1 collections.deque([p])queue2 collections.deque([q])while queue1 and queue2:node…

centos7和centos8安装mysql5.6 5.7 8.0

https://dev.mysql.com/downloads/repo/yum/ 注意构造下http://repo.mysql.com/mysql-community-release-el*-*.noarch.rpm 【以centos7为例】 安装mysql5.6 wget http://repo.mysql.com/mysql-community-release-el7-5.noarch.rpm rpm -ivh mysql-community-release-el7-5…

初识Qt:从Hello world到对象树的深度解析

Qt中的对象树深度解析 Hello world1.图形化界面创建命令行式创建在栈上创建在堆上创建为什么传文本需要QString,std::string不行吗?那为什么要传入this指针?为什么new后不用显示调用delete函数呢,不会造成内存泄漏问题吗&#xff…

国产操作系统上使用SQLynx连接数据库 _ 统信 _ 麒麟 _ 中科方德

原文链接:国产操作系统上使用SQLynx连接数据库 | 统信 | 麒麟 | 中科方德 Hello,大家好啊!今天我们将探讨如何在国产操作系统上使用SQLynx。这是一款功能强大的数据库管理工具,可以帮助用户高效地管理和操作数据库。本文将详细介绍…

2024 电工杯高校数学建模竞赛(A题)数学建模完整思路+完整代码全解全析

你是否在寻找数学建模比赛的突破点?数学建模进阶思路! 作为经验丰富的数学建模团队,我们将为你带来2024电工杯数学建模竞赛(B题)的全面解析。这个解决方案包不仅包括完整的代码实现,还有详尽的建模过程和解…

Docker搭建mysql性能测试环境

OpenEuler使用Docker搭建mysql性能测试环境 一、安装Docker二、docker安装mysql三、测试mysql连接 一、安装Docker 建立源文件vim /etc/yum.repos.d/docker-ce.repo增加内容[docker-ce-stable] nameDocker CE Stable - $basearch baseurlhttps://repo.huaweicloud.com/docker…

NLP(18)--大模型发展(2)

前言 仅记录学习过程,有问题欢迎讨论 Transformer结构: LLM的结构变化: Muti-head 共享: Q继续切割为muti-head,但是K,V少切,比如切为2个,然后复制到n个muti-head减少参数量,加速训练 atte…

STM32-串口通信波特率计算以及寄存器的配置详解

您好,我们一些喜欢嵌入式的朋友一起建立的一个技术交流平台,本着大家一起互相学习的心态而建立,不太成熟,希望志同道合的朋友一起来,抱歉打扰您了QQ群372991598 串口通信基本原理 处理器与外部设备通信的两种方式 并行…