动手学深度学习(Pytorch版)代码实践 -深度学习基础-10权重衰减

10权重衰减

"""
正则化是处理过拟合的常用方法:在训练集的损失函数中加入惩罚项,以降低学习到的模型的复杂度。
保持模型简单的一个特别的选择是使用L2惩罚的权重衰减。这会导致学习算法更新步骤中的权重衰减。
"""

import torch
from torch import nn
from d2l import torch as d2l
import liliPytorch as lp

n_train, n_test, num_input, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_input,1)) * 0.01, 0.05

train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)

test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)


#初始化模型参数
def init_params():
    w = torch.normal(0,1,size=(num_input,1), requires_grad=True)
    b = torch.zeros(1,requires_grad=True)
    return [w,b]

#定义L2范数惩罚
def l2_penalty(w):
    return torch.sum(w.pow(2)) / 2

def l1_penalty(w):
    return torch.sum(torch.abs(w))

# 定义模型
def linreg(X, w, b):
    """线性回归模型"""
    return torch.matmul(X, w) + b

# 定义损失函数
def squared_loss(y_hat, y):
    """均方损失函数"""
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2

# 定义优化函数
def sgd(params, lr, batch_size):
    """小批量随机梯度下降"""
    # 更新参数时不需要计算梯度
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size  # 参数更新
            param.grad.zero_()  # 梯度清零

#定义训练代码实现
def train(lambd):
    w, b = init_params()
    net, loss = lambda X: linreg(X, w, b), squared_loss
    num_epochs, lr = 100, 0.003
    animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            # 增加了L2范数惩罚项,
            # 广播机制使l2_penalty(w)成为一个长度为batch_size的向量
            l = loss(net(X), y) + lambd * l2_penalty(w)
            l.sum().backward()
            sgd([w, b], lr, batch_size)
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),
                                     d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数是:', torch.norm(w).item())


#忽略正则化直接训练¶
# train(lambd=0)
#w的L2范数是: 14.630496978759766

# 使用权重衰减
# train(lambd=3)
# d2l.plt.show() 


#权重衰减-简洁实现
def train_concise(wd):
    net = nn.Sequential(nn.Linear(num_input, 1))
    for param in net.parameters():
        param.data.normal_()
    loss = nn.MSELoss(reduction='none')
    num_epochs, lr = 100, 0.003
    # 偏置参数没有衰减
    trainer = torch.optim.SGD([
        {"params":net[0].weight,'weight_decay': wd},
        {"params":net[0].bias}], lr=lr)
    animator = lp.Animator(xlabel='epochs', ylabel='loss', yscale='log',
                            xlim=[5, num_epochs], legend=['train', 'test'])
    for epoch in range(num_epochs):
        for X, y in train_iter:
            trainer.zero_grad()
            l = loss(net(X), y)
            l.mean().backward()
            trainer.step()
        if (epoch + 1) % 5 == 0:
            animator.add(epoch + 1,
                         (d2l.evaluate_loss(net, train_iter, loss),
                          d2l.evaluate_loss(net, test_iter, loss)))
    print('w的L2范数:', net[0].weight.norm().item())

train_concise(0)
d2l.plt.show() 
# w的L2范数是: 0.33992505073547363

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/733238.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python基础教程(三十):math模块

💝💝💝首先,欢迎各位来到我的博客,很高兴能够在这里和您见面!希望您在这里不仅可以有所收获,同时也能感受到一份轻松欢乐的氛围,祝你生活愉快! 💝&#x1f49…

AudioSep:从音频中分离出特定声音(人声、笑声、噪音、乐器等)本地一键整合包下载

AudioSep是一种 AI 模型,可以使用自然语言查询进行声音分离。这一创新性的模型由Audio-AGI开发,使用户能够通过简单的语言描述来分离各种声音源。 比如在嘈杂的人流车流中说话的录音中,可以分别提取干净的人声说话声音和嘈杂的人流车流噪声。…

对于初学者,该如何选择大模型框架 LlamaIndex 与 LangChain ?

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学. 针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 汇总合集&…

任务3.8.4 利用RDD实现分组排行榜

文章目录 1. 任务说明2. 解决思路3. 准备成绩文件4. 采用交互式实现5. 采用Spark项目实战概述:使用Spark RDD实现分组排行榜任务背景任务目标技术选型实现步骤1. 准备数据2. 数据上传至HDFS3. 启动Spark Shell4. 读取数据生成RDD5. 数据处理6. 计算TopN7. 输出结果8…

「动态规划」如何求最长湍流子数组的长度?

78. 最长湍流子数组https://leetcode.cn/problems/longest-turbulent-subarray/description/ 给定一个整数数组arr,返回arr的最长湍流子数组的长度。如果比较符号在子数组中的每个相邻元素对之间翻转,则该子数组是湍流子数组。更正式地来说,…

算法期末整理

目录 一 算法概述 二 递归与分治策略 三 动态规划 四 贪心算法 五 回溯法 六 分支限界法 七 随机化算法 八 线性规划与网络流 一 算法概述 算法的概念 通俗地讲,算法是指解决问题的一种方法或一个过程。更严格地讲,算法是由若干条指令组成的有穷…

设计模式——职责链模式

职责链模式(Chain of Responsibility) 使多个对象都有机会处理请求,从而避免请求的发送者和接收者之间的耦合关系。将这个对象连接成一条链,并沿着这条链传递请求,直到有一个对象处理它为止。   发出请求的客户端并不知道职责链中哪一个对象…

python watchdog 配置文件热更新

目录 一、Watchdog示例 二、aiohttp服务配置热更新 在同事的golang代码中学习到了config.json热更新的功能,这里自己也学习了一下python写web服务的时候怎么来实现配置的热更新。主要是利用Watchdog这个第三方python库,来监控文件系统的改变&#xff0…

实战电商大数据项目搭建||电商大数据采集||电商API接口

我会提供给你大概1亿条真实的互联网用户上网数据,至于来源,我先不告诉你,绝对是你在网络上无法找到的宝贵数据源。 此外,还会给你提供一个基于当前数据特点而设计的大数据处理方案。 当然,为了防止用户的隐私部分被泄露…

每日优秀影视分享❗❗

一、热门电影推荐 《头脑特工队 2》:皮克斯再次为观众带来了这部经典动画的续集。 影片讲述了刚步入青春期的小女孩莱莉脑海中的复杂情绪进行的一场奇妙冒险。 这部电影不仅延续了前作的优秀品质,更在情感深度和视觉呈现上有了进一步的提升。 《艾尔登…

同时使用接口文档swagger和knife4j

项目场景: springboot项目中同时使用接口文档swagger和knife4j 问题描述 在实体类中设置了字段必填的属性,在访问接口文档时出现异常 实体类关键代码片段 /*** 部门表 sys_dept*/ public class SysDept extends BaseEntity {private static final lo…

Python基础入门

目录 1. 什么是Python? 2. 安装Python 3. Python基础语法 4. 数据结构 5. 文件操作 6. Python标准库 总结 1. 什么是Python? Python是一种高级编程语言,由Guido van Rossum于1991年发布。它以其简单易读的语法和强大的功能而闻名&…

高效22KW双向DCDC储能、充电电源模块项目设计开发

22kW 双向CLL谐振变换器的目标是输出电压范围宽、高效率和高功率密度的双向应用,如电动汽车车载充电器和储能系统。研究了一种新的灵活的 CLLC 双向谐振变换器增益控制方案,以便在充放电模式下实现高效率和宽电压增益范围。得益于 Wolfspeed C3MTM 1200V…

读《文明之光》第2册总结

《文明之光》系列大致按照从地球诞生到近现代的顺序讲述了人类文明进程的各个阶段,每个章节相对独立,全景式地展现了人类文明发展历程中的多样性。《文明之光》系列第二册讲述了从近代科学兴起,到工业革命时代,以及原子能应用这一…

【代码随想录】【算法训练营】【第46天】 [121]买卖股票的最佳时机 [122]买卖股票的最佳时机II [123]买卖股票的最佳时机III

前言 思路及算法思维,指路 代码随想录。 题目来自 LeetCode。 day 46,周六,坚持很困难~ 题目详情 [121] 买卖股票的最佳时机 题目描述 121 买卖股票的最佳时机 解题思路 前提: 思路: 重点: 代码实…

Springboot应用的信创适配

CentOS7在2024.6.30停止维护后,可替代的Linux操作系统-CSDN博客 全面国产化之路-信创-CSDN博客 信创适配评测-CSDN博客 Springboot应用的信创适配 Springboot应用的信创适配,如上图所示需要适配的很多,从硬件、操作系统、中间件&#xff08…

开启声音的奇幻之旅:AI声音变换器的魔法秘籍与创意应用

AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频https://aitools.jurilu.com/这个充满科技魔力的时代,AI Voice Changer 就像一把神奇的钥匙,能为我们打开声音的魔法之门。今天,就让我带你…

基于CentOS Stream 9平台 安装/卸载 Redis7.0.15

已更正systemctl管理Redis服务问题 1. 官方下载地址 https://redis.io/downloads/#redis-downloads 1.1 下载或上传到/opt/coisini目录下: mkdir /opt/coisini cd /opt/coisini wget https://download.redis.io/releases/redis-7.0.15.tar.gz2. 解压 tar -zxvf re…

【实战分享】雷池社区版助力构建高可用、安全的Web应用架构

引言 在日益复杂的网络环境中,构建坚不可摧的安全防线成为每一位网站守护者的重要使命。本文将深入剖析一套集CDN加速、高效Nginx代理与雷池WAF深度防护于一体的现代网站安全架构设计,特别强调雷池WAF在此架构中的核心作用及其对整体安全性的提升策略。…

Linux中部署MySQL环境(本地安装)

进入官网:http://www.mysql.com 选择社区版本得到MySQL 选择对应的版本和系统进行安装 用wget进行软件包下载 wget https://cdn.mysql.com//Downloads/MySQL-8.0/mysql-8.0.32-1.el9.x86_64.rpm-bundle.tar解压该软件包 tar -xf mysql-8.0.32-1.el9.x86_64.rpm-bu…