权重衰减(Weight Decay)

       在深度学习中,权重衰减(Weight Decay)是一种常用的正则化技术,旨在减少模型的过拟合现象。权重衰减通过向损失函数添加一个正则化项,以惩罚模型中较大的权重值。

一、权重衰减

       在深度学习中,模型的训练过程通常使用梯度下降法(或其变种)来最小化损失函数。梯度下降法的目标是找到损失函数的局部最小值,使得模型的预测能力最好。然而,当模型的参数(即权重)过多或过大时,容易导致过拟合问题,即模型在训练集上表现很好,但在测试集上表现较差。

       权重衰减通过在损失函数中引入正则化项来解决过拟合问题。正则化项通常使用L1范数或L2范数来度量模型的复杂度。L2范数正则化(也称为权重衰减)是指将模型的权重的平方和添加到损失函数中,乘以一个较小的正则化参数$ \lambda $这个额外的项迫使模型学习到较小的权重值,从而减少模型的复杂度。

       具体而言,对于一个深度学习模型的损失函数$L(w, b)$,其中$w,b$表示模型的参数(权重和偏置),权重衰减可以通过以下方式实现:

$ L'\left( w,b \right) =L\left( w,b \right) +\lambda \cdot \lVert w \rVert ^2 $

       其中,$ L'\left( w,b \right) $是添加了权重衰减的损失函数,$ \lVert w \rVert ^2 $表示参数的L2范数的平方和,$ \lambda $是正则化参数,用于控制正则化项的重要性。

       在训练过程中,梯度下降法将同时更新损失函数和权重。当计算梯度时,权重衰衰减的正则化项将被添加到梯度中,从而导致权重更新的幅度减小。这使得模型的权重趋向于减小,避免过拟合现象。

       需要注意的是,正则化参数$ \lambda $的选择对模型的性能有重要影响。较小的$ \lambda $值会导致较强的正则化效果,可能会使模型欠拟合。而较大的$ \lambda $值可能会减少正则化效果,使模型过拟合。因此,选择合适的正则化参数是权衡模型复杂度和泛化能力的关键。

       偏置(biases)在神经网络中起到平移激活函数的作用,通常不会像权重那样导致过度拟合。偏置的主要作用是调整激活函数的位置,使其更好地对应所需的输出。由于偏置的影响较小,因此将权重衰减应用于偏置通常不是常见的做法。

二、权重衰减数学解释

       L2范数正则化在解决过拟合问题方面具有一定的效果,这是因为它在损失函数中引入了权重的平方和作为正则化项。下面我将解释一下L2范数正则化的数学原理。

       在深度学习中,我们的目标是最小化损失函数,该函数包括两部分:经验误差和正则化项。对于L2范数正则化,我们将正则化项定义为权重的平方和的乘以一个正则化参数$ \lambda $

       针对损失函数$ L'\left( w,b \right)$,我们使用梯度下降法来最小化这个损失函数。在梯度下降的每一步中,我们计算损失函数的梯度,然后更新权重。对于L2范数正则化,梯度的计算中包含了正则化项的贡献。

       具体来说,我们计算损失函数对权重w的梯度,记为$ \nabla L\left( w,b \right) $。那么加入L2范数正则化后的梯度可以写为:

$ \nabla L'\left( w,b \right) =\nabla L\left( w,b \right) +2\lambda w $

       这里,$ 2\lambda w $是正则化项的梯度贡献,其中$ 2\lambda $是正则化参数$ \lambda $的倍数,$w$是权重的梯度。

       当我们使用梯度下降法更新权重时,梯度的负方向指示了损失函数下降的方向。由于L2范数正则化项的存在,权重的梯度会受到惩罚,从而导致权重的更新幅度减小。

       这种减小权重更新幅度的效果使得模型倾向于学习到较小的权重值,从而降低了模型的复杂度。通过减小权重的幅度,L2范数正则化可以有效地控制模型的过拟合,提高模型的泛化能力。

       总结起来,L2范数正则化通过引入权重的平方和作为正则化项,在梯度计算和权重更新中对权重进行惩罚,从而减小了模型的复杂度,防止过拟合现象的发生。

也可以参考李沐老师的课件:

三、代码从零开始实现

import torch
from torch import nn
from d2l import torch as d2l

1、生成数据

       首先,我们像以前一样生成一些数据,生成公式如下:

$y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2).$

       我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。为了使过拟合的效果更加明显,我们可以将问题的维数增加到$d = 200$(w的长度为200),并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5   # 训练集长度为20、验证机长度为100、权重参数有200个、批量大小为5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05   # 真实的权重和偏置
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

2、初始化模型参数

       我们将定义一个函数来随机初始化模型参数。

def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]

3、定义L2范数惩罚

       实现这一惩罚最方便的方法是对所有项求平方后并将它们求和。

def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

4、定义训练代码实现

       下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。和之前线性回归一样,线性网络和平方损失没有变化,所以我们通过`d2l.linreg`和`d2l.squared_loss`导入它们。唯一的变化是损失现在包括了惩罚项。

def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())

5、忽略正则化直接训练

       我们现在用`lambd = 0`禁用权重衰减后运行这个代码。注意,这里训练误差有了减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd=0)
w的L2范数是: 12.963241577148438

 

6、使用权重衰减

       下面,我们使用权重衰减来运行代码。注意,在这里训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd=3)
w的L2范数是: 0.3556520938873291

 

四、简洁实现

       由于权重衰减在神经网络优化中很常用,深度学习框架为了便于我们使用权重衰减,将权重衰减集成到优化算法中,以便与任何损失函数结合使用。此外,这种集成还有计算上的好处,允许在不增加任何额外的计算开销的情况下向算法中添加权重衰减。由于更新的权重衰减部分仅依赖于每个参数的当前值,因此优化器必须至少接触每个参数一次。

1、定义训练代码实现

       在下面的代码中,我们在实例化优化器时直接通过`weight_decay`指定weight decay超参数。默认情况下,PyTorch同时衰减权重和偏移。这里我们只为权重设置了`weight_decay`,所以偏置参数$b$不会衰减。

def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd}, {"params":net[0].bias}],
                              lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

2、忽略正则化直接训练

train_concise(0)
w的L2范数: 13.727912902832031

3、使用权重衰减

train_concise(3)
w的L2范数: 0.3890590965747833

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/250813.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

leetcode移除元素

移除元素 题目分析题解代码:c版本c版本 题目 给你一个数组 nums 和一个值 val,你需要 原地 移除所有数值等于 val 的元素,并返回移除后数组的新长度。 不要使用额外的数组空间,你必须仅使用 O(1) 额外空间并 原地 修改输入数组。 元素的顺…

六大场景36种数据分析模型及方法图示,数据分析师必备!

【关注微信公众号:跟强哥学SQL,回复“笔试”免费领取大厂SQL笔试题。】 我一直认为,实际工作中,精通数据分析工具仅仅只是数据思维训练的一部分,掌握丰富的数据分析方法和模型实际上更为重要。 基于科学的数据分析方法…

『PyTorch』张量和函数之gather()函数

文章目录 PyTorch中的选择函数gather()函数 参考文献 PyTorch中的选择函数 gather()函数 import torch a torch.arange(1, 16).reshape(5, 3) """ result: a [[1, 2, 3],[4, 5, 6],[7, 8, 9],[10, 11, 12],[13, 14, 15]] """# 定义两个index…

Tomcat-指定启动jdk、修改使用的jdk版本

修改tomcat配置文件setclasspath.sh 配置文件首行增加以下代码,指定启动的jdk: export JAVA_HOME/opt/softwares/jdk1.8.0_211/ export JRE_HOME/opt/softwares/jdk1.8.0_211/jre

【Kafka】Kafka的重复消费和消息丢失问题

文章目录 前言一、重复消费1.1 重复消费出现的场景1.1.1 Consumer消费过程中,进程挂掉/异常退出1.1.2 消费者消费时间过长 1.2 重复消费解决方案1.2.1 针对于消费端挂掉等原因造成的重复消费问题1.2.2 针对于Consumer消费时间过长带来的重复消费问题 二、消息丢失2.…

【Qt5】QVersionNumber

2023年12月10日,周日上午 QVersionNumber 是 Qt 框架中用于表示版本号的类。它提供了一种方便的方式来处理和比较版本号,特别是在应用程序或库需要与特定版本的依赖项进行交互时。 以下是一个简单的示例,演示了如何使用 QVersionNumber&…

Spring-temp

IOC/DI实现步骤 1.配置元数据 2.实例化IOC 3.获取Bean 基于XML配置方式 管理组件 1.基于构造函数:有参、无参 2.基于静态工厂方法:有参、无参 依赖注入 1.构造函数 2.setter方法 Bean组件高级特性 1.作用域 2.生命周期 FactoryBean 基于注解 IOC Bean作…

Python如何匹配库的版本

目录 1. 匹配库的版本 2. Python中pip,库,编译环境的问题回答总结 2.1 虚拟环境 2.2 pip,安装库,版本 1. 匹配库的版本 (别的库的版本冲突同理) 在搭建pyansys环境的时候,安装grpcio-tools…

加权准确率WA,未加权平均召回率UAR和未加权UF1

加权准确率WA,未加权平均召回率UAR和未加权UF1 1.加权准确率WA,未加权平均召回率UAR和未加权UF12.参考链接 1.加权准确率WA,未加权平均召回率UAR和未加权UF1 from sklearn.metrics import classification_report from sklearn.metrics impor…

Python中的程序逻辑经典案例详解

我的博客 文章首发于公众号:小肖学数据分析 Python作为一种强大的编程语言,以其简洁明了的语法和强大的标准库,成为了理想的工具来构建这些解决方案。 本文将通过Python解析几个经典的编程问题。 经典案例 水仙花数 问题描述&#xff1a…

三勾商城新功能-电子面单发货

商家快递发货时可以选择在线下单,在线获取和打印电子面单。免去手写面单信息以及避免填写运单号填错,系统会自动填写对应发货商品的运单信息 快递100电子面单1、进入快递100,点击登录 2、登录成功后,点击“电子面单与云打印” 3、进入电子面单与云打印后…

Druid-spring-boot-starter源码阅读-其余组件自动装配

前面我们看完了整个DruidDataSource初始化流程,但是其实Druid除了最核心的数据源之外,还有其他需要自动配置的,细心的人可能看到了,就是利用Import注解导入的四个类。 DruidFilterConfiguration public class DruidFilterConfigu…

解决:TypeError: write() argument must be str, not tuple

解决:TypeError: write() argument must be str, not tuple 文章目录 解决:TypeError: write() argument must be str, not tuple背景报错问题报错翻译报错位置代码报错原因解决方法今天的分享就到此结束了 背景 在使用之前的代码时,报错&…

缓存的定义及重要知识点

文章目录 缓存的意义缓存的定义缓存原理缓存的基本思想缓存的优势缓存的代价 缓存的重要知识点 缓存的意义 在互联网高访问量的前提下,缓存的使用,是提升系统性能、改善用户体验的唯一解决之道。 缓存的定义 缓存最初的含义,是指用于加速 …

联邦边缘学习中的知识蒸馏综述

联邦边缘学习中的知识蒸馏综述 移动互联网的快速发展伴随着智能终端海量用户数据的产生。如何在保护数据隐私的前提下,利用它们训练出性能优异的机器学习模型,一直是业界关注的难点。为此,联邦学习应运而生,它允许在终端本地训练并协同边缘服务器进行模型聚合来实现分布式机器…

下午好~ 我的论文【yolo1~4】(第二期)

写在前面:本来是一期的,我看了太多内容了,于是分成三期发吧 TAT (捂脸) 文章目录 YOLO系列v1v2v3v4 YOLO系列 v1 You Only Look Once: Unified, Real-Time Object Detection 2015 ieee computer society 12.3 CCF-C…

网站提示“不安全”

当你在浏览网站时,有时可能会遇到浏览器提示网站不安全的情况。这通常是由于网站缺乏SSL证书所致。那么,从SSL证书的角度出发,我们应该如何解决这个问题呢? 首先,让我们简单了解一下SSL证书。SSL证书是一种用于保护网站…

论文阅读——Semantic-SAM

Semantic-SAM可以做什么: 整合了七个数据集: 一般的分割数据集,目标级别分割数据集:MSCOCO, Objects365, ADE20k 部分分割数据集:PASCAL Part, PACO, PartImagenet, and SA-1B The datasets are SA-1B, COCO panopt…

Grafana Loki 快速尝鲜

Grafana Loki 是一个支持水平扩展、高可用的聚合日志系统,跟其他的聚合日志系统不同,Loki只对日志的元数据-标签进行索引,日志数据会被压缩并存储在对象存储中,甚至可以存储在本地文件系统中,能够有效降低成本&#xf…

Python基础09-学生管理系统

零、文章目录 Python基础09-学生管理系统 1、学员管理系统功能概述 (1)最终效果图 (2)功能概述 需求:进入系统显示系统功能界面,功能如下: 【1】添加学员信息->add_student【2】删除学员…