pytorch学习——正则化技术——权重衰减

一、概念介绍

 

        权重衰减(Weight Decay)是一种常用的正则化技术,它通过在损失函数中添加一个惩罚项来限制模型的复杂度,从而防止过拟合。 

        在训练参数化机器学习模型时, 权重衰减(weight decay)是最广泛使用的正则化的技术之一, 它通常也被称为L2正则化。

1.1理解:

权重衰减(weight_decay)本质上是一个L2正则化系数

那什么是参数的正则化?从我的理解上,就是让参数限定在一定范围,目的是为了不让模型对训练集过拟合。

注:应对过拟合最好的方法还是扩大有效样本(但成本过高)

1.2如何控制模型容量?

1.将模型变得比较小,减少里面参数的数量

2.缩小参数的取值范围

注:权重衰退就是通过限制参数的取值来实现

1.3硬性限制

即使得w的每个项的平方都小于θ这个值,最强情况下就是θ等于0,即所有w都等于0

1.4柔性限制

 即损失函数后面加了一个非负项,为了使损失函数最小化,就得使得后面项足够小——起到限制w的作用,相比于硬性限制,柔性限制并没有将w的值限制在一个固定范围内。

1.5图解对最优解的影响

 

 上式为不加限制条件的最优解,即图中的绿色中心点,但该点会使得||w||^2这一项较大,其和并不是最优解。

而加上限制的最优点即为图中两曲线的交叉点

1.6更新参数法则

 

 1.7总结

   ~权重衰减是通过L2正则项使得模型参数不会过大,从而控制复杂度

   ~正则项权重是控制模型复杂度的超参数

二、示例演示

2.1模型构造

生成公式如下:

# 导入需要的库
import torch
from torch import nn
from d2l import torch as d2l

# 定义训练和测试数据集的大小,输入特征的维度和批次大小
n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5

# 定义真实的权重true_w和偏差true_b,并将其初始化为0.01和0.05
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05

# 使用d2l.synthetic_data函数生成训练数据train_data和测试数据test_data
# 生成的数据是通过真实的权重和偏差加上一些噪声生成的
train_data = d2l.synthetic_data(true_w, true_b, n_train)
test_data = d2l.synthetic_data(true_w, true_b, n_test)

# 使用d2l.load_array函数将训练数据train_data和测试数据test_data
# 转换为数据迭代器train_iter和test_iter
train_iter = d2l.load_array(train_data, batch_size)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

2.2初始化模型参数

def init_params():
    w = torch.normal(0, 1, size=(num_inputs, 1), requires_grad=True)
    b = torch.zeros(1, requires_grad=True)
    return [w, b]
# 初始化模型参数w和b
# w的形状为(num_inputs, 1),从正态分布中随机生成
# b初始化为0
# 参数需要计算梯度,requires_grad参数被设置为True
# 返回一个包含w和b的列表

2.3定义L2范数

def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

2.4定义训练代码实现

        下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

函数的具体实现如下:

  1. 首先通过init_params()函数初始化模型参数w和b。

  2. 定义net函数为线性回归模型,loss为平方损失函数。

  3. 设置训练的轮数num_epochs和学习率lr,同时创建一个可视化工具animator,用于可视化训练过程中的损失值。

  4. 在每个epoch中,遍历训练数据集train_iter,对每个小批量数据(X, y)进行如下操作:

    • 计算模型的输出net(X),并计算损失函数loss(net(X), y)。

    • 加上L2范数惩罚项lambd * l2_penalty(w),其中l2_penalty(w)为权重w的L2范数。

    • 对损失函数进行反向传播,并使用SGD来更新模型参数w和b。

  5. 每5个epoch,计算训练集和测试集上的损失值,并使用animator将损失值可视化。

  6. 训练结束后,输出模型参数w的L2范数。

# 带有L2正则化的线性回归训练过程
# lambd表示L2正则化的强度

# 初始化模型参数w和b
w, b = init_params()

# 定义线性回归模型net和平方损失函数loss
net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_loss

# 设置训练的轮数num_epochs和学习率lr
# 创建一个可视化工具animator,用于可视化训练过程中的损失值
num_epochs, lr = 100, 0.003
animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                        xlim=[5, num_epochs], legend=['train', 'test'])

# 在每个epoch中,遍历训练数据集train_iter,对每个小批量数据(X, y)进行如下操作:
for epoch in range(num_epochs):
    for X, y in train_iter:
        # 计算模型的输出net(X),并计算损失函数loss(net(X), y)
        # 加上L2范数惩罚项lambd * l2_penalty(w),其中l2_penalty(w)为权重w的L2范数
        # 对损失函数进行反向传播,并使用SGD来更新模型参数w和b
        l = loss(net(X), y) + lambd * l2_penalty(w)
        l.sum().backward()
        d2l.sgd([w, b], lr, batch_size)

    # 每5个epoch,计算训练集和测试集上的损失值,并使用animator将损失值可视化
    if (epoch + 1) % 5 == 0:
        animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                 d2l.evaluate_loss(net, test_iter, loss)))

# 训练结束后,输出模型参数w的L2范数
print('w的L2范数是:', torch.norm(w).item())

 2.5训练结果展示

        在这段代码中,lambd是一个超参数,表示L2正则化的强度。在每个小批量数据的损失函数中,会加上L2范数惩罚项,以控制模型的复杂度和防止过拟合。L2正则化的强度由超参数lambd控制,lambd越大,模型的复杂度就越小,对训练数据的拟合程度就越差,但是可以更好地控制过拟合。反之,lambd越小,模型的复杂度就越大,对训练数据的拟合程度就越好,但是可能会过拟合。在模型训练过程中,我们通常会使用交叉验证等技术来选择最优的超参数lambd。

2.5.1忽略正则化直接训练

        其中用lambd = 0禁用权重衰减后运行这个代码。 注意,虽然训练误差有了减少,但测试误差没有减少, 这意味着出现了严重的过拟合。

 2.5.2使用权重衰减

        下面,我们使用权重衰减来运行代码。 注意,在这里训练误差增大,但测试误差减小。 得到预期效果。

 三.简洁实现代码

# 导入需要的库
import torch
from torch import nn
from d2l import torch as d2l

def train_concise(wd):
    # 定义训练和测试数据集的大小,输入特征的维度和批次大小
    n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5

    # 使用nn.Sequential定义了一个单层全连接神经网络net
    # 并将其参数使用param.data.normal_()方法初始化为随机值
    net = nn.Sequential(nn.Linear(num_inputs, 1))
    for param in net.parameters():
        param.data.normal_()

    # 使用nn.MSELoss定义平方损失函数loss
    # 该损失函数的reduction参数设置为'none',表示不对损失值进行降维
    loss = nn.MSELoss(reduction='none')

    # 设置训练的轮数num_epochs和学习率lr
    # 使用torch.optim.SGD定义一个优化器trainer,该优化器的参数包括网络的权重和偏差,以及权重衰减系数wd
    num_epochs, lr = 100, 0.003
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)

    # 创建一个可视化工具animator,用于可视化训练过程中的损失值
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])

    # 在每个epoch中,遍历训练数据集train_iter,对每个小批量数据(X, y)进行如下操作:
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 将优化器trainer的梯度清零
            # 计算模型的输出net(X),并计算损失函数loss(net(X), y)
            # 对损失函数进行反向传播,并使用优化器trainer来更新模型参数
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        # 每5个epoch,计算训练集和测试集上的损失值,并使用animator将损失值可视化。
            if (epoch + 1) % 5 == 0:
                animator.add(epoch + 1,
                            (d2l.evaluate_loss(net, train_iter, loss),
                            d2l.evaluate_loss(net, test_iter, loss)))
        print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)    #lambd设置为0

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/55341.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

windows环境下adb 下载和配置,连接手机。

ADB下载地址: https://adbdownload.com/ 选择下载windows系统的。 下载后解压,查看adb.exe所在的目录,如下 这里将路径复制下来:D:\ADB 配置到系统环境变量中。 然后再打开cmd,输入adb version查看版本。 出现…

2023.8.1号论文阅读

文章目录 MCPA: Multi-scale Cross Perceptron Attention Network for 2D Medical Image Segmentation摘要本文方法实验结果 SwinMM: Masked Multi-view with SwinTransformers for 3D Medical Image Segmentation摘要本文方法实验结果 MCPA: Multi-scale Cross Perceptron Att…

极简在线商城系统,支持docker一键部署

Hmart 给大家推荐一个简约自适应电子商城系统,针对虚拟商品在线发货,支持企业微信通知,支持docker一键部署,个人资质也可搭建。 前端 后端 H2 console 运行命令 docker run -d --name mall --restartalways -p 8080:8080 -e co…

10. Mybatis 项目的创建

目录 1. Mybatis 概念 2. 第一个 Mybits 查询 2.1 创建数据库和表 2.2 添加 Mybatis 框架支持 2.3 添加配置文件 2.4 配置 MyBatis 中的 XML 路径 2.5 添加业务代码 在学习 Mybatis 之前,我们需要知道 Mybatis 和 Spring 没有任何的关系。如果一定要强调二者…

ChatGPT安全技术

前言 近期,Twitter 博主 lauriewired 声称他发现了一种新的 ChatGPT"越狱"技术,可以绕过 OpenAI 的审查过滤系统,让 ChatGPT 干坏事,如生成勒索软件、键盘记录器等恶意软件。 他利用了人脑的一种"Typoglycemia&q…

Kubernetes系列

文章目录 1 详解docker,踏入容器大门1.1 引言1.2 初始docker1.3 docker安装1.4 docker 卸载1.5 docker 核心概念和底层原理1.5.1 核心概念1.5.2 docker底层原理 1.6 细说docker镜像1.6.1 镜像的常用命令 1.7 docker 容器1.8 docker 容器数据卷1.8.1 直接命令添加1.8.2 Dockerfi…

远程控制平台四之优化部署

服务器端打包 把服务器打成jar包对于后台开发的朋友来说小菜一碟,但对于前端开发可能有些细节要注意一下,尤其是有依赖其他第三方库的情况下,这里梳理了一下流程: File – Project Structure – Artifacts – add – JAR – From modules and dependencies 选中module和主…

OpenCVForUnity(十)扩张与侵蚀效果

文章目录 前言扩张案例展示 侵蚀案例展示 结语: 前言 在这个教程中,您将学习两种常见的图像形态运算符:侵蚀和膨胀。为此,您将使用OpenCV库中的两个函数:erode 和 dilate。 形态操作是一组基于形状的图像处理操作。形态…

电气防火限流式保护器在汽车充电桩使用上的作用

【摘要】 随着电动汽车行业的不断发展,电动汽车充电设施的使用会变得越来越频繁和广泛。根据中汽协数据显示,2022年上半年,我国新能源汽车产销分别完成266.1万辆和260万辆,同比均增长1.2倍,市场渗透率达21.6%。因此,电动汽车的安全…

【MySQL】数据库基本使用

文章目录 一、数据库介绍二、数据库使用2.1 登录MySQL2.2 基本使用2.2.1 显示当前 MySQL 实例中所有的数据库列表2.2.2 创建数据库2.2.3 创建数据库表2.2.4 在表中插入数据2.2.5 在表中查询数据 三、服务器、数据库、表之间的关系四、SQL语句分类五、存储引擎 一、数据库介绍 …

sql入门基础-2

Dml语句 对数据的增删改查 关键字 Insert增 Update删 Delete改 添加数据 给指定字段添加数据 Insert into 表明 (字段名1,字段名2) values(值1,值2); 给全部字段添加数据--(根据位置对应添加到字段下) Insert into 表名 values…

套接字通信(C/C++ 多线程)----基于线程池的并发服务器

(一)大家可以看我写的这三篇,了解一下: 基于linux下的高并发服务器开发(第四章)- 多线程实现并发服务器_呵呵哒( ̄▽ ̄)"的博客-CSDN博客https://blog.csdn.net/weixin_4198701…

【JavaWeb】Javascript经典案例

Javascript经典案例 注意&#xff1a;该文章是参考b站<20个JS经典案例>进行学习的&#xff0c;没有CSS的组成。 在慢慢更新中…哈哈哈哈&#xff0c;太慢了 文章目录 1.支付定时器2.验证码生成及校验 1.支付定时器 代码实现&#xff1a; confirm.html <!DOCTYPE html…

2.04 商品搜索功能实现

根据关键字获取分类查询对应的分页商品信息&#xff0c;并可以价格和销量进行排序切换 步骤1&#xff1a;mapper.xml编写sql语句 <!-- k: 默认&#xff0c;代表默认排序&#xff0c;根据name--> <!-- c: 根据销量排序--> <!-- p: 根据价格排序--> <sel…

消息队列 - 数据库操作

这里写自定义目录标题 前言数据表的插入删除操作关于实现接口类的几个注意实现实现封装创建DataBaseManager 类另一种获取Bean对象的方式 对数据库进行单元测试 前言 上一篇博客, 我们将消息队列的实体类创建完毕了, 并且还写了一些关于数据库的操作, 接下来我们继续进行关于数…

Java throw和throws 关键字

在Java中&#xff0c;异常可以分为两种类型&#xff1a; 未检查的异常&#xff1a;它们不是在编译时而是在运行时被检查&#xff0c;例如&#xff1a;ArithmeticException&#xff0c;NullPointerException&#xff0c;ArrayIndexOutOfBoundsException&#xff0c;Error类下的异…

wordpress 学习贴

安装问题 我的使用环境为docker环境&#xff0c;php、nginx、mysql分别处于3个容器中&#xff0c; 提示异常&#xff0c;打开debug模式&#xff0c;会发现 No such file or directory Warning: mysqli_real_connect(): (HY000/2002): No such file or directory 这个其实问题其…

Linux操作系统3-项目部署

手动部署 步骤 1.在idea中将文件项目进行打包 2.自定义一个文件目录&#xff0c;上传到Linux 3.使用 java -jar jar包名就可以进行运行 注意,如果需要启动该项目&#xff0c;需要确定所需的端口是否打开 采用这种方式&#xff0c;程序运行的时候会出现霸屏&#xff0c;并且会…

最近写了10篇Java技术博客【SQL和画图组件】

&#xff08;1&#xff09;Java获取SQL语句中的表名 &#xff08;2&#xff09;Java SQL 解析器实践 &#xff08;3&#xff09;Java SQL 格式化实践 &#xff08;4&#xff09;Java 画图 画图组件jgraphx项目整体介绍&#xff08;一&#xff09; 画图组件jgraphx项目导出…

计算机毕设 深度学习实现行人重识别 - python opencv yolo Reid

文章目录 0 前言1 课题背景2 效果展示3 行人检测4 行人重识别5 其他工具6 最后 0 前言 &#x1f525; 这两年开始毕业设计和毕业答辩的要求和难度不断提升&#xff0c;传统的毕设题目缺少创新和亮点&#xff0c;往往达不到毕业答辩的要求&#xff0c;这两年不断有学弟学妹告诉…