【动手学深度学习】(六)权重衰退

文章目录

  • 一、理论知识
  • 二、代码实现
    • 2.1从零开始实现
    • 2.2简洁实现
  • 【相关总结】

主要解决过拟合

一、理论知识

1、使用均方范数作为硬性限制(不常用)
通过限制参数值的选择范围来控制模型容量
在这里插入图片描述
通常不限制偏移b
小的在这里插入图片描述意味着更强的正则项
使用均方范数作为柔性限制
对于每个在这里插入图片描述都可以找到在这里插入图片描述使得之前的目标函数等价于下面的:
在这里插入图片描述

可以通过拉格朗日乘子来证明
超参数在这里插入图片描述控制了正则项的重要程度

在这里插入图片描述
在这里插入图片描述
参数更新法则
在这里插入图片描述
总结:

  • 权重衰退通过L2正则项使得模型参数不会过大,从而控制模型复杂度
  • 正则项权重是控制模型复杂度的超参数

二、代码实现

权重衰减是最广泛使用的正则化技术之一
1.首先,人工生成数据
在这里插入图片描述
我们选择标签是关于输入的线性函数。 标签同时被均值为0,标准差为0.01高斯噪声破坏。 为了使过拟合的效果更加明显,我们可以将问题的维数增加到, 并使用一个只包含20个样本的小训练集。

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
# print(torch.ones((num_inputs, 1)))
# print(true_w)
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
# print(train_iter)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

2.1从零开始实现

只需将的平方惩罚添加到原始目标函数中。

def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w,b]

定义L2范数惩罚

def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

定义训练代码

def train(lambd):
    w,b = init_params()
    net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                           xlim=[5,num_epochs], legend=['train', 'test'])
    
    for epoch in range(num_epochs):
        for X, y in train_iter:
#             增加了L2范数惩罚项
# 广播机制使l2_penalty(w)成为一个长度为torch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            d2l.sgd([w,b], lr, batch_size)
        if(epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                    d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:',torch.norm(w).item())

忽略正则化直接训练
用lambd = 0禁用权重衰减

train(lambd=0)

w的L2范数是: 13.702591896057129
在这里插入图片描述
使用权重衰退

train(lambd=3)

w的L2范数是: 0.36873573064804077
在这里插入图片描述

2.2简洁实现

在实例化优化器时直接通过weight_decay指定weight decay超参数

def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)

w的L2范数: 12.619434356689453
在这里插入图片描述

train_concise(3)

w的L2范数: 0.3909929692745209
在这里插入图片描述

【相关总结】

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/223379.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深入理解TDD(测试驱动开发):提升代码质量的利器

在日常的软件开发工作中,我们常常会遇到这样的问题:如何在繁忙的项目进度中,保证我们的代码质量?如何在不断的迭代更新中,避免引入新的错误?对此,有一种有效的开发方式能帮助我们解决这些问题&a…

CSS中区分行高,行间距

行高(line height) —文字占有的实际高度 —使用line-height来设置行高 行高类似于我们上学单线本,单线本是一行一行,线与线之间的距离就是行高,控制文字行与行之间的距离, 网页中的文字实际上也是写在一个…

14. 二叉树遍历

从物理结构的角度来看,树是一种基于链表的数据结构,因此其遍历方式是通过指针逐个访问节点。然而,树是一种非线性数据结构,这使得遍历树比遍历链表更加复杂,需要借助搜索算法来实现。 二叉树常见的遍历方式包括层序遍…

0基础学java-day14

一、集合 前面我们保存多个数据使用的是数组,那么数组有不足的地方,我们分析一下 1.数组 2 集合 数据类型也可以不一样 3.集合的框架体系 Java 的集合类很多,主要分为两大类,如图 :[背下来] package com.hspedu.c…

Domino多Web站点托管

大家好,才是真的好。 看到一篇文档,大概讲述的是他在家里架了一台Domino服务器,上面跑了好几个Internet的Web网站(使用Internet站点)。再租了一台云服务器,上面安装Nginx做了反向代理,代理访问…

matlab实践(十):贝塞尔曲线

1.贝塞尔曲线 贝塞尔曲线的原理是基于贝塞尔曲线的数学表达式和插值算法。 贝塞尔曲线的数学表达式可以通过控制点来定义。对于二次贝塞尔曲线,它由三个控制点P0、P1和P2组成,其中P0和P2是曲线的起点和终点,P1是曲线上的一个中间点。曲线上…

前端JavaScript入门-day08-正则表达式

目录 介绍 语法 元字符 边界符 量词 字符类: 修饰符 介绍 正则表达式(Regular Expression)是用于匹配字符串中字符组合的模式。在 JavaScript中,正则表达式也是对象,通常用来查找、替换那些符合正则表达式的…

Distilling the Knowledge in a Neural Network(2015.5)(d补)

文章目录 Abstract1 Introduction2 Distillation2.1 Matching logits is a special case of distillation Results 论文链接 Abstract 提高几乎所有机器学习算法性能的一种非常简单的方法是在相同的数据上训练许多不同的模型,然后对它们的预测进行平均[3]。不幸的是…

Node.js安装和下载(保姆级教程,别再再说你不会了)

1.浏览器搜索node.js 2.打开官网(选择Other Download) ​ 3.根据你的计算机版本选择 4.找到你下载的程序(双击打开) 5.双击后的效果如下: 6.继续下一步 7.选择安装路径然后下一步 8.然后继续下一步 9. 直接下一步&am…

P6 Linux 系统中的文件类型

目录 前言 ​编辑 01 linux系统查看文件类型 02 普通文件 - 03 目录文件 d 04 字符设备文件 c 和块设备文件 b 05 符号链接文件 l 06 管道文件 p 07 套接字文件 s 总结 前言 🎬 个人…

数据增强改进,实现检测目标copypaste,增加目标数据量,提升精度

🗝️YOLOv8实战宝典--星级指南:从入门到精通,您不可错过的技巧   -- 聚焦于YOLO的 最新版本, 对颈部网络改进、添加局部注意力、增加检测头部,实测涨点 💡 深入浅出YOLOv8:我的专业笔记与技术总结   -- YOLOv8轻松上手, 适用技术小白,文章代码齐全,仅需 …

postgresql自带指令命令系列二

简介 在安装postgresql数据库的时候会需要设置一个关于postgresql数据库的PATH变量 export PATH/home/postgres/pg/bin:$PATH,该变量会指向postgresql安装路径下的bin目录。这个安装目录和我们在进行编译的时候./configure --prefix [指定安装目录] 中的prefix参…

consistency model

Consistency is All You Need - wrong.wang什么都不用做生成却快了十倍其实也并非完全不可能https://wrong.wang/blog/20231111-consistency-is-all-you-need/[学科基础] 从布朗运动到扩散模型采样算法 - 知乎引言 扩散模型是近年来新出现的一种生成模型,很多工作将…

现货白银简单介绍

在贵金属投资领域,现货白银是当前国际上最为流行、交投最为活跃的白银投资方式,其交易市场遍布全球,包括伦敦、苏黎世、纽约、芝加哥及香港等主要市场,是一种以杠杆交易和做市商的形式进行的现货交易。 现货白银可以说是当下交易模…

Python (二) 读写excel文件

程序员的公众号:源1024,获取更多资料,无加密无套路! 最近整理了一波电子书籍资料,包含《Effective Java中文版 第2版》《深入JAVA虚拟机》,《重构改善既有代码设计》,《MySQL高性能-第3版》&…

1996-2021年世界各国WGI全球治理指标:政治稳定、制度控制、国家治理、控制腐败、自由指数数据

1996-2021年世界各国WGI全球治理指标:政治稳定、制度控制、国家治理、控制腐败、自由指数数据 1、时间:1996-2021年 2、指标:Voiceand Accountability、Political Stability No Violence、Government Effectiveness、Regulatory Quality、R…

tomcat控制台中文信息显示乱码

问题现象 我的tomcat版本是10.1版本。 在cmd下启动tomcat,会新打开控制台输出窗口: 控制台窗口输出的中文信息是乱码: 问题原因 产生这个问题的原因是:控制台窗口的编码和输出到控制台窗口的日志信息编码不一致。 查看tomc…

《opencv实用探索·十一》opencv之Prewitt算子边缘检测,Roberts算子边缘检测和Sobel算子边缘检测

1、前言 边缘检测: 图像边缘检测是指在图像中寻找灰度、颜色、纹理等变化比较剧烈的区域,它们可能代表着物体之间的边界或物体内部的特征。边缘检测是图像处理中的一项基本操作,可以用于人脸识别、物体识别、图像分割等多个领域。 边缘检测…

如何在服务器上运行python文件

目录 前置准备 详细步骤 一,在服务器安装Anaconda 下载安装包 上传文件到服务器 安装环境 二,创建虚拟环境 创建环境 三,测试执行python文件 执行python文件 查看进程状态 总结 前置准备 如何在个人服务器上运行python文件&#x…

elk+kafka+filebeat

elk1 cd /opt 把filebeat投进去 tar -xf filebeat-6.7.2-linux-x86_64.tar.gz mv filebeat-6.7.2-linux-x86_64 filebeat cd filebeat/ yum -y install nginx systemctl restart nginx vim /usr/share/nginx/html/index.html this is nginx cp filebeat.yml filebeat.yml.…