【深度学习笔记】09 权重衰减

09 权重衰减

    • 范数和权重衰减
    • 利用高维线性回归实现权重衰减
    • 权重衰减的简洁实现

范数和权重衰减

在训练参数化机器学习模型时,权重衰减(decay weight)是最广泛应用的正则化技术之一,它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,
因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0
在某种意义上是最简单的。

一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx
中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
要保证权重向量比较小,
最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
将原来的训练目标最小化训练标签上的预测损失,
调整为最小化预测损失和惩罚项之和。

损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

x ( i ) \mathbf{x}^{(i)} x(i)是样本 i i i的特征,
y ( i ) y^{(i)} y(i)是样本 i i i的标签,
( w , b ) (\mathbf{w}, b) (w,b)是权重和偏置参数。

为了惩罚权重向量的大小,
必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
我们通过正则化常数 λ \lambda λ来描述这种权衡,
这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ = 0 \lambda = 0 λ=0,我们恢复了原来的损失函数。
对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。
这里我们仍然除以 2 2 2:当我们取一个二次函数的导数时,
2 2 2 1 / 2 1/2 1/2会抵消。

通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w
然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减
我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。
与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。
较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w
而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。

是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。
通常,网络输出层的偏置项不会被正则化。

利用高维线性回归实现权重衰减

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

首先生成数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ  where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

选择标签是关于输入的线性函数。
标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200
并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

初始化模型参数

定义一个函数来随机初始化模型参数

def init_params():
    w = torch.normal(0, 1, size = (num_inputs, 1), requires_grad = True)
    b = torch.zeros(1, requires_grad = True)
    return [w, b]

定义 L 2 L_2 L2范数惩罚

def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

定义训练代码实现

下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())

忽略正则化直接训练

用lamdb=0禁用权重衰减后运行代码。此时训练误差有所减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd = 0)
w的L2范数是: 14.971677780151367

在这里插入图片描述

使用权重衰减

使用权重衰减来运行代码。此时训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd = 3)
w的L2范数是: 0.34405317902565

在这里插入图片描述

权重衰减的简洁实现

在实例化优化器时直接通过weight_decay指定weight decay超参数。默认情况下,PyTorch同时衰减权重和便宜。这里只为权重设置了weight_decay,所以偏置参数 b b b不会衰减。

def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)
w的L2范数: 13.416662216186523

在这里插入图片描述

train_concise(3)
w的L2范数: 0.39273694157600403

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/220281.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HITOS_LAB5 进程运行轨迹的跟踪与统计

5. 进程运行轨迹的跟踪与统计 5.1. 实验目的 掌握 Linux 下的多进程编程技术;通过对进程运行轨迹的跟踪来形象化进程的概念;在进程运行轨迹跟踪的基础上进行相应的数据统计,从而能对进程调度算法进行实际的量化评价, 更进一步加…

通过时间交织技术扩展ADC采样速率的简要原理

前言 数据采集是将自然界中存在的模拟信号通过模数转换器(ADC)转换成数字信号,再对该数字信号进行相应的接收和处理。数据采集系统作为数据采集的手段,在移动通信、图向采集、无线电等领域有重要作用。随着电子信息技术的飞速发展…

8_企业架构缓存中间件分布式memcached

企业架构缓存中间件分布式memcached 学习目标和内容 1、能够理解描述网站业务访问流程 2、能够理解网站业务的优化方向 3、能够描述内存缓存软件Memcached的作用 4、能够通过命令行操作Memcached 5、能够操作安装php的memcached扩展 extension 6、能够实现session存储到memcach…

6G的Java软件安装包和2G的Maven仓库分享给大家

文章目录 🔊博主介绍🥤本文内容📢文章总结📥博主目标 🔊博主介绍 🌟我是廖志伟,一名Java开发工程师、Java领域优质创作者、CSDN博客专家、51CTO专家博主、阿里云专家博主、清华大学出版社签约作…

轨道交通故障预测与健康管理PHM系统的应用

轨道交通是现代城市中不可或缺的交通方式,它为人们提供了快速、高效和可靠的出行方式。然而,由于轨道交通系统的复杂性和高负荷运行,设备故障和运营中断问题时有发生。为了提高轨道交通系统的可靠性和安全性,故障预测与健康管理&a…

一文读懂中间件

前言:在程序猿的日常工作中, 经常会提到中间件,然而大家对中间件的理解并不一致,导致了一些不必要的分歧和误解。“中间件”一词被用来描述各种各样的软件产品,在不同文献中有着许多不同的中间件定义,包括操…

花店小程序商城制作攻略教程分享

现如今,随着互联网的快速发展,越来越多的实体店面对客流量不足的问题。特别是对于花店来说,客流量的多少直接影响着销售额和收益。为了解决这一问题,开发一个花店小程序商城成为了不可忽视的选择。 为了开发花店小程序商城&#x…

使用Docker在Debian上构建GRBL模拟器镜像:简明步骤和操作指南

概述编译编写 Dockerfile构建镜像运行测试其他 概述 本文将详细介绍如何在Debian系统上通过Docker构建GRBL模拟器镜像,以便进行数控机床的仿真测试。GRBL是一种开源的控制系统,用于控制三轴CNC机床、激光雕刻、激光切割,而在Docker容器中运…

Vue 官方周报 #122 - 如何使用Head插件

Hi 👋 本周的问题中,您将学习在Vue中如何使用Head插件。 unhead是一个与框架无关的文档头管理器,您可以使用它来管理页面元数据,如 Vue应用程序中的标题。 它用于Nuxt核心,是UnJS生态系统的一部分。 安装 首先&…

老师怎样夸学生

老师夸学生可以从以下几个方面入手: 1. 表扬学生的思维深度和独立思考能力。如果学生在文章中有独特的思考角度和深度的思考,老师可以直接点出来赞扬。 2. 赞美学生的语言表达。如果学生的文章用词精准、文笔流畅,老师可以夸奖学生的语言表达…

在外包待了6年,技术退步太明显......

先说情况,大专毕业,18年通过校招进入湖南某软件公司,干了接近6年的功能测试,今年年初,感觉自己不能够在这样下去了,长时间呆在一个舒适的环境会让一个人堕落!而我已经在一个企业干了四年的功能测试&#xf…

【“C++ 精妙之道:解锁模板奇谭与STL精粹之门“】

【本节目标】 1. 泛型编程 2. 函数模板 3. 类模板 4. 什么是STL 5. STL的版本 6. STL的六大组件 7. STL的重要性 8. 如何学习STL 9.STL的缺陷 1. 泛型编程 如何实现一个通用的交换函数呢? void Swap(int& left, int& right) {int temp left;lef…

运维04:nginx

源代码编译安装nginx yum工具安装:自动下载nginx,且安装到固定的位置源代码编译安装:更适用于专业的企业服务器环境 比起yum工具安装,会有更多额外的功能可以自定义安装路径、配置文件 安装环境 源代码编译安装(该方…

软件性能测试之压力测试详解

压力测试 压力测试是一种软件测试,用于验证软件应用程序的稳定性和可靠性。压力测试的目标是在极其沉重的负载条件下测量软件的健壮性和错误处理能力,并确保软件在危急情况下不会崩溃。它甚至可以测试超出正常工作点的测试,并评估软件在极端条…

卡通渲染总结《二》

关于技术的方面,一方面就是其轮廓边缘检测: 主要的方法可以被分为基于图片空间和对象空间,对象空间比图片空间会多一些立体坐标位置的信息。 轮廓类型分类 首先我们顶一下轮廓是什么,从一个视角看去如果一条边相邻的两个面其恰…

SpringSecurity6 | 默认用户生成

SpringSecurity6 | 默认用户生成 ✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: Java…

加密市场进入牛初阶段?一场新的造富效应即将拉开帷幕!

周一(12月4日),比特币一度上涨至42000美元,创下自2022年4月以来的最高水平。从目前比特币的走势来看,加密市场无疑已然进入到牛初阶段。 在牛市初期,确实存在人们不相信牛市到来的情况。由于在熊市中亏损的心理阻碍和对市场进一步…

ROS2教程03 ROS2节点

ROS2节点 版权信息 Copyright 2023 Herman YeAuromix. All rights reserved.This course and all of its associated content, including but not limited to text, images, videos, and any other materials, are protected by copyright law. The author holds all right…

在机器学习或者深度学习中是否可以直接分为训练集和测试集而不需要验证集?我的答案如下:

文章目录 一、训练集是什么?二、验证集是什么?三、测试集是什么?四、是否可以直接分为训练集和测试集而不需要验证集?总结 在机器学习和深度学习项目中,通常会将数据集划分为三个部分:训练集,验…

python精细讲解,从代码出发,适合新手宝宝食用的python入门教学【持续更新中】

文章目录 1、输入输出1.1 输入语句1.2 输出语句 2、List列表操作2.1 取值取单个元素:[]取出现的第一个元素:index 2.2 添加操作追加:append插入:insert 2.3 删除操作removepopdelclear清空 copy复制操作列表相关的数学操作数数&am…