【深度学习_TensorFlow】过拟合

写在前面

过拟合与欠拟合


欠拟合: 是指在模型学习能力较弱,而数据复杂度较高的情况下,模型无法学习到数据集中的“一般规律”,因而导致泛化能力弱。此时,算法在训练集上表现一般,但在测试集上表现较差,泛化性能不佳。

过拟合: 是指模型在训练数据上表现很好,但在测试数据上表现不佳。这是由于模型过于复杂,记住了训练数据中的噪声和模式,而没有学到一般规则。本文将探讨过拟合问题以及其解决方法。

过拟合的解决方法


方法描述
增加训练数据
(More data)通过增加训练数据量,简单粗暴最有效,可以减少过拟合现象。
正则化
(regularization)在损失函数中添加一项,以惩罚模型的复杂度。常用的正则化方法包括L1正则化、L2正则化和dropout。
早停法
(Early stopping)在训练过程中,每次迭代后都会评估模型在验证集上的性能。如果性能在连续若干次迭代中没有提高,就停止训练。
数据增强
(Data augmentation)通过改变原有数据减少过拟合。例如,可以通过旋转、缩放等方式对图像数据进行增强。

写在中间

可以看到随着网络层数增加,模型变得复杂,过拟合现象变得愈发严重。接下来,我们将介绍一系列方法来帮助检测并抑制过拟合现象。

在这里插入图片描述

1. 交叉验证

增加数据集是最有效的方法,但是代价往往是昂贵的,所以要充分利用好现有的数据集。前面我们介绍了数据集需要划分为训练集和测试集,但我们为了挑选模型超参数和检测过拟合线性,一般需要将原来的训练集再次切分为新的训练集和验证集(validation set)。最终数据集被切分为 训练集、验证集、测试集。这三部分数据集的功能如下:

类别描述
训练集用于训练模型的参数,通过学习训练数据集来进行模型训练
验证集用于评估训练过程中的模型表现,调整模型的超参数
测试集用于评估最终训练好的模型在真实数据上的表现,测试验证模型的性能,评估模型的预测能力和泛化能力

验证集和测试集的区分

验证集使命:根据验证集的表现来调整模型的各种超参数的设置,提升模型的泛化能力。

测试集使命:就是检验模型的能力,其表现不能用来反馈模型的调整(就如你不能拿着期末考试原题来练习,否则期末高分就不能体现出你平时学习的真实状况),我们的办法就是从平常的练习题中抽取几道题组成验证集来检验你的能力。

这是一个将mnist手写数字识别测试集切分的例子,将6万张图像的前5万张划分为训练集,后1万张划分为验证集

(x, y), (x_test, y_test) = datasets.mnist.load_data()

# 60k训练集切分为 50k训练集和 10k验证集
# (x_train, y_train), (x_val, y_val)
x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000])
y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])

# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(10000).batch(128)

# 验证集
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_db = val_db.map(preprocess).shuffle(10000).batch(128)

# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).shuffle(10000).batch(128)

但是这样切分还是有局限性的,训练集只能在前5万张图像中出现,验证集只能在后1万张图像中出现,是否有方法让验证集能使用前5万张的图像呢?还真有,这就是标题所提及的**交叉验证:**我们将训练集的6万张图像划成n份,每训练取其中的 n - 1 份作为训练集,取 1 份作为验证集,而非固定前5万张为训练集,后 1 万张为验证集。

# 创建一个范围为60000的索引数组
idx = tf.range(60000)
# 随机打乱索引数组
idx = tf.random.shuffle(idx)
# 使用索引数组从x和y中获取训练集的样本
# 训练集的样本数量为50000
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
# 使用索引数组从x和y中获取验证集的样本
# 验证集的样本数量为10000
x_val, y_val = tf.gather(x, idx[-10000:]), tf.gather(y, idx[-10000:])

也可以不用手动划分,在训练函数中增加参数,即可自动化操作

# 使用60k训练集对网络进行训练,设置训练周期数为6
# 设置验证集占比为0.1,即数据集中10%的数据将作为验证集
# 设置每个2个训练周期进行一次验证
network.fit(train_db, epochs=6, validation_split=0.1, validation_freq=2)

2. 正则化

之所以出现过拟合的现象,是因为模型太过复杂,这里通过限制网络参数的稀疏性来约束网络的实际容量,这种约束一般通过在损失函数上添加额外的参数稀疏性惩罚项实现,常用的正则化的方式有L0、L1、L2正则化,dropout正则化。


简单介绍

关于正则化的数学原理,本人也搞不明白,也就不滥竽充数了。

有关实现简单概括就是在损失函数中引入模型权重参数的L范数,使学习到的权重参数稀疏化。

正则化的方式可以手动实现,也可以调用API实现:其中手动实现主要在计算loss值的后面,调用API主要在创建层的时候。

import tensorflow as tf
from tensorflow.keras import layers, regularizers

# 网络构建
# 这会在模型损失函数中加入权重参数的L2范数作为惩罚项,力度由0.001控制。
network = Sequential([layers.Dense(256, kernel_regularizer=regularizers.l2(0.001), activation='relu'),
                      layers.Dense(128, kernel_regularizer=regularizers.l2(0.001), activation='relu'),
                      layers.Dense(64, kernel_regularizer=regularizers.l2(0.001), activation='relu'),
                      layers.Dense(32, kernel_regularizer=regularizers.l2(0.001), activation='relu'),
                      layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()
# 截取手动前向计算的代码
for step, (x, y) in enumerate(train_db):
    # 创建一个 GradientTape,用于记录计算过程
    with tf.GradientTape() as tape:
        
        x = tf.reshape(x, (-1, 28*28))  # [b, 28, 28] => [b, 784]
       
        out = network(x)  # [b, 784] => [b, 10]
        
        y_onehot = tf.one_hot(y, depth=10)  # [b] => [b, 10]
        # 使用交叉熵损失函数计算 loss
        loss = tf.reduce_mean(tf.losses.categorical_crossentropy(y_onehot, out, from_logits=True))
        
  
        loss_regularization = []
        for p in network.trainable_variables:  # 重点在这里,遍历网络中所有的可训练参数(network.trainable_variables)。
            loss_regularization.append(tf.nn.l2_loss(p))  # # 对每个参数计算L2正则化项(tf.nn.l2_loss(p)),这会返回一个标量。
        loss_regularization = tf.reduce_sum(tf.stack(loss_regularization))  # # 将所有参数的L2正则化项求和,得到正则化损失loss_regularization。
        # 将损失函数定义为交叉熵损失和 L2 正则化损失的和
        loss = loss + 0.0001 * loss_regularization
    # 使用 tape 计算损失函数关于网络参数的梯度,并应用优化器进行反向传播更新参数
    grads = tape.gradient(loss, network.trainable_variables)
    optimizer.apply_gradients(zip(grads, network.trainable_variables))

正则化效果

在这里插入图片描述

Dropout


通过随机断开神经网络的连接,减少每次训练时实际参与计算的模型的参数量,但是在测试时,Dropout会恢复所有的连接,保证模型测试时获得最好的性能。

在这里插入图片描述

我们以层方式来实现以上功能

import tensorflow as tf
from tensorflow.keras import layers, regularizers

# 网络构建

network = Sequential([layers.Dense(256, activation='relu'),
                      layers.Dropout(0.5),  # 有0.5的概率断开与下一层神经元的连接
                      layers.Dense(128, activation='relu'),
                      layers.Dropout(0.5),
                      layers.Dense(64, activation='relu'),
                      layers.Dense(32, activation='relu'),
                      layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()

for step, (x, y) in enumerate(train_db):
    # 训练时
    with tf.GradientTape() as tape:
      out = network(x, training=True)
      
    # 测试时
    out = network(x, training=False)

3. Early stopping

早停法

那么如何选择合适的 Epoch 就提前停止训练(Early Stopping),避免出现过拟合现象呢?我们可以通过观察验证指标的变化,来预测最适合的 Epoch 可能的位置。具体地,对于分类问题,我们可以记录模型的验证准确率,并监控验证准确率的变化,当发现验证准确率连续𝑛个 Epoch 没有下降时,可以预测可能已经达到了最适合的 Epoch 附近,从而提前终止训练。

from tensorflow.keras.callbacks import EarlyStopping


# 数据集读取···

# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss',  # 监视验证集loss
                               patience=3,  # 当验证集loss在3个epoch内都没有改善则停止训练
                               mode='min',  # 监测loss时一般设置为min,监测准确值时一般设置为max
                               verbose=1,  # 检测值改善时打印一条信息
                               restore_best_weights=True  # 将权重恢复到最好的一个epoch
                               )
 # 网络构建
 
 # 参数构建

 # 模型装配

 # 模型训练,添加参数
 network.fit(train_db, epochs=100,
            validation_data=val_db, validation_steps=10,
            callbacks=[early_stopping])

这里我们会对手写数字识别的代码再次进行修改,来使用上面提及的方法,你可以通过修改repeat的方式来复制数据集使训练数据增多,更改epochs的方式来增加训练次数,经过测试,这段代码在10 epochs 之后便达到了过拟合

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics, regularizers
from tensorflow.keras.callbacks import EarlyStopping
# 处理每一张图像
def preprocess(x, y):

    x = tf.cast(x, dtype=tf.float32) / 255.
    x = tf.reshape(x, [28 * 28])
    y = tf.cast(y, dtype=tf.int32)
    y = tf.one_hot(y, depth=10)

    return x, y

# 数据集读取
(x, y), (x_test, y_test) = datasets.mnist.load_data()

# # 固定切分
# # 60k训练集切分为 50k训练集和 10k验证集
# # (x_train, y_train), (x_val, y_val)
# x_train, x_val = tf.split(x, num_or_size_splits=[50000, 10000])
# y_train, y_val = tf.split(y, num_or_size_splits=[50000, 10000])

# 交叉验证切分
idx = tf.range(60000)
idx = tf.random.shuffle(idx)
x_train, y_train = tf.gather(x, idx[:50000]), tf.gather(y, idx[:50000])
x_val, y_val = tf.gather(x, idx[-10000:]), tf.gather(y, idx[-10000:])

# 训练集
train_db = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_db = train_db.map(preprocess).shuffle(10000).batch(128).repeat(10)

# 验证集
val_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
val_db = val_db.map(preprocess).shuffle(10000).batch(128)

# 测试集
test_db = tf.data.Dataset.from_tensor_slices((x_test, y_test))
test_db = test_db.map(preprocess).shuffle(10000).batch(128)

# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss',  # 监视验证集loss
                               patience=3,  # 当验证集loss在2个epoch内都没有改善则停止训练
                               mode='min',  # 监测loss时一般设置为min,监测准确值时一般设置为max
                               verbose=1,  # 检测值改善时打印一条信息
                               restore_best_weights=True  # 将权重恢复到最好的一个epoch
                               )

# 网络构建
# 正则化:在模型损失函数中加入权重参数的L2范数作为惩罚项,力度由0.001控制。
# Dropout:添加dropout层来随机断开连接
network = Sequential([layers.Dense(256, kernel_regularizer=regularizers.l2(0.001), activation='relu'),
                      layers.Dropout(0.5),
                      layers.Dense(128, kernel_regularizer=regularizers.l2(0.001), activation='relu'),
                      layers.Dropout(0.5),
                      layers.Dense(64, kernel_regularizer=regularizers.l2(0.001), activation='relu'),
                      layers.Dense(32, kernel_regularizer=regularizers.l2(0.001), activation='relu'),
                      layers.Dense(10)])
# 参数构建
network.build(input_shape=(None, 28*28))
# 模型展示
network.summary()
# 模型装配
network.compile(optimizer=optimizers.Adam(learning_rate=0.01),
                loss=tf.losses.CategoricalCrossentropy(from_logits=True),
                metrics=['accuracy'])
# 模型训练
network.fit(train_db, epochs=50,
            validation_data=val_db, validation_steps=10,
            callbacks=[early_stopping])
# 模型评估
print('模型评估:')
network.evaluate(test_db)

写在最后

👍🏻点赞,你的认可是我创作的动力!
⭐收藏,你的青睐是我努力的方向!
✏️评论,你的意见是我进步的财富!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/90034.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

FPGA应用于图像处理

FPGA应用于图像处理 FPGA(Field-Programmable Gate Array)直译过来就是现场可编程门阵列。是一种可以编程的逻辑器件,具有高度的灵活性,可以根据具体需求就像编程来实现不同的功能。 FPGA器件属于专用的集成电流中的一种半定制电…

TinyVue - 华为云 OpenTiny 出品的企业级前端 UI 组件库,免费开源,同时支持 Vue2 / Vue3,自带 TinyPro 中后台管理系统

华为最新发布的前端 UI 组件库,支持 PC 和移动端,自带了 admin 后台系统,完成度很高,web 项目开发又多一个选择。 关于 OpenTiny 和 TinyVue 在上个月结束的华为开发者大会2023上,官方正式进行发布了 OpenTiny&#…

使用VSCode SSH实现公网远程连接本地服务器开发的详细教程

文章目录 前言1、安装OpenSSH2、vscode配置ssh3. 局域网测试连接远程服务器4. 公网远程连接4.1 ubuntu安装cpolar内网穿透4.2 创建隧道映射4.3 测试公网远程连接 5. 配置固定TCP端口地址5.1 保留一个固定TCP端口地址5.2 配置固定TCP端口地址5.3 测试固定公网地址远程 前言 远程…

【二分】搜索旋转数组

文章目录 不重复数组找最小值,返回下标重复数组找最小值,返回下标不重复数组找target,返回下标重复数组找target,返回bool重复数组找target,返回下标 不重复数组找最小值,返回下标 class Solution {public …

Windows下 MySql通过拷贝data目录迁移数据库的方法

MySQL数据库的文件目录下图所示, 现举例说明通过COPY文件夹data下数据库文件,进行数据拷贝的步骤;源数据库运行在A服务器上,拷贝到B服务器,假定B服务器上MySQL数据库已经安装完成,为空数据库。 首先进入A服…

华为云渲染实践

// 编者按:云计算与网络基础设施发展为云端渲染提供了更好的发展机会,华为云随之长期在自研图形渲染引擎、工业领域渲染和AI加速渲染三大方向进行云渲染方面的探索与研究。本次LiveVideoStackCon 2023上海站邀请了来自华为云的陈普,为大家分…

百度“AI智障”到AI智能体验之旅

目录 前言一、百度PLATO1.抬杠第一名2.听Ta瞎扯淡3.TA当场去世了4.智障与网友的高光时刻 二、文心一言1.设计测试用例2.随意发问3.手机端约会神器 三、体验总结:四、千帆大模型 前言 最近收到了文心一言3.5大模型的内测资格,正巧之前也体验过它的前身&q…

Request对象和response对象

一、概念 request对象和response对象是通过Servlet容器(如Tomcat)自动创建并传递给Servlet的。 Servlet容器负责接收客户端的请求,并将请求信息封装到request对象中,然后将request对象传 递给相应的Servlet进行处理。类似地&…

SpringBoot入门篇1 - 简介和工程创建

目录 SpringBoot是由Pivotal团队提供的全新框架, 其设计目的是用来简化Spring应用的初始搭建以及开发过程。 1.创建入门工程案例 ①创建新模块,选择Spring初始化,并配置模块相关基础信息 ②开发控制器类 controller/BookController.jav…

短视频矩阵系统接口部署技术搭建

前言 短视频矩阵系统开发涉及到多个领域的技术,包括视频编解码技术、大数据处理技术、音视频传输技术、电子商务及支付技术等。因此,短视频矩阵系统开发人员需要具备扎实的计算机基础知识、出色的编程能力、熟练掌握多种开发工具和框架,并掌握…

全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据!

全套解决方案:基于pytorch、transformers的中文NLP训练框架,支持大模型训练和文本生成,快速上手,海量训练数据! 1.简介 目标:基于pytorch、transformers做中文领域的nlp开箱即用的训练框架,提…

开源网安受邀参加软件供应链安全沙龙,推动企业提升安全治理能力

​8月23日下午,合肥软件行业软件供应链安全沙龙在中安创谷科技园举办。此次沙龙由合肥软件产业公共服务中心联合中安创谷科技园公司共同主办,开源网安软件供应链安全专家王晓龙、尹杰受邀参会并带来软件供应链安全方面的精彩内容分享,共同探讨…

政府网站定期巡检:构建高效、安全与透明的数字政务

在数字时代,政府网站已不仅仅是一个信息发布窗口,更是政府与公众互动的桥梁、政务服务的主要渠道以及数字化治理的重要平台。因此,确保政府网站的高效运行、信息安全与透明公开就显得尤为重要。在此背景下,定期的网站巡检与巡查成…

xfs ext4 结合lvm 扩容、缩容 —— 筑梦之路

ext4 文件系统扩容、缩容操作 扩容系统根分区 根文件系统在 /dev/VolGroup/lv_root 逻辑卷上,文件系统类型为ext4,大小为10G,现在要将其扩容成20G。 给空闲空间分区# 调整分区类型为LVM,也就是8e类型 fdisk /dev/sdb# 选定分区后使…

2023年高教社杯 国赛数学建模思路 - 案例:FPTree-频繁模式树算法

文章目录 算法介绍FP树表示法构建FP树实现代码 建模资料 ## 赛题思路 (赛题出来以后第一时间在CSDN分享) https://blog.csdn.net/dc_sinor?typeblog 算法介绍 FP-Tree算法全称是FrequentPattern Tree算法,就是频繁模式树算法&#xff0c…

JavaScript函数调用其他函数

在JavaScript中,函数可以调用其他函数。这通常被称为函数组合,它允许你通过将较简单的函数组合在一起来创建更复杂的功能。 例如:还是以之前的水果加工举例,但是现在我们需要输出,这个苹果有几块,橘子有几块…

计算机竞赛 基于大数据的时间序列股价预测分析与可视化 - lstm

文章目录 1 前言2 时间序列的由来2.1 四种模型的名称: 3 数据预览4 理论公式4.1 协方差4.2 相关系数4.3 scikit-learn计算相关性 5 金融数据的时序分析5.1 数据概况5.2 序列变化情况计算 最后 1 前言 🔥 优质竞赛项目系列,今天要分享的是 &…

基于Java的旅游信息推荐系统设计与实现,springboot+vue,MySQL数据库,前后端分离,完美运行,有三万字论文。

基于Java的旅游信息推荐系统设计与实现,springbootvue,MySQL数据库,前后端分离,完美运行,有三万字论文。 前台主要功能:登录注册、旅游新闻、景区信息、美食信息、旅游线路、现在留言、收藏、预定旅游线路…

CAD打开对象捕捉设置的快捷键是什么?

CAD打开对象捕捉设置的快捷键是什么呢?今天就教大家如何操作。 方法 打开对象捕捉设置的快捷命令:SE。空格确定即可。 也可以输入快捷命令:DS也一样可以打开对象捕捉设置。血糖测试仪什么牌子好?盘点血糖检测仪的三大品牌! | 共…

visual studio 2022.NET Core 3.1 未显示在目标框架下拉列表中

问题描述 在Visual Studio 2022我已经安装了 .NET core 3.1 并验证可以运行 .NET core 3.1 应用程序,但当创建一个新项目时,目标框架的下拉列表只允许 .NET 6.0和7.0。而我在之前用的 Visual Studio 2019,可以正确地添加 .NET 核心项目。 …