深度学习——自己的训练集——训练模型(CNN)

训练模型

    • 1.导入必要的库
    • 2.加载类别名称
    • 3.创建标签映射字典
    • 4.加载图像数据和对应的标签
    • 5.构建和编译CNN模型
    • 6.训练模型
    • 7.保存训练好的模型

1.导入必要的库

导入处理数据和训练模型时需要的库
os: 这个模块提供了与操作系统交互的功能,比如文件和目录操作。
cv2: 这是OpenCV库的别名,它是一个强大的计算机视觉库,用于图像和视频处理。
numpy as np: NumPy是一个用于科学计算的库,它提供了高效的数组处理能力,对于图像处理等任务非常有用。
tensorflow as tf: TensorFlow是一个开源的机器学习库,用于构建和训练各种类型的机器学习模型。

import os
import cv2
import numpy as np
import tensorflow as tf

2.加载类别名称

with open(‘99/classes.txt’, ‘r’) as f:

with open(...) as f::这是上下文管理器(context manager),用于自动处理文件资源的打开和关闭。当with语句执行完成后,文件会自动关闭,即使遇到异常也是如此。
'99/classes.txt':这是要打开的文件的路径。
'r':这是文件打开模式,表示以只读方式打开文件。
f:这是上下文管理器创建的文件对象,可以用来读取文件内容。

classes = f.read().splitlines():

f.read():这个方法调用用于读取文件的全部内容,并将结果作为一个字符串返回。
.splitlines():这个方法调用用于将字符串按照行分隔符(通常是换行符\n)分割成一个列表。
classes:这个变量存储了分割后的列表,其中每个元素都是一个从文件中读取的标签名称。

with open('99/classes.txt', 'r') as f:
    classes = f.read().splitlines()

3.创建标签映射字典

创建了一个标签映射字典,用于将标签索引转换为实际的标签名称。

label_mapping = {
    '0': 'sad',
    '1': 'happy',
    '2': 'amazed',
    '3': 'anger'
}

4.加载图像数据和对应的标签

从文件夹中加载了图像数据和对应的标签。

image_folder = '561'
label_folder = '99'

X_train = []
y_train = []

#遍历image_folder文件夹中的所有文件
for image_file in os.listdir(image_folder):
#创建一个完整的文件路径,将image_folder目录的路径和image_file(文件或子目录的名称)连接起来。
image_path = os.path.join(image_folder, image_file)

#cv2.imread(image_path):这个函数调用用于读取图像文件。
image = cv2.imread(image_path)

#如果图像成功加载,将图像数据添加到X_train列表中。
    if image is not None:
        X_train.append(image)

	#将label_folder目录的路径和image_file(去除.jpg扩展名后的文件名)连接起来,并在最后加上.txt扩展名
        label_file = os.path.join(label_folder, image_file.replace('.jpg', '.txt'))
        with open(label_file, 'r') as f:
            label_index = f.readline().strip().split()[0]  # 只取第一个数字作为标签索引
            label_name = label_mapping[label_index]
            label = classes.index(label_name)
            y_train.append(label)

X_train = np.array(X_train)
y_train = np.array(y_train)

5.构建和编译CNN模型

model = tf.keras.Sequential([
    tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(image.shape)),
    tf.keras.layers.MaxPooling2D((2, 2)),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(len(classes), activation='softmax')
])

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])

model = tf.keras.Sequential([ … ]):

tf.keras.Sequential:这是一个Keras模型,用于创建一个包含顺序堆叠的层的模型。
[ ... ]:这是一个列表,其中包含了模型中的层。
model:这个变量存储了创建的Keras Sequential模型。

tf.keras.layers.Conv2D(32, (3, 3), activation=‘relu’,
input_shape=(image.shape)):

tf.keras.layers.Conv2D:这是一个2D卷积层,用于提取图像的局部特征。
32:这是卷积层的输出通道数。
(3, 3):这是卷积核的大小,即每个卷积核覆盖的像素区域。
activation='relu':这是激活函数,用于在每个卷积层之后应用。
input_shape=(image.shape):这是输入层的形状,它是从image.shape获得的,确保模型的输入形状与图像数据的形状匹配。

tf.keras.layers.MaxPooling2D((2, 2)):

tf.keras.layers.MaxPooling2D:这是一个2D最大池化层,用于通过取每个池化区域的最大值来减小特征图的大小。
(2, 2):这是池化窗口的大小,即每个池化操作覆盖的像素区域。

tf.keras.layers.Flatten():

tf.keras.layers.Flatten:这是一个扁平化层,用于将2D或多维数组展平为一维数组。

tf.keras.layers.Dense(128, activation=‘relu’):

tf.keras.layers.Dense:这是一个全连接层,用于在模型中添加更多的非线性变换。
128:这是全连接层的神经元数量。
activation='relu':这是激活函数,用于在每个全连接层之后应用。

tf.keras.layers.Dense(len(classes), activation=‘softmax’):

tf.keras.layers.Dense:这是一个全连接层,用于在模型中添加更多的非线性变换。
len(classes):这是全连接层的神经元数量,它等于类别的数量。
activation='softmax':这是激活函数,用于在每个全连接层之后应用,以产生一个概率分布。

model.compile(optimizer=‘adam’,
loss=‘sparse_categorical_crossentropy’, metrics=[‘accuracy’]):

model.compile:这个方法用于编译模型,指定训练过程中使用的优化器、损失函数和评估指标。
optimizer='adam':这是模型使用的优化器,用于调整模型的权重以最小化损失函数。
loss='sparse_categorical_crossentropy':这是模型使用的损失函数,用于评估模型在训练数据上的性能。
metrics=['accuracy']:这是模型使用的评估指标,用于评估模型在训练数据上的性能。

6.训练模型

model.fit(X_train, y_train, epochs=20, batch_size=32)

model.fit:这是Keras中的一个方法,用于训练模型。
X_train:这是模型的输入数据,它是一个NumPy数组。
y_train:这是模型的目标数据,它是一个NumPy数组。
epochs=20:这是训练过程中重复训练数据的次数。
batch_size=32:这是每次梯度更新的样本数量。

7.保存训练好的模型

model.save('cnn_model.h5')

在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/648282.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024-5-10-从0到1手写配置中心Config之Spring Value热更新

定义SpringValueProcessor处理类 实现BeanPostProcessor后置处理器接口,扫描所有的Spring value,保存起来。实现ApplicationListener接口,在配置变更时,更新所有的spring value 实现BeanPostProcessor后置处理器接口 实现postPr…

移动云:连接未来的智慧之旅

随着数字化转型的加速,云服务在各行各业中的应用越来越广泛。移动云不仅提供了灵活的计算和存储资源,还通过创新的技术手段,为企业和开发者解决了许多实际问题。在这个变革的大背景下,移动云服务作为中国移动倾力打造的云业务品牌…

155. 最小栈

题目: 设计一个支持 push ,pop ,top 操作,并能在常数时间内检索到最小元素的栈。 实现 MinStack 类: MinStack() 初始化堆栈对象。void push(int val) 将元素val推入堆栈。void pop() 删除堆栈顶部的元素。int top() 获取堆栈顶…

深入解析内置模块OS:让你的Python代码更懂操作系统

新书上架~👇全国包邮奥~ python实用小工具开发教程http://pythontoolsteach.com/3 欢迎关注我👆,收藏下次不迷路┗|`O′|┛ 嗷~~ 目录 一、OS模块简介与基础应用 二、文件与目录操作详解 三、OS模块的高级应用:双色…

【算法】前缀和——除自身以外数组的乘积

本节博客是用前缀和算法求解“除自身以外数组的乘积”,有需要借鉴即可。 目录 1.题目2.前缀和算法3.变量求解4.总结 1.题目 题目链接:LINK 2.前缀和算法 1.创建两个数组 第一个数组第i位置表示原数组[0,i-1]之积第二个数组第i位置表示原数组[i1,n-1]…

How to limit request by IP on nginx?

/etc/nginx/conf.d/default.conf 1.Define a limit_req_zone # 定義限流區塊 limit_req_zone $binary_remote_addr zonelimit_zone:10m rate2r/s; limit_req_zone $binary_remote_addr zonelimit_zone:10m rate2r/s; 是一个 Nginx 配置指令,用于定义请求限制区域和…

【linux】多线程(2)

文章目录 线程的应用生产消费者模型自制锁生产消费队列成员参数生产函数消费函数 任务处理方式主函数 POSIX信号量sem_wait()sem_post() 线程池应用场景示例 单例模式饿汉实现单例 吃完饭, 立刻洗碗, 这种就是饿汉方式. 因为下一顿吃的时候可以立刻拿着碗就能吃饭.懒汉实现单例…

GMSL2硬件设计V1.1

一、说明 GMSL(Gigabit Multimedia Serial Links),中文名称为千兆多媒体串行链路,是Maxim公司(现属于ADI)推出的一种高速串行接口,通过同轴电缆或屏蔽双绞线(STP)传输高速串行数据,用于汽车摄像头和显示器应用。GMSL2就是指ADI专有的第二代千兆多媒体串行链路技术,传输…

重生之while在鸣潮学习HTML

个人主页:终端 今天是开荒的第五天,数据坞都刷了吗,没刷就过来学html! 目录 JavaWeb学习路线 1.HTML入门 1.1什么是HTML 1.2HTML&CSS&JavaScript的作用 1.3什么是超文本 1.4什么是标记语言 1.5HTML基础结构 1.6HTML的…

通过acme.sh和cloudflare实现免费ssl证书自动签发

参考使用acme.sh通过cloudflare自动签发免费ssl证书 | LogDicthttps://www.logdict.com/archives/acme.shshi-yong-cloudflarezi-dong-qian-fa-mian-fei-sslzheng-shu

Jmeter-使用手册(_5.5版本)

JMeter是一个Java桌面应用程序,具有使用Swing图形API的图形界面。可以进行接口、性能等测试,也可以对任何数据库进行同样的测试,具有可移植性,可跨平台支持Windows,Linux,Mac上使用。 JMeter运行场景不仅可…

【openlayers系统学习】4.2Mapbox 样式渲染图层

二、Mapbox 样式渲染图层 显然我们目前的地图需要一些样式。 VectorTile​ 图层的样式与 Vector​ 图层的样式工作方式完全相同。那里描述的样式在这里也适用。 对于这样的地图,创建数据驱动的样式(对矢量图层操作)非常简单。但矢量切片也用…

AIGC 003-Controlnet升级你的SD让图像生成更加可控!

AIGC 003-Controlnet升级你的SD让图像生成更加可控! 文章目录 0 论文工作1 论文方法2 效果 0 论文工作 ControlNet 论文 (Adding Conditional Control to Text-to-Image Diffusion Models) 提出了一种名为 ControlNet 的神经网络结构,旨在为大型文本到图…

趣店集团golang一面要个20K,Channel什么情况下会出现死锁,有遇到过吗?

结束后面试官加了VX,并询问方便二面的时间,一直还没回复,拖着拖着给忘啦... 面试题 1、自我介绍 2、你在团队里头负责哪一块,这个物流开放平台流量多大 3、为什么今年3月份被从物流开放团队转到了finance财务部门,感…

[SWPUCTF 2021 新生赛]pop

常见的魔术方法 魔术方法__construct() 类的构造函数,在对象实例化时调用 __destruct() 类的析构函数,在对象被销毁时被调用 __call() 在对象中调用一个不可访问的对象时被调用,比如一个对象被调用时,里面没有程序想调用的属性 …

​​​【收录 Hello 算法】10.4 哈希优化策略

目录 10.4 哈希优化策略 10.4.1 线性查找:以时间换空间 10.4.2 哈希查找:以空间换时间 10.4 哈希优化策略 在算法题中,我们常通过将线性查找替换为哈希查找来降低算法的时间复杂度。我们借助一个算法题来加深理解。 Question 给…

云上聚智共创未来 | 移动云的项目实战,10分钟让你获得高度可玩的个人博客网站

🎬 鸽芷咕:个人主页 🔥 个人专栏: 《C干货基地》《粉丝福利》 ⛺️生活的理想,就是为了理想的生活! 引入 随着互联网的发展各种以前看起来离我们比较遥远的词越来越近了,比如 云服务、大数据、区块链、容器这些听起来…

04_前端三大件JS

文章目录 JavaScript1.JS的组成部分2.JS引入2.1 直接在head中通过一对script标签定义脚本代码2.2创建JS函数池文件,所有html文件共享调用 3.JS的数据类型和运算符4.分支结构5.循环结构6.JS函数的声明7.JS中自定义对象8.JS_JSON在客户端使用8.1JSON串格式8.2JSON在前…

#12松桑前端后花园周刊-SolidStart、Vercel融资、Angular18、Nextjs15RC、p5.js、ChromeDevTools引入AI

⚡️行业动态 SolidStart 1.0 元框架发布 Solidjs 核心团队发布其元框架 SolidStart 1.0 正式版,其特点如下:基于文件系统的路由;支持SSR、流式SSR、CSR、SSG渲染模式;通过代码分割、树摇和无用代码删除构建优化;基于…

LabVIEW超快激光微纳制造系统设计

LabVIEW超快激光微纳制造系统设计 在当前的制造行业中,精密加工技术的需求日益增长,尤其是在微纳尺度上。超快激光制造技术,以其独特的加工精度和加工效率,成为了精密加工领域的重要手段。然而,大多数超快激光制造系统…