深度学习笔记(八)——构建网络的常用辅助增强方法:数据增强扩充、断点续训、可视化和部署预测

文中程序以Tensorflow-2.6.0为例
部分概念包含笔者个人理解,如有遗漏或错误,欢迎评论或私信指正。
截图和程序部分引用自北京大学机器学习公开课

要构建一个完善可用的神经网络,除了设计网络结构以外,还需要添加一些辅助代码来增强网络运行的稳定性,鲁棒性。可以用来增强的方向主要有 个,首先是数据输入前的预处理环节,其次是数据在训练过程中的优化,最后的数据在训练结束后的导出和可视化,同时能够及时保存结果和继续上一次训练在实际工作中是十分有效的。

训练的前奏曲——数据集

在前面的代码中往往是直接加载现有的数据集,然后送入网络进行学习,实际工作研究中,数据往往需要重新花费不少时间去采集、标准化和标注。对于不同的数据集一般可以有不同的处理方式。首先是最常见的图像数据,在进行图像学习之前先要制作图像数据集,一个好的数据集可以帮助我们达到事半功倍的效果。制作图像数据集时,有几个基础要求:

  • 有比较明显的特征,图像中背景信息和语义信息有比较明显的区分,语义信息就是我们关注的对象;
  • 尺寸合适,图像尺寸能够满足网络输入的基本要求,现在大部分主流手机拍照得到的照片尺寸太大,不合适在验证阶段使用;
  • 数据格式能够加载,图像数据格式最好是RGB或者灰度图;
  • 数据覆盖全面,数据源尽可能多的覆盖研究对象可能出现或者存在的场景;
  • 标注不出错,标注时不能只图快,要保证一定的精确度

当然,研究的数据不一定都是图像,还有可能是时间序列数据,是多组传感器的采集值,是一段文本等等,但是都避不开 数据格式、数据规模、标签准确度这几个关键因素。
下面还是以基础的图像处理来说明数据集处理中常用的几个方法。

数据增强,扩充数据集

当被用来的训练的数据由于成本、时间等原因受到限制的时候可以通过数据的扩充适当的将原始的数据增加出一定比例。但这不意味着可以无限制的增加数据规模,只有当已经采集到的数据达到一定规模之后,数据的扩充作为锦上添花才能起到比较好的作用。
在数据扩充之前要先解决一个问题,就是如何把数据从原始的文件夹中读取到程序中。通常可以使用python自带的os库遍历某个路径下所有符合要求格式的数据,然后以此加载数据,随后通过PIL或Pands对读取的数据进行格式化操作,最后转化为numpy数组,再导入为tensor格式就可以用来训练了。在加载数据和格式化数据时,要注意做好特征数据和标签数据的对应。
下面我们尝试从指定的本文文件中读取数据和对应的标签,首先假设我们有这样的一组数据:
在这里插入图片描述
右键,在所在位置打开终端输入:

dir/b>file_name.csv

这段代码会将当前目录下的所有文件名以此写到一个csv表格中。我们可以使用excel打开这个表格,修改其中数据的分类,或者根据自己的需要进行修改。当然不一定要使用csv,也可以使用txt后缀,这样文件就直接输出到txt文本中。
针对csv我们可以逐行或者逐列读取:

import pandas as pd
# 读取 CSV 文件
df = pd.read_csv('../Data/MNIST/file_name.csv')
# 获取列数据
train_data_file = df['train']
train_data_label = df['train_label']
test_data_file = df['test']
test_data_label = df['test_label']
# 打印列数据
print(train_data_file )
print(train_data_label)

表格数据:
在这里插入图片描述
程序输出:

0    0.png
1    1.png
2    2.png
3    3.png
4    4.png
5    5.png
Name: train, dtype: object
0    0
1    1
2    2
3    3
4    4
5    5
Name: train_label, dtype: int64

那么我们直接构建一个函数读取指定的数据内容的函数

import pandas as pd
from PIL import Image
import numpy as np
# 读取 CSV 文件
def readImage(image_path, file_path):
    csv_file = pd.read_csv(file_path)
    # 获取列数据
    train_data_file = csv_file['train']
    train_data_label = csv_file['train_label']
    test_data_file = csv_file['test']
    test_data_label = csv_file['test_label']
    x, y_ , t, yt_= [], [], [], []
    for _index in np.arange(0, train_data_file.shape[0], 1):
        if pd.notna(train_data_file.iloc[_index]):  # 判断如果数据非空
            img_ = Image.open(image_path + train_data_file[_index])
            img_ = np.array(img_.convert('L'))
            img_ = img_ / 255.  # 数据标准归一化
            x.append(img_)
            y_.append(train_data_label[_index])
    for _index in np.arange(0, test_data_file.shape[0], 1):
        if pd.notna(test_data_file.iloc[_index]):
            img_ = Image.open(image_path + test_data_file[_index])
            img_ = np.array(img_.convert('L'))
            img_ = img_ / 255.
            t.append(img_)
            yt_.append(test_data_label[_index])
    return (x, y_), (t, yt_)

(train_img, train_lab), (test_img, test_lab)  = readImage(image_path='../Data/MNIST/', file_path='../Data/MNIST/file_name.csv')

把数据读取为numpy之后就可以进一步载入tf中,利用tf函数进行数据的扩充。扩充的方法主要有:随机平移、缩放、0填充、随机旋转。

img_prossess_Gen = tf.keras.preprocessing.image.ImageDataGenerator(
    rescale='所有数据乘以这个数(倍乘)',
    horizontal_flip='是否随机水平旋转 Boolean',
    rotation_rang='随机旋转的角度范围 Int',
    width_shift_range='随机宽度偏移量',
    height_shift_range='随机高度便宜量',
    zoom_range='随机缩放的范围 Float or [lower, upper].'
)
img_prossess_Gen.fit(train_img)

所以结合前面分类博客中的代码,我们可以得到一个简单的例子:

cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
datagen = ImageDataGenerator(	# 定义数据扩充项
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    validation_split=0.2)
datagen.fit(x_train)	# 扩充训练数据
model.fit(datagen.flow(x_train, y_train, batch_size=32,	# 根据扩充数据并辅助分组后进行训练
         subset='training'),
         validation_data=datagen.flow(x_train, y_train,
         batch_size=8, subset='validation'),
         steps_per_epoch=len(x_train) / 32, epochs=epochs)

到此我们实现了加载数据并且扩充数据,利用扩充数据实现网络训练的操作。

训练的曲谱——过程优化

当数据量较大的时候并且网络结构比较复杂的情况下,我们比较希望能够在训练的过程中按照一定阶段保存训练模型。并且在未来某个时候重新加载数据继续训练。当然,我们也可以在一个足够大的通用数据集先进行预训练,让网络学习到数据的通用特征。然后将得到的模型导出,并重新加载细节上符合工作研究要求的数据重新开始训练。
同时在训练的过程中,也希望网络能够记住每个迭代计算的结果,并且计时的把训练过程中最好的模型保存下来。这个操作在网络训练的后期会显得十分重要。

断点续训,模型保存和读取

首先应该指定一个模型的保存路径,如果在路径中已经有需要的历史训练数据,就直接加载历史模型。值得注意,保存的动态模型格式是 ckpt

# 记录模型保存路径
checkpoint_save_path = "./checkpoint/Baseline.ckpt"
if os.path.exists(checkpoint_save_path + '.index'):
    print('-------------load the model-----------------')
    model.load_weights(checkpoint_save_path)     # 如果有模型则加载后使用

在加载之后还需要能够保存模型,这里使用keras提供的训练过程记录器,通过提供一个训练过程的回调函数来检测训练和保存模型。

# 定义保存和记录数据的回调器
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
                                                 save_weights_only=True,	# 保存模型
                                                 save_best_only=True)	# 只保存最好的模型
# 使用数据扩充,并添加回调控制器,用来记录模型
history = model.fit(image_gen_train.flow(x_train, y_train, batch_size=32),
                    epochs=5, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback])
# 训练过程的历史参数可以通过 history 查看

参数查看

在上面的程序中已经实现了模型的保存。但是跟具体的还可以把网络中每个层中每个连接的权重参数偏置项,和卷积计算的结果,卷积核的参数,偏置保存下来。

# 保存网络权重参数
# print(model.trainable_variables)
file = open('./weights.txt', 'w')
for v in model.trainable_variables:		# 逐行的将网络中的所有参数写入到 weights.txt 文本文件中
    file.write(str(v.name) + '\n')
    file.write(str(v.shape) + '\n')
    file.write(str(v.numpy()) + '\n')
file.close()

训练过程可视化、tensorboard

如何观察参数的效果,除了在终端打印训练过程中的Loss,准确度以外,可以将这些关键的数据保存下来,这样调整参数后不同的效果就可以通过曲线图像的形式保存出来,便于观察变化趋势,指导设计者调节参数。在绘制曲线之前首先要在训练位置的函数输出 history 参数,后续通过调用这个参数中的数据在加上matplotlib来画出曲线。

# 显示训练集和验证集的acc和loss曲线
acc = history.history['sparse_categorical_accuracy']
val_acc = history.history['val_sparse_categorical_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
plt.subplot(1, 2, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.subplot(1, 2, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.show()

在这里插入图片描述

根据图像不难看出随着迭代次数的增加,训练集和测试集的测试准确率不断上升,损失值不断下降。此时可以增加训练的迭代次数,并微调相关参数,查看网络可能出现的不同的效果。
除了上面提及的讲训练结果导出画图的方法,还可以安装tensorboard,一般安装tensorflow2时会配套自动安装。只需要在模型训练结果保存的位置打开终端,启动对应的虚拟环境,然后输入 tensorboard,就可以在给出的网页中查看到实时的训练参数。在训练的函数中需要加入关于tensorboard的回调函数。

# 设置TensorBoard输出的回调函数
tf_callback = tf.keras.callbacks.TensorBoard(log_dir="./logs")	# 设置log文件存放地址,这里是相对路径,也可以使用绝对路径
# 使用数据扩充
history = model.fit(image_gen_train.flow(x_train, y_train, batch_size=64),
                    epochs=8, validation_data=(x_test, y_test), validation_freq=1,
                    callbacks=[cp_callback, tf_callback])	# 这里的回调包含保存模型和输出tensorboard训练参数

运行后再终端启动虚拟环境,输入指令加上你设置的log文件存放地址

tensorboard --logdir YOU_LOG_PATH

在这里插入图片描述
在tensorboard中可以观察模型的结构和训练中的参数数据。

训练的终章——结果可视化与数据评价

训练得到一个新的模型了,那如何使用它。这就需要再继续添加新的代码,用来单独加载模型并实现模型的前向推理,最后推理结果给出。
根据上面的例子,我们以手写数字的代码为例,可以很容易得到如下代码:

import cv2
# opencv-python==4.5.1.48
# 1. 加载图像
model_path = './checkpoint/mnist/mnist.ckpt'
image_path = '../Data/MNIST/'

new_model = mnisModel()
new_model.load_weights(model_path)

preNum = int(input("place input how many jpg file while be test:"))
for i in range(preNum):
   imgNum = int(input("place input png name:"))
   img_path = image_path+str(imgNum)+'.png'
   print("read image:{}".format(img_path))

   img_ = cv2.imread(img_path)
   resized_img = cv2.resize(img_, (28, 28), interpolation=cv2.INTER_AREA)
   gray_img = cv2.cvtColor(resized_img, cv2.COLOR_BGR2GRAY)
   # 4. 准备图像数据,进行归一化和添加批次维度
   cv2.imshow("input num", img_)
   img_for_prediction = gray_img.astype(np.float32) / 255.0  # 归一化到 [0, 1]
   img_for_prediction = np.expand_dims(img_for_prediction, axis=0)  # 添加批次维度

   result = model.predict(img_for_prediction)
   predNum = tf.argmax(result, axis=1)
   print("predice num is: ")
   tf.print(predNum)

通过上面的程序最终我们实现了加载已经有的模型,然后继续开始前向推理,并输出推理的结果。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/327829.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Flume 之自定义Sink

1、简介 前文我们介绍了 Flume 如何自定义 Source, 并进行案例演示,本文将接着前文,自定义Sink,在这篇文章中,将使用自定义 Source 和 自定义的 Sink 实现数据传输,让大家快速掌握Flume这门技术。 2、自定…

【PostgreSQL】安装和常用命令教程

PostgreSQL window安装教程 window安装PostgreSQL 建表语句: DROP TABLE IF EXISTS student; CREATE TABLE student (id serial NOT NULL,name varchar(100) NOT NULL,sex varchar(5) NOT NULL,PRIMARY KEY (id) );INSERT INTO student (id, name, sex) VALUES (…

【电力电子】2 开、闭环单相桥式SPWM逆变仿真电路

【仅供参考】 【2022.11西南交大电力电子仿真】 目录 1 开环单相桥式SPWM逆变电路搭建及波形记录 2 闭环单相桥式SPWM逆变电路搭建及波形记录 1 开环单相桥式SPWM逆变电路搭建及波形记录 采用单极性调制法,按老师PPT(如下图)所示进行单相…

图解基础排序算法(冒泡、插入、选择)(山东大学实验二)

目录 ⚽前言: 🏐 冒泡排序: 设定: 分类: 起源: 图解冒泡: 图中绿色: 图中橙色: 整体思路: 交换思路: 核心代码: &#x…

基于WebSocket双向通信技术实现-下单提醒和催单(后端)

学习复盘和总结项目亮点。 扩展:该功能能应用在,各种服务类项目中。(例如:酒店、洗脚城等系ERP系中提醒类服务) 4. 来单提醒 4.1 需求分析和设计 用户下单并且支付成功后,需要第一时间通知外卖商家。通…

服务网关 Gateway

服务网关 Gateway Spring Cloud Gateway 是 Spring Cloud 生态系统中的网关,它基于 Spring5.0 SpringBoot2.0 WebFlux(基于高性能的 Reactor 模式响应式通信框架 Netty,异步非阻塞模型)等技术开发。旨在为微服务架构提供一种简…

如何在CentOS 7 中搭建Python 3.0 环境

1、下载 通过https://www.python.org/ftp/python/下载Python安装包,这里下载Python-3.10.9.tgz; 2、上传 借助MobaXterm等工具将Python安装包上传至/opt目录; 3、解压 将JDK压缩文件解压至/opt目录:tar -xvf /opt/Python-3.1…

React Store及store持久化的使用

1.安装 npm insatll react-redux npm install reduxjs/toolkit npm install redux-persist2. 使用React Toolkit创建counterStore并配置持久化 store/modules/counterStore.ts: import { createSlice } from reduxjs/toolkit// 定义状态类型 interface Action {…

Python数据分析案例33——新闻文本主题多分类(Transformer, 组合模型) 模型保存

案例背景 对于海量的新闻,我们可能需要进行文本的分类。模型构建很重要,现在对于自然语言处理基本都是神经网络的方法了。 本次这里正好有一组质量特别高的新闻数据,涉及 教育 科技 社会 时政 财经 房产 家居 七大主题,基本涵盖…

编译原理1.1习题 语言处理器

图源:文心一言 编译原理习题整理~🥝🥝 作为初学者的我,这些习题主要用于自我巩固。由于是自学,答案难免有误,非常欢迎各位小伙伴指正与讨论!👏💡 第1版:自…

HCIP -- ospf实验

要求: 实现: R3:int t 0/0/0ip address 172.16.1.4 255.255.255.248 (配置虚拟接口ip地址)tunnel-protocol gre p2mp (配置接口协议为p2mp)source 43.0.0.2 (配置源)osp…

【UE 材质】简单的纹理失真、溶解效果

目录 1. 失真效果 2. 溶解效果 3. 失真溶解 我们一开始有这样一个纹理 1. 失真效果 其中纹理节点“DistortTexture”的纹理为引擎自带的纹理“T_Noise01”,我们可以通过控制参数“失真度”来控制纹理的失真程度 2. 溶解效果 3. 失真溶解

奇安信天擎 rptsvr 任意文件上传漏洞复现

0x01 产品简介 奇安信天擎是奇安信集团旗下一款致力于一体化终端安全解决方案的终端安全管理系统(简称“天擎”)产品。通过“体系化防御、数字化运营”方法,帮助政企客户准确识别、保护和监管终端,并确保这些终端在任何时候都能可信、安全、合规地访问数据和业务。天擎基于…

微软推出付费版Copilot

关注卢松松,会经常给你分享一些我的经验和观点。 微软已经超越苹果,成了全球市值最高的公司,其他公司都因为AI大裁员,而微软正好相反,当然这个原因很简单:就是微软强制把AI全面接入到系统里来了。而Copilot…

Electron+React项目打包踩坑记录

首先,如何打包 写下本文的时间是 2024/01/16,搜索了网络上 ElectronReact 的打包方式,中间行不通,本文采用的方式是记录本文时 Electron 快速入门(https://www.electronjs.org/zh/docs/latest/tutorial/quick-start)记录的打包方式…

【Java 设计模式】创建型之抽象工厂模式

文章目录 1. 定义2. 应用场景3. 代码实现4. 应用示例结语 在软件开发中,抽象工厂模式是一种常见的创建型设计模式,它提供了一种创建一系列相关或相互依赖对象的接口,而无需指定它们具体的类。抽象工厂模式的核心思想是将一组相关的产品组合成…

spring boot shardingsphere mybatis-plus druid mysql 搭建mysql数据库读写分离架构

spring boot shardingsphere mybatis-plus druid mysql 搭建mysql数据库读写分离架构 ##关于window mysql主从搭建简单教程 传送门 window mysql5.7 搭建主从同步环境-CSDN博客 ##父pom.xml <?xml version"1.0" encoding"UTF-8"?> <project…

北交所交易手续费标准?哪家证券公司开通北交所券商交易手续费佣金万2?

北交所&#xff08;Beijing Exchange&#xff09;是指位于中国北京的一家金融交易所。北交所是中国政府为推动金融改革和国际化市场而设立的交易场所。它提供包括股票、债券、期货、外汇等多种金融产品的交易服务。北交所的目标是促进中国金融市场的发展&#xff0c;吸引国内外…

机器人持续学习基准LIBERO系列7——计算并可视化点云

0.前置 机器人持续学习基准LIBERO系列1——基本介绍与安装测试机器人持续学习基准LIBERO系列2——路径与基准基本信息机器人持续学习基准LIBERO系列3——相机画面可视化及单步移动更新机器人持续学习基准LIBERO系列4——robosuite最基本demo机器人持续学习基准LIBERO系列5——…

git项目管理

Git工作流程图 git 基础指令 git init #创建本地仓库,创建成功后&#xff0c;当前目录会多一个.git文件夹 git status #查看修改状态 git add . #添加工作区到暂存区 git commit -m 注释内容 #提交暂存区到本地仓库&#xff08;commit&#xff09; git log …