【机器学习】036_权重衰退

一、范数

· 定义:向量的范数表示一个向量有多大(分量的大小)

L1范数:

        · 即向量元素绝对值之和,用符号 ‖ v ‖ 1 表示。

        · 公式:\left \| x \right \|_1 = \sum_{n}^{i=1}|x_i|

L2范数:

        · 即向量的模,向量各元素绝对值的平方之和再开根号,用符号 ‖ v ‖ 2 表示。

        · 公式:\left \| x \right \|_2=\sqrt{\sum_{n}^{i=1}x_i^2}

Lp范数:

        · 即向量范数的一般形式,各元素绝对值的p次幂之和再开p次根号,用符号 ‖ v ‖ p 表示。

        · 公式:\left \| x \right \|_p = (\sqrt[p]{\sum_{n}^{i=1}|x|^p})

二、权重衰减(L2正则化)

模型(函数)复杂度的度量:

· 一般通过线性函数 f(x) = w^Tx 中的权重向量的某个范数(如 \left \| w \right \|^2)来度量其复杂度

要想避免模型的过拟合,就要控制模型容量,使模型的权重向量尽可能小

· 通过限制参数值的选择范围来控制模型容量

衰减方法:

借助损失函数,将权重范数作为惩罚项添加到最小化损失中;使得损失函数的作用变为“最小化预测损失和惩罚项之和”。

损失函数公式如下:

J(w,b)=L(w,b)+\frac{\lambda }{2}\left \| w \right \|^2

· 其中,L(w,b) 是模型原本的损失函数,\frac{\lambda }{2}\left \| w \right \|^2 是新添加的惩罚项。

· 正则化常数 \lambda 用来描绘这种权衡,其为一个非负超参数。

· \lambda 的值越大,表示对 w 的约束较大;反之 \lambda 的值越小,表示对 w 的约束较小。

※为何选用平方范数而不是标准范数:

        · 便于计算。平方范数可以去掉平方根使得导数更容易计算,利于反向传播过程。

        · 使用L2范数是因为它会对权重向量的大分量施加巨大的惩罚,使各权重均匀分布。

        · L1范数惩罚会导致权重集中在某一小部分特征上,其它权重被清除为0(特征选择)。

使用该损失函数,就可以使梯度下降的优化算法在训练的每一步都衰减权重,避免过拟合发生。

如上图所示,现在模型的损失函数同时受两项影响,一是误差项,二是惩罚项。

        现在在等高线图上,梯度下降最终收敛的位置不再是某一个项所造成的最低点,因为在这时,可能误差项达到最小了,但是惩罚项很大,使得惩罚项拉着损失函数再向另一个方向移动。

        只有当达到了两个项共同作用下的一个平衡点时,损失函数才具有最小值,这个时候的模型往往复杂度也降低了,虽然有可能造成训练损失增大,但是测试损失会减小。

三、代码实现权重衰减

从零实现代码如下:

import matplotlib
import torch
from torch import nn
from d2l import torch as d2l

# 训练数据集、测试数据集、输入值、训练批次
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# 初始化w和b的真实值
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# 拿到训练数据
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

# 初始化模型参数w和b
def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]
# 定义L2范数惩罚项
def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2
# 实现训练代码,读入参数为兰姆达(正则化参数)
def train(lambd):
    w, b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())
# 使用权重进行训练
train(lambd=3)

简洁实现代码如下:

import torch
from torch import nn
from d2l import torch as d2l

# 训练数据集、测试数据集、输入值、训练批次
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
# 初始化w和b的真实值
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# 拿到训练数据
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

    train_concise(3)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/170516.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

基于springboot实现智能热度分析和自媒体推送平台系统项目【项目源码】计算机毕业设计

基于springboot实现智能热度分析和自媒体推送平台演示 系统开发平台 在该自媒体分享网站中,Eclipse能给用户提供更多的方便,其特点一是方便学习,方便快捷;二是有非常大的信息储存量,主要功能是用在对数据库中查询和编…

如果在手机没有root的情况下完成安卓手机数据恢复

您是否不小心从安卓设备中删除了重要数据? 担心如何取回您的照片、视频和文档? 有时您可能会不小心删除重要数据并使用安卓 root方法取回文件。 许多用户不喜欢根植他们的安卓设备,因为这是一种复杂的方法。 在本指南中,我们将向您…

【寒武纪(10)】linux arm aarch 是 opencv 交叉编译与使用

文章目录 1、直接找github 别人编译好的2、自主编译参考 3使用CMake检查 参考 1、直接找github 别人编译好的 测试很多,找到一个可用的。 https://github.com/dog-qiuqiu/libopencv 它用了超级模块! OpenCV的world模块也称为超级模块(supe…

【Java基础】Java导Excel攻略

💝💝💝欢迎来到我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:kwan 的首页,持续学…

HALCON根据需要创建自定义函数

任务要求: 创建函数myfun(a,b,c),输入浮点数a,b的值,计算c a b,将计算结果返回。 操作步骤: 1)打开HDevelop程序 2)打开函数菜单,选择“创建新函数”&#xff0c…

存储配置和挂载方式

存储配置 Iscsi简介 iSCSI 启动器,从本质上说,iSCSI 启动器是一个客户端设备,用于将请求连接并启动到服务器(iSCSI 目标)。 iSCSI 启动器有三种实现方式:可以完全基于硬件实现,比如 iSCSI H…

Conditional GAN

Text-to-Image 对于根据文字生成图像的问题,传统的做法就是训练一个NN,然后输入一段文字,输出对应一个图片,输出图片与目标图片越接近越好。存在的问题就是,比如火车对应的图片有很多张,如果用传统的NN来训…

在浏览器中使用WebRTC获取用户IP地址

本文翻译自 Discover WebRTC: Obtain User IP Addresses in the Browser,作者:Zack, 略有删改。 如果需要在程序中获取当前用户的IP,通常手段都是需要使用服务器。但现在借助WebRTC的强大功能,我们可以直接在浏览器客户…

基于springboot实现医院信管系统项目【项目源码+论文说明】

基于springboot实现医院信管系统演示 摘要 随着信息技术和网络技术的飞速发展,人类已进入全新信息化时代,传统管理技术已无法高效,便捷地管理信息。为了迎合时代需求,优化管理效率,各种各样的管理系统应运而生&#x…

重磅,瑞士药监局 发布 EU GMP附录1《无菌药品生产》官方解读!

近日,瑞士药监局发布了EU GMP附录1《无菌药品生产》(同时也是PIC/S和WHO GMP附录1)的解读文件,该文件侧重于新版EU、PIC/S和WHO GMP附录1的一些最重要的变化,也涵盖了长期以来反复引起问题的方面。反映了检查员对这些主…

python操作windows窗口,python库pygetwindow使用详解

文章目录 一、pygetwindow模块简介二、pygetwindow常用方法1、常用方法2、window常用方法 一、pygetwindow模块简介 pygetwindow是一个Python第三方库,用于获取、管理和操作窗口。它提供了一些方法和属性,使得在Python程序中可以轻松地执行各种窗口操作…

10个好用的Mac数据恢复软件推荐—恢复率高达99%

如果您正在寻找最好的 Mac 数据恢复软件来检索意外删除或丢失的文件,那么这里就是您的最佳选择。 我们理解,当您找不到 Mac 计算机或外部驱动器上保存的一些重要文件时,会感到多么沮丧和绝望。这些文件非常珍贵,无论出于何种原因…

通信原理板块——差错控制编码或纠错编码

微信公众号上线,搜索公众号小灰灰的FPGA,关注可获取相关源码,定期更新有关FPGA的项目以及开源项目源码,包括但不限于各类检测芯片驱动、低速接口驱动、高速接口驱动、数据信号处理、图像处理以及AXI总线等 1、背景 数字信号在传输过程中&…

(免费领源码)python#flask#mysql旅游数据可视化81319-计算机毕业设计项目选题推荐

摘要 信息化社会内需要与之针对性的信息获取途径,但是途径的扩展基本上为人们所努力的方向,由于站在的角度存在偏差,人们经常能够获得不同类型信息,这也是技术最为难以攻克的课题。针对旅游数据可视化等问题,对旅游数据…

AD9361寄存器功能笔记之本振频率设定

LO的产生过程如图: 各个模块都有高灵活性。 1、参考时钟即是AD9361全局参考时钟,可以是外接晶振的片上DCXO,或是外部输入的有驱动能力的时钟信号。根据FM-COMMS5的设计,参考时钟可以使用时钟Buffer 40MHz晶振构成的参考频率源。 …

实战 - 在Linux上部署各类软件

前言 为什么学习各类软件在Linux上的部署 在前面,我们学习了许多的Linux命令和高级技巧,这些知识点比较零散,同学们跟随着课程的内容进行练习虽然可以基础掌握这些命令和技巧的使用,但是并没有一些具体的实操能够串联起来这些知…

算法-简单-二叉树-翻转、对称

记录一下算法题的学习8 翻转二叉树 翻转二叉树题目 给你一棵二叉树的根节点 root ,翻转这棵二叉树,并返回其根节点。 举例:给定root[5,3,7,2,4,6,10] 翻转成为root[5,7,3,10,6,4,2] 即所有的根节点的左右节点都要互换位置,输出的…

BUUCTF 菜刀666 1

BUUCTF:https://buuoj.cn/challenges 题目描述: 流量分析,你能找到flag吗 注意:得到的 flag 请包上 flag{} 提交 密文: 下载附件,解压得到一个.pcapng文件。 解题思路: 1、双击文件,打开wir…

两种典型的雷达框架,traditional chain (待深入了解)和Capon Beamforming Chain(已经了解)

如图1所示,第1种是被称作“traditional chain”, 它的处理思路是adc数据作range-FFT,再到doppler-FFT,构建range-Dopper map,再到cfar,最后对候选点作angle-FFT,当然,这是最经典的framework&…