基于python开发用于深度学习模型训练过程loss值曲线的平滑处理模块

深度学习网络模型的loss曲线是训练过程中非常重要的一个监控指标,它能够直观地反映模型的学习状态以及可能存在的问题。以下是对深度学习网络模型loss曲线的详细介绍:

一、loss曲线的基本概念

在深度学习的训练过程中,loss函数用于衡量模型预测结果与实际标签之间的差异。loss曲线则是通过记录每个epoch(或者迭代步数)的loss值,并将其以图形化的方式展现出来,以便我们更好地理解和分析模型的训练过程。

二、loss曲线的解读

  1. loss值的变化趋势:
    • 如果loss值随着训练的进行而逐渐降低,说明模型正在学习并优化,这是一个正常的训练过程。
    • 如果loss值在训练初期迅速下降,但随后趋于稳定或波动较小,可能意味着模型已经收敛,或者陷入了局部最优解。
    • 如果loss值在训练过程中出现剧烈波动,可能是学习率设置不当、模型结构复杂度过高等原因导致的。
  2. 训练和验证loss的对比:
    • 训练loss和验证loss的差距可以反映模型的过拟合程度。如果训练loss持续下降而验证loss却开始上升,说明模型可能出现了过拟合现象。
    • 理想的训练过程应该是训练loss和验证loss都逐渐下降,且两者之间的差距较小。
  3. 不同阶段的loss变化:
    • 在训练初期,由于模型参数是随机初始化的,因此loss值通常会比较大。随着训练的进行,loss值会逐渐降低并趋于稳定。
    • 在训练后期,如果模型没有出现过拟合现象,loss值应该能够稳定在一个较低的水平上。

三、loss曲线的绘制与监控

在深度学习框架(如TensorFlow、PyTorch等)中,通常都提供了绘制loss曲线的功能。通过调用相应的API或库(如matplotlib、Visdom等),我们可以方便地绘制出训练过程中的loss曲线,并对其进行实时监控和分析。

四、loss曲线的优化策略

针对loss曲线反映出的问题,我们可以采取以下优化策略:

  1. 调整学习率:学习率是影响loss曲线变化的重要因素之一。如果学习率设置得过大,可能会导致loss值在训练过程中出现剧烈波动;如果学习率设置得过小,则可能会导致训练过程过于缓慢。因此,我们需要根据loss曲线的变化情况来适时调整学习率的大小。
  2. 添加正则化项:正则化项可以有效地防止模型过拟合。通过向损失函数中添加正则化项(如L1正则化、L2正则化等),我们可以限制模型参数的复杂度,从而降低过拟合的风险。
  3. 使用更复杂的模型结构:如果模型的复杂度不够高,可能无法充分拟合训练数据中的复杂模式。在这种情况下,我们可以尝试使用更复杂的模型结构(如增加网络层数、使用更复杂的激活函数等)来提高模型的拟合能力。
  4. 增加训练数据:增加训练数据可以提供更多的信息供模型学习,从而降低过拟合的风险。如果条件允许的话,我们可以尝试增加训练数据的数量或多样性来提高模型的性能。

实际工作中,经常会需要训练构建深度学习模型,相信做这块工作的同学对于loss曲线一定不会陌生的,大家肯定也都经常在模型开发过程中实际去绘制模型的loss曲线,在一些特殊的场景下需要对原始产生的loss曲线进行平滑处理,这里主要是记录实践这块的内容。

这里我以经常使用的keras框架,介绍下我常用的讲模型训练过程日志进行记录存储的方式,核心代码实现如下:

#记录日志
history = model.fit(
    X_train,
    y_train,
    validation_data=(X_test, y_test),
    #传入回调
    callbacks=[checkpoint],
    epochs=nepoch,
    batch_size=32,
)
print(history.history.keys())
# loss提取
lossdata, vallossdata = history.history["loss"], history.history["val_loss"]
# 绘制loss曲线
plot_both_loss_acc_pic(
    lossdata, vallossdata, picpath=saveDir + "train_val_loss.png"
)
history = {}
#提取训练过程对应的log
history["loss"], history["val_loss"] = lossdata, vallossdata
#存储日志数据
with open(saveDir + "history.json", "w") as f:
    f.write(json.dumps(history))

这里我给出history.json的样例数据,如下所示:

{"loss": [0.0239631230070447, 0.0075705342770514186, 0.004838935030165967, 0.0037340148873459002, 0.002886130001751231, 0.0024534663854011295, 0.0023201104651917924, 0.002976924244579323, 0.002085769131966776, 0.0018753843622715224, 0.0019806173175960172, 0.002174197305382795, 0.001658159012761194, 0.001545024904081888, 0.001667008826952705, 0.0013947380403409929, 0.0012537746476829388, 0.0014786866023657216, 0.0016623390785946131, 0.0016191040555174983, 0.0014966548395261134, 0.001477120676483648, 0.0016280919435364668, 0.0017182350213351422, 0.0038554738028685545, 0.0027464564262130392, 0.0017087835722348526, 0.0014510032096478255, 0.001268975875018749, 0.001481830868523139, 0.001604654047318627, 0.0011948789410770326, 0.001490574051798416, 0.001524109376014187, 0.0015062743931394868, 0.0013054789145924908, 0.0011241542828178905, 0.0010764475793075279, 0.0011480460991939996, 0.0012678029520214276, 0.0012396599495106504, 0.0011639709738934618, 0.0012134075943700145, 0.0012499850020485322, 0.001329989843023205, 0.0011846670753724083, 0.001357856133803473, 0.0015265580290890642, 0.0012421558107066537, 0.001249045898042552, 0.0013697822622925414, 0.0010749583650784015, 0.0010974660338928532, 0.0010916401195769782, 0.0010911698460223627, 0.001078350035803241, 0.001045568893730859, 0.001084814094107926, 0.0011569271895574074, 0.0011443737715600441, 0.001247118570911225, 0.0012540589338988402, 0.0011518743058927274, 0.0015513227900919035, 0.0017111857056945697, 0.0015170943725414776, 0.001481423723410487, 0.0011165965530857377, 0.0016210588698042031, 0.002381780790270182, 0.0011541179547393296, 0.0013710694562572288, 0.0012280710985459404, 0.001037340645381916, 0.0010694707121014697, 0.0009750368017871479, 0.001008019566722004, 0.0011101727457727573, 0.0012511928422843225, 0.001071397447170223, 0.0011470449074543591, 0.0015238439674756194, 0.0010109543884446336, 0.0011297101726488506, 0.001058421874235954, 0.001103364821769398, 0.001025826505723811, 0.0010999036314539848, 0.001329398845137427, 0.0017114742325290903, 0.0011102726525873048, 0.0011274378092930091, 0.0011542693009646294, 0.0011940637438370937, 0.0012636104229160712, 0.0013925317771055863, 0.00100061368093664, 0.0011615896552567776, 0.0010081333990953022, 0.001092779955855081], "val_loss": [0.004923976918584422, 0.00991965542106252, 0.01076323433877214, 0.003843901578434988, 0.011352231488318036, 0.0016448196832482753, 0.0016787166668923179, 0.02015221753696862, 0.003941944209049995, 0.0029026116281257648, 0.0045372380556440665, 0.0014563935330155992, 0.001654032355260202, 0.0013641683946173688, 0.0015195327850769421, 0.0011578488043405262, 0.0012232080662598539, 0.00509458419768826, 0.0012246073744455843, 0.0023663273782738923, 0.0011423173363590124, 0.006865146876263775, 0.0020036918448137217, 0.007316410553788668, 0.001553758288929729, 0.0013593508740948317, 0.0025380967877266045, 0.0023082743653120765, 0.0013224915555359697, 0.00858367411909919, 0.0009927515703326974, 0.0010470627885201553, 0.0011798253622959907, 0.0024045295798905976, 0.0015412836871722614, 0.0038771789925368992, 0.0015362703578399592, 0.001756014192697445, 0.00334732801114258, 0.000975109149983741, 0.0046767660281866, 0.0018946981394516401, 0.0021767043220614524, 0.004211987026869074, 0.0009522750635177975, 0.0021094563270085734, 0.0037733877482088772, 0.001548874757549799, 0.0027838850510306656, 0.008273044527557335, 0.00123940688829048, 0.0016841785786183257, 0.0009756766973479994, 0.001928586675479126, 0.0011492695222075685, 0.0012013394433827336, 0.0010477521618380897, 0.00121309975940293, 0.0030147337820380926, 0.0013649057897150909, 0.0023210468165895067, 0.0011219763923068775, 0.0017544153219971217, 0.0030385015789713516, 0.0016239398731206739, 0.0031037202962723217, 0.002162101651590906, 0.003717466969484169, 0.0033957000386803165, 0.0009902583321826043, 0.00193247984708777, 0.001976198960389746, 0.0027693257654870028, 0.0025635553493262514, 0.0013357499459648113, 0.0012082410958675226, 0.001168333794186382, 0.0025652966841957286, 0.0010059437105634347, 0.0009358364489775054, 0.0036403173617528457, 0.0009317236960142556, 0.0015049418612187238, 0.0017247698554695634, 0.0010254738238903596, 0.001047871537096063, 0.0009514076437641818, 0.0036001800608478096, 0.014663169037942824, 0.002012193938227076, 0.001970677826351388, 0.0037272977164799445, 0.0012484785829560438, 0.002330363199182198, 0.0011025683723059237, 0.0013020975239525893, 0.001059662765137067, 0.0009167807317632986, 0.0009290355350362676, 0.0012791703282058924]}

接下来看下原始loss数据绘制出来的对比可视化曲线:

整体波形不断,也反映了模型实际训练过程并不够稳定,这里抛开模型训练的因素,单纯地基于曲线数据进行分析,想要对其进行平滑处理,得到的效果如下:

核心实现就是使用scipy.signal.savgol_filter方法,scipy.signal.savgol_filter 是 SciPy 库中的一个函数,用于对一维数据序列应用 Savitzky-Golay 平滑滤波器。这个滤波器是一种局部多项式回归的技术,能够在平滑数据的同时尽量保留数据的特征形状,如峰值和谷值,因此特别适用于信号去噪和数据平滑处理,尤其是那些包含噪声的实验数据或时序数据。

以下是 savgol_filter 函数的主要参数及其说明:

  • x (array_like):要过滤的一维数据序列。如果 x 不是单精度或双精度浮点数数组,它将在过滤前被转换为这种类型。

  • window_length (int):滤波器窗口的长度,即应用于数据点上的局部多项式拟合所使用的相邻数据点的数量。这个值必须是奇数,并且 polyorder + 1 <= window_length

  • polyorder (int):拟合局部数据点的多项式的阶数。它决定了平滑程度,阶数越高,可以拟合更复杂的曲线,但也会更多地改变原始数据的特性。必须满足 polyorder < window_length

  • deriv (int, 可选):指定是否计算导数以及计算哪阶导数。默认为0,表示直接平滑数据;大于0的值用于计算相应阶数的导数。

  • delta (float, 可选):采样点之间的间距,默认为1.0。仅在计算导数(deriv > 0)时使用。

  • axis (int, 可选):当输入数据 x 的维度大于1时,指定沿哪个轴应用滤波器。默认为-1,表示最后一个轴。

  • mode (str, 可选):决定如何处理边界效应,可选值有 'mirror''constant''nearest''wrap''interp'。默认为 'interp',表示通过线性插值来扩展数据以处理边界。选择 'mirror' 会在边界处镜像数据,而 'constant' 则会使用边缘值填充。

  • cval (float, 可选):当 mode'constant' 时使用的常数值。默认为0.0。

使用示例

假设我们有一个包含噪声的一维数据列表 data,我们可以使用 savgol_filter 来平滑这些数据:

from scipy.signal import savgol_filter
import numpy as np

# 假设 data 是一个包含噪声的数据序列
data = np.random.randn(100)  # 生成随机噪声数据作为示例
window_length = 5  # 窗口长度
polyorder = 3  # 多项式阶数

# 应用 Savitzky-Golay 滤波器
smoothed_data = savgol_filter(data, window_length, polyorder)

# 然后可以绘制原始数据和平滑后的数据进行对比
import matplotlib.pyplot as plt

plt.figure()
plt.plot(data, label='Noisy data')
plt.plot(smoothed_data, label='Smoothed data')
plt.legend()
plt.show()

借助于scipy.signal.savgol_filter方法,我们可以非常方便快捷地实现对原生loss曲线的平滑化处理,这里为了直观对比效果,我们绘制对比可视化曲线,如下所示:

有需要的也都可以尝试下。

完整代码实现如下:

def lossPloter(train_loss,val_loss):
    """
    loss曲线对比可视化
    """
    iters = range(len(train_loss))
    #单独绘制原始loss曲线
    plt.clf()
    plt.figure(figsize=(10,6))
    plt.plot(iters, train_loss, 'red', linewidth = 2, label='train loss')
    plt.plot(iters, val_loss, 'coral', linewidth = 2, label='val loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('A Loss Curve')
    plt.legend(loc="upper right")
    plt.savefig("original_loss.png")
    num = 5 if len(train_loss)<25 else 15
    #插值平滑处理
    train_loss_smooth=scipy.signal.savgol_filter(train_loss, num, 3)
    val_loss_smooth=scipy.signal.savgol_filter(val_loss, num, 3)
    for i in range(5):
        val_loss_smooth=scipy.signal.savgol_filter(val_loss_smooth, num, 3)
    #二者同时绘制
    plt.clf()
    plt.figure(figsize=(10,6))
    plt.plot(iters, train_loss, 'red', linewidth = 2, label='train loss')
    plt.plot(iters, val_loss, 'coral', linewidth = 2, label='val loss')
    plt.plot(iters, train_loss_smooth, 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
    plt.plot(iters, val_loss_smooth, '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('A Loss Curve')
    plt.legend(loc="upper right")
    plt.savefig("compare_loss.png")
    plt.cla()
    plt.close("all")
    #单独绘制平滑曲线
    plt.clf()
    plt.figure(figsize=(10,6))
    plt.plot(iters, train_loss_smooth, 'green', linewidth = 2, label='smooth train loss')
    plt.plot(iters, val_loss_smooth, 'blue', linewidth = 2, label='smooth val loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Loss Curve')
    plt.legend(loc="upper right")
    plt.savefig("smooth_loss.png")

会得到三幅图像:
original_loss.png: 原始loss对比曲线

smooth_loss.png: 平滑化的loss对比曲线

compare_loss.png: 二者对比曲线

感兴趣的话可以尝试下!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/661111.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

23种设计模式之一— — — —装饰模式详细介绍与讲解

装饰模式详细讲解 一、定义二、装饰模式结构核心思想模式角色模式的UML类图应用场景模式优点模式缺点 实例演示图示代码演示运行结果 一、定义 装饰模式&#xff08;别名&#xff1a;包装器&#xff09; 装饰模式&#xff08;Decorator Pattern&#xff09;是结构型的设计模式…

【PB案例学习笔记】-12秒表实现

写在前面 这是PB案例学习笔记系列文章的第11篇&#xff0c;该系列文章适合具有一定PB基础的读者。 通过一个个由浅入深的编程实战案例学习&#xff0c;提高编程技巧&#xff0c;以保证小伙伴们能应付公司的各种开发需求。 文章中设计到的源码&#xff0c;小凡都上传到了gite…

云原生架构内涵_3.主要架构模式

云原生架构有非常多的架构模式&#xff0c;这里列举一些对应用收益更大的主要架构模式&#xff0c;如服务化架构模式、Mesh化架构模式、Serverless模式、存储计算分离模式、分布式事务模式、可观测架构、事件驱动架构等。 1.服务化架构模式 服务化架构是云时代构建云原生应用的…

【Java用法】java中计算两个时间差

java中计算两个时间差 不多说&#xff0c;直接上代码&#xff0c;可自行查看示例 package org.example.calc;import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import java.time.temporal.ChronoUnit;public class MinusTest {public static void…

迅睿CMS邮箱设置QQ邮箱为例

邮箱设置 1、服务器地址两个&#xff0c;普通与企业。 普通&#xff1a;ssl://smtp.qq.com企业&#xff1a;ssl://smtp.exmail.qq.com 2、端口号为&#xff1a;465 3、邮箱账号&#xff1a;填写自己的QQ邮箱作为发布服务器。 4、邮箱密码&#xff1a;到QQ邮箱账号中获取“…

c++编程(15)——list的模拟实现

欢迎来到博主的专栏——c编程 博主ID&#xff1a;代码小豪 文章目录 前言list的数据结构list的默认构造尾插与尾删iterator插入和删除构造、析构、赋值copy构造initializer_list构造operator 析构函数 前言 受限于博主当前的技术水平&#xff0c;暂时还不能模拟实现出STL当中用…

C语言数据结构堆排序、向上调整和向下调整的时间复杂度的计算、TopK问题等的介绍

文章目录 前言一、堆排序1. 排升序&#xff08;1&#xff09;. 建堆&#xff08;2&#xff09;. 排序 2. 拍降序&#xff08;1&#xff09;. 建堆&#xff08;2&#xff09;. 排序 二、建堆时间复杂度的计算1. 向上调整时间复杂度2. 向下调整时间复杂度 三、TopK问题总结 前言 …

数据库设计实例---学习数据库最重要的应用之一

一、引言【可忽略】 在学习“数据库系统概述”这门课程时&#xff0c;我一直很好奇&#xff0c;这样一门必修课&#xff0c;究竟教会了我什么呢&#xff1f; 由于下课后&#xff0c;&#xff0c;没有拓展自己的眼界&#xff0c;上课时又局限于课堂上老师的讲课水平&#xff0c;…

Java+mysql酒店管理系统

1&#xff0e;引言 1.1编写的目的 本文档为酒店管理系统需求分析报告&#xff0c;为酒店管理系统的设计的主要依据&#xff0c;主要针对酒店管理系统的概要设计和详细设计人员&#xff0c;作为项目验收的主要依据。 1.2背景 本软件全称为阳光酒店管理系统。 1.3 参考资料 …

Windows和Linux系统部署Docker(2)

目录 一、Linux系统部署docker 前置环境&#xff1a; 1.安装需要的软件包&#xff0c; yum-util 提供yum-config-manager功能 2.添加阿里云 docker-ce 仓库 3.安装docker软件包 4.启动 docker并设置开机自启 5.查看版本&#xff1a; 二、windows系统部署docker 1.查看…

.NET 直连SAP HANA数据库

前言 上个项目碰到的需求&#xff0c;IT部门要求直连SAP的HANA数据库&#xff0c;以只读的权限读取SAP部门开发的CDS视图&#xff0c;是个有点复杂的工程&#xff0c;需要从成品一直往前追溯到原材料的产地&#xff0c;和交货单、工单、采购订单有相当程度上的关联 IT部门要求…

代码随想录算法训练营第五十四天||392.判断子序列、115.不同的子序列

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、392.判断子序列 思路 二、115.不同的子序列 思路 一、392.判断子序列 给定字符串 s 和 t &#xff0c;判断 s 是否为 t 的子序列。 字符串的一个子序列是…

Money Trees

思路分析: 利用双指针 l1始终作为起点,ri,不断更新终点 #include<iostream> #include<cstring> #include<string> #include<algorithm> #define int long long using namespace std; int w[2000005],h[2000005],s[2000005]; int t,n,m,l,r; signed m…

信息学奥赛初赛天天练-15-阅读程序-深入解析二进制原码、反码、补码,位运算技巧,以及lowbit的神奇应用

更多资源请关注纽扣编程微信公众号 1 2021 CSP-J 阅读程序1 阅读程序&#xff08;程序输入不超过数组或字符串定义的范围&#xff1b;判断题正确填 √&#xff0c;错误填&#xff1b;除特 殊说明外&#xff0c;判断题 1.5 分&#xff0c;选择题 3 分&#xff09; 源码 #in…

什么是访问控制漏洞

什么是AC Bugs&#xff1f; 实验室 Vertical privilege escalation 仅通过隐藏目录/判断参数来权限控制是不安全的&#xff08;爆破url/爬虫/robots.txt/Fuzz/jsfinder&#xff09; Unprotected functionality 访问robots.txt 得到隐藏目录&#xff0c;访问目录 &#xff0c;…

使用Jmeter进行性能测试的基本操作方法

&#x1f525; 交流讨论&#xff1a;欢迎加入我们一起学习&#xff01; &#x1f525; 资源分享&#xff1a;耗时200小时精选的「软件测试」资料包 &#x1f525; 教程推荐&#xff1a;火遍全网的《软件测试》教程 &#x1f4e2;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1…

长难句打卡5.29

Today, professors routinely treat the progressive interpretation of history and progressive public policy as the proper subject of study while portraying conservative or classical liberal ideas — such as free markets and self-reliance — as falling outsid…

【SPSS】基于因子分析法对水果茶调查问卷进行分析

&#x1f935;‍♂️ 个人主页&#xff1a;艾派森的个人主页 ✍&#x1f3fb;作者简介&#xff1a;Python学习者 &#x1f40b; 希望大家多多支持&#xff0c;我们一起进步&#xff01;&#x1f604; 如果文章对你有帮助的话&#xff0c; 欢迎评论 &#x1f4ac;点赞&#x1f4…

No input file specified.(‘.user.ini’文件问题宝塔复制到本地,其他情况可跳过)

症状 病因 一般是宝塔直接copy到本地的情况。 宝塔面板中的.user.ini文件是一个重要的配置文件&#xff0c;它主要用于配置PHP运行环境和网站环境。以下是.user.ini文件的主要作用和操作建议&#xff1a; 防止跨目录访问和文件跨目录读取。这是.user.ini文件的主要作用之一&a…

采用Java+ SpringBoot+ IntelliJ+idea开发的ADR药物不良反应监测系统源码

采用Java SpringBoot IntelliJidea开发的ADR药物不良反应监测系统源码 ADR药物不良反应监测系统有哪些应用场景&#xff1f; ADR药物不良反应监测系统有哪些应用场景&#xff1f; ADR药物不良反应监测系统具有广泛的应用场景&#xff0c;以下是一些主要的应用场景&#xff1a…