【PyTorch】权重衰减

文章目录

  • 1. 理论介绍
  • 2. 实例解析
    • 2.1. 实例描述
    • 2.2. 代码实现

1. 理论介绍

  • 通过对模型过拟合的思考,人们希望能通过某种工具调整模型复杂度,使其达到一个合适的平衡位置。
  • 权重衰减(又称 L 2 L_2 L2正则化)通过为损失函数添加惩罚项,用来惩罚权重的 L 2 L_2 L2范数,从而限制模型参数值,促使模型参数更加稀疏或更加集中,进而调整模型的复杂度,即: L ( w , b ) + λ 2 ∥ w ∥ 2 L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2 L(w,b)+2λw2其中 λ \lambda λ权重衰减的超参数
  • L p L_p Lp范数: ∥ x ∥ p = ( ∑ i = 1 n ∣ x i ∣ p ) 1 / p \|\mathbf{x}\|_p = \left(\sum_{i=1}^n \left|x_i \right|^p \right)^{1/p} xp=(i=1nxip)1/p
    p = 1 p=1 p=1时称为 L 1 L_1 L1范数;当 p = 2 p=2 p=2时称为 L 2 L_2 L2范数。
    惩罚 L 1 L_1 L1范数会导致模型将权重集中在一小部分特征上, 而将其他权重清除为零, 这称为特征选择;惩罚 L 2 L_2 L2范数会导致模型在大量特征上均匀分布权重,使得模型对单个变量的观测误差更为稳定。
  • 通常不建议对偏置进行正则化,因为偏置的取值并不像权值那样会随着训练过程而变化,因此对偏置进行正则化对于控制模型的复杂度影响较小;另外,对偏置进行正则化可能会导致对数据中的偏移进行过度拟合,而减弱了模型对其他特征的学习。

2. 实例解析

2.1. 实例描述

使用以下公式生成包含20个样本的小训练集和100个样本的测试集,并用线性网络进行拟合: y = 0.05 + ∑ i = 1 200 0.01 x i + ϵ  where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^{200} 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=12000.01xi+ϵ where ϵN(0,0.012).

2.2. 代码实现

  • 主要代码
optimizer = optim.SGD([
            {"params": net.weight,"weight_decay": weight_decay},
            {"params": net.bias}
            ], lr=lr)
  • 完整代码
import os
import torch
from torch import nn, optim
from torch.utils.data import TensorDataset, DataLoader
from tensorboardX import SummaryWriter
from rich.progress import track

def data_generator(w, b, num):
    """为线性模型生成数据"""
    X = torch.randn(num, len(w))
    y = torch.sum(X @ w, dim=1) + b
    y += torch.normal(0, 0.01, y.shape)
    return X, y.reshape(-1, 1)

def load_dataset(*tensors):
    """加载数据集"""
    dataset = TensorDataset(*tensors)
    return DataLoader(dataset, batch_size, shuffle=True)

def evaluate_loss(dataloader, net, criterion):
    """评估模型在指定数据集上的损失"""
    num_examples = 0
    loss_sum = 0.0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.cuda(), y.cuda()
            loss = criterion(net(X), y)
            num_examples += y.shape[0]
            loss_sum += loss.sum()
        return loss_sum / num_examples


if __name__ == '__main__':
    # 全局参数设置
    lr = 0.003
    num_epochs = 100
    batch_size = 5

    # 创建记录器
    def log_dir():
        root = "runs"
        if not os.path.exists(root):
            os.mkdir(root)
        order = len(os.listdir(root)) + 1
        return f'{root}/exp{order}'
    writer = SummaryWriter(log_dir=log_dir())
    
    # 合成数据集
    num_inputs = 200
    n_train, n_test = 20, 100
    true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
    X, y = data_generator(true_w, true_b, n_train + n_test)

    # 加载数据集
    dataloader_train = load_dataset(X[:n_train], y[:n_train])
    dataloader_test = load_dataset(X[n_train:], y[n_train:])

    def loop(weight_decay):
        # 定义模型
        net = nn.Linear(num_inputs, 1).cuda()
        nn.init.normal_(net.weight)
        nn.init.constant_(net.bias, 0)
        criterion = nn.MSELoss(reduction='none')
        optimizer = optim.SGD([
            {"params": net.weight,"weight_decay": weight_decay},
            {"params": net.bias}
            ], lr=lr)

        # 训练循环
        for epoch in track(range(num_epochs), description=f'wd={weight_decay}'):
            for X, y in dataloader_train:
                X, y = X.cuda(), y.cuda()
                loss = criterion(net(X), y)
                optimizer.zero_grad()
                loss.mean().backward()
                optimizer.step()
            writer.add_scalars(f'wd={weight_decay}', {
                'train_loss': evaluate_loss(dataloader_train, net, criterion),
                'test_loss': evaluate_loss(dataloader_test, net, criterion),
            }, epoch)


    for weight_decay in [0, 3]:
        loop(weight_decay)
    writer.close()
  • 输出结果
    • weight_decay = 0
      0
    • weight_decay = 3
      3

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/224122.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

HTTPS 的通信加解密过程,证书为什么更安全?

目录 一、什么是https 二、HTTPS 的加解密过程 三、HTTPS 为什么更安全? 一、什么是https HTTPS(Hypertext Transfer Protocol Secure)是一种通过加密和身份验证保护数据传输安全的通信协议。它是在常用的HTTP协议基础上添加了 SSL/TLS 加…

Memory-augmented Deep Autoencoder for Unsupervised Anomaly Detection 论文阅读

Memorizing Normality to Detect Anomaly: Memory-augmented Deep Autoencoder for Unsupervised Anomaly Detection 摘要1.介绍2.相关工作异常检测Memory networks 3. Memory-augmented Autoencoder3.1概述3.2. Encoder and Decoder3.3. Memory Module with Attention-based S…

redis中使用事务

事务是指一个执行过程,要么全部执行成功,要么失败什么都不改变。不会存在一部分成功一部分失败的情况,也就是事务的ACID四大特性(原子性、一致性、隔离性、持久性)。但是redis中的事务并不是严格意义上的事务&#xff…

论MYSQL注入的入门注解

📑打牌 : da pai ge的个人主页 🌤️个人专栏 : da pai ge的博客专栏 ☁️宝剑锋从磨砺出,梅花香自苦寒来 📑什么是MySQL注入&…

AI模型平台Hugging Face存在API令牌漏洞;大型语言模型与任务模型

🦉 AI新闻 🚀 AI模型平台Hugging Face存在API令牌漏洞,黑客可窃取、修改模型 摘要:安全公司Lasso Security发现AI模型平台Hugging Face上存在API令牌漏洞,黑客可获取微软、谷歌等公司的令牌,并能够访问模…

【PyTorch】训练过程可视化

文章目录 1. 训练过程中的可视化1.1. alive_progress1.2. rich.progress 2. 训练结束后的可视化2.1. tensorboardX2.1.1. 安装2.1.2. 使用 1. 训练过程中的可视化 主要是监控训练的进度。 1.1. alive_progress 安装 pip install alive_progress使用 from alive_progress i…

持续集成交付CICD: Sonarqube REST API 查找与新增项目

目录 一、实验 1.SonarQube REST API 查找项目 2.SonarQube REST API 新增项目 一、实验 1.SonarQube REST API 查找项目 (1)Postman测试 转换成cURL代码 (2)Jenkins添加凭证 (3)修改流水线 pipeline…

HCIP考试实验

实验更新中,部分配置解析与分析正在完善中........... 实验拓扑图 实验要求 要求 1、该拓扑为公司网络,其中包括公司总部、公司分部以及公司骨干网,不包含运营商公网部分。 2、设备名称均使用拓扑上名称改名,并且区分大小写。 3…

SQL server 根据已有数据库创建相同的数据库

文章目录 用导出的脚本创建相同的数据库导出建表脚本再次建表 一些sql语句 用导出的脚本创建相同的数据库 导出建表脚本 首先,右击要导出的数据库名,依次选择任务-生成脚本。 简介(第一页)处选择下一步,然后来到选择…

MAMBA介绍:一种新的可能超过Transformer的AI架构

有人说,“理解了人类的语言,就理解了世界”。一直以来,人工智能领域的学者和工程师们都试图让机器学习人类的语言和说话方式,但进展始终不大。因为人类的语言太复杂,太多样,而组成它背后的机制,…

MAC 系统在vs code中,如何实现自动换行

目录 问题描述: 问题解决: 问题描述: 在vscode中,有些时候,一行内容过多,如果不能自动换行,就需要拖动页面,才能看到完整的内容。如下图两行所示: 问题解决&#xff1a…

国标GB28181设备注册安防监控平台EasyCVR不上线是什么原因?

安防视频监控EasyCVR平台兼容性强,可支持的接入协议众多,包括国标GB28181、RTSP/Onvif、RTMP,以及厂家的私有协议与SDK,如:海康ehome、海康sdk、大华sdk、宇视sdk、华为sdk、萤石云sdk、乐橙sdk等。平台能将接入的视频…

【Python】Flask + MQTT 实现消息订阅发布

目录 Flask MQTT 实现消息订阅发布准备开始1.创建Flask项目2创建py文件:mqtt_demo.py3.代码实现4.项目运行5.测试5.1 测试消息接收5.2 测试消息发布6、扩展 Flask MQTT 实现消息订阅发布 准备 本次项目主要使用到的库:flask_mqtt pip install flask…

如何自定义负载均衡策略

参考官方资源 Home Netflix/ribbon Wiki (github.com)6. 客户端负载均衡器:功能区 (spring.io)负载均衡策略 内置负载均衡规则类规则描述RoundRobinRule简单轮询服务列表来选择服务器。它是Ribbon默认的负载均衡规则。AvailabilityFilteringRule对以下两种服务器进…

10.Java程序设计-基于SSM框架的微信小程序家教信息管理系统的设计与实现

摘要是论文的开篇,用于简要概述研究的目的、方法、主要结果和结论。以下是一个简化的摘要示例,你可以根据实际情况进行修改和扩展: 摘要 随着社会的发展和教育需求的增长,家教服务作为一种个性化的学习方式受到了广泛关注。为了更…

STM32L051使用HAL库操作实例(13)- 读取IAQ-CORE-C传感器实例

目录 一、前言 二、传感器参数 三、STM32CubeMX配置(本文使用的STM32CubeMX版本为6.1.2)例程使用模拟I2C进行数据读取 1.MCU选型 2.使能时钟 3.时钟配置 4.GPIO口配置 四、配置STM32CubeMX生成工程文件 五、点击GENERATE CODE生成工程文件 六、…

系统设计-缓存介绍

该图说明了我们在典型架构中缓存数据的位置。 沿着流程有多个层次。 客户端应用程序:HTTP 响应可以由浏览器缓存。我们第一次通过 HTTP 请求数据,返回时在 HTTP 标头中包含过期策略;我们再次请求数据,客户端应用程序首先尝试从浏…

看图学源码之 Atomic 类源码浅析二(cas + 分治思想的原子累加器)

原子累加器 相较于上一节看图学源码 之 Atomic 类源码浅析一(cas 自旋操作的 AtomicXXX原子类)说的的原子类,原子累加器的效率会更高 XXXXAdder 和 XXXAccumulator 区别就是 Adder只有add 方法,Accumulator是可以进行自定义运算方…

工业4G路由器助力轨道交通城市地铁实现数字化转型

随着城市的科技不断发展,地铁系统的智能化程度也在不断提高。地铁闸机的网络部署已经成为地铁建设中必不可少环节。而4G路由器作为地铁闸机的网络通讯设备,助力轨道交通地铁闸机实现数字化转型。 工业4G路由器在地铁系统光纤宽带网络遇到故障或其他问题…

2023 金砖国家职业技能大赛网络安全省赛二三阶段样题(金砖国家未来技能挑战赛)

2023 金砖国家职业技能大赛网络安全省赛二三阶段样题(金砖国家未来技能挑战赛) 第二阶段: 安全运营 **背景:**作为信息安全技术人员必须能够掌握操作系统加固与安全管控、防火 墙一般配置、常见服务配置等相关技能,利…