学习率调整策略 | PyTorch 深度学习实战

前一篇文章,深度学习里面的而优化函数 Adam,SGD,动量法,AdaGrad 等 | PyTorch 深度学习实战

本系列文章 GitHub Repo: https://github.com/hailiang-wang/pytorch-get-started

本篇文章内容来自于 强化学习必修课:引领人工智能新时代【梗直哥瞿炜】

PyTorch 学习率调整策略

  • 常见的学习率调节器
    • 学习率衰减
    • 指数衰减
    • 余弦学习率调节
    • 预热
  • 示例程序
    • 执行结果
      • 没有使用学习率自动调节时
      • 使用了学习率自动调节
      • 结论
  • 常见学习率调节器
  • Links

常见的学习率调节器

在这里插入图片描述

学习率衰减

在这里插入图片描述

指数衰减

在这里插入图片描述

余弦学习率调节

实现学习率循环降低或升高的效果

在这里插入图片描述

预热

在这里插入图片描述

示例程序

下面以指数衰减调节器(ExponentialLR)为例子,展示同样的数据条件下:不衰减学习率和衰减学习率两种情况下,损失函数loss的收敛情况。

import torch
torch.manual_seed(777)


'''
Learning rate scheduler
'''
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset # 构造数据集加载器
from torch.utils.data import random_split # 划分数据集

torch.manual_seed(777)# for reproducibility为了重复使用

############################
# 生成数据
############################

# 定义函数
def f(x,y):
    return x**2 + 2*y**2

# 定义初始值
num_samples = 1000 # 1000 个样本点
X = torch.rand(num_samples) # 均匀分布
Y = torch.rand(num_samples)
Z = f(X,Y) + 3 * torch.randn(num_samples)


# Concatenates a sequence of tensors along a new dimension.
# All tensors need to be of the same size.
# https://pytorch.org/docs/stable/generated/torch.stack.html
dataset = torch.stack([X,Y,Z], dim=1)
# print(dataset.shape) # torch.Size([1000, 3])

# split data, 按照 7:3 划分数据集
train_size = int(0.7 * len(dataset))
test_size = len(dataset) - train_size

train_dataset, test_dataset = random_split(dataset=dataset, lengths=[train_size, test_size])

# 将数据封装为数据加载器
# narrow 函数对数据进行切片操作,
# 
train_dataloader = DataLoader(TensorDataset(train_dataset.dataset.narrow(1,0,2), train_dataset.dataset.narrow(1,2,1)), batch_size=32, shuffle=False)
test_dataloader = DataLoader(TensorDataset(test_dataset.dataset.narrow(1,0,2), test_dataset.dataset.narrow(1,2,1)), batch_size=32, shuffle=False)

############################
# 模型定义
############################

# 定义一个简单的模型
class Model(nn.Module):

    def __init__(self):
        super().__init__()
        self.hidden = nn.Linear(2, 8)
        self.output = nn.Linear(8, 1)

    def forward(self, x):
        x = torch.relu(self.hidden(x))
        return self.output(x)


############################
# 模型训练
############################

# 超参数
num_epochs = 100
learning_rate = 0.1 # 学习率,调大一些更直观

# 定义损失函数
loss_fn = nn.MSELoss()

# 通过两次训练,对比有无调节器的效果
for with_scheduler in [False, True]:

    # 定义训练和测试误差数组
    train_losses = []
    test_losses = []

    # 初始化模型
    model = Model()

    # 定义优化器
    optimizer = torch.optim.SGD(model.parameters(), lr = learning_rate)

    # 定义学习率调节器
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=0.99)

    # 迭代训练
    for epoch in range(num_epochs):
        # 设定模型工作在训练模式
        model.train()
        train_loss = 0

        # 遍历训练集
        for inputs, targets in train_dataloader:
            # 预测、损失函数、反向传播
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_fn(outputs, targets)
            loss.backward()
            optimizer.step()
            
            # 记录 loss
            train_loss += loss.item()

        # 计算 loss 并记录到训练误差
        train_loss /= len(train_dataloader)
        train_losses.append(train_loss)

        # 在测试数据集上评估
        model.eval()
        test_loss = 0

        with torch.no_grad():
            # 遍历测试集
            for inputs, targets in test_dataloader:
                # 预测、损失函数
                outputs = model(inputs)
                loss = loss_fn(outputs, targets)
                # 记录 loss
                test_loss += loss.item()

            # 计算 loss 并记录到测试误差
            test_loss /= len(test_dataloader)
            test_losses.append(test_loss)

        
        # 是否更新学习率
        if with_scheduler:
            scheduler.step()

    # 绘制训练和测试误差曲线
    plt.figure(figsize= (8, 4))
    plt.plot(range(num_epochs), train_losses, label="Train")
    plt.plot(range(num_epochs), test_losses, label="Test")
    plt.title("{0} lr_scheduler".format("With " if with_scheduler else "Without"))
    plt.legend()
    # plt.ylim((1,2))
    plt.show()

执行结果

没有使用学习率自动调节时

在这里插入图片描述

使用了学习率自动调节

在这里插入图片描述

结论

使用了学习率自动调节,学习的速度更快,收敛速度更快。

常见学习率调节器

## 学习率衰减,例如每训练 100 次就将学习率降低为原来的一半
scheduler = torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=100, gamma=0.5)
## 指数衰减法,每次迭代将学习率乘上一个衰减率
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=optimizer,gamma=0.99)
## 余弦学习率调节,optimizer 初始学习率为最大学习率,eta_min 是最小学习率,T_max 是最大的迭代次数
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=100, eta_min=0.00001)
## 自定义学习率,通过一个 lambda 函数自定义实现学习率调节器
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda epoch: 0.99 ** epoch)
## 预热
warmup_steps = 20
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer=optimizer, lr_lambda=lambda t: min(t/warmup_steps, 0.001))

Links

  • PyTorch学习率调整策略.ipynb
  • 6.2 动态调整学习率
  • 【学习率】torch.optim.lr_scheduler学习率10种调整方法整理
  • 11.11. 学习率调度器
  • Pytorch – 手动调整学习率以及使用torch.optim.lr_scheduler调整学习率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/965905.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

在 Flownex 中创建自定义工作液

在这篇博文中,我们将了解如何在 Flownex 中为流网添加和定义一种新的流体温度相关工作材料。 Flownex 物料管理界面 在 Flownex 中使用与温度相关的流体材料时,了解其特性与温度的关系非常重要。这种了解可确保准确预测各种热条件下的流体行为&#xff0…

记一次golang环境的变化

前两天编译打包了了个文件,把env的 goos 搞坏了 导致运行项目一直报错 先是这样 go: unsupported GOOS/GOARCH pair windows/amd64再是这样 /amd64supported GOOS/GOARCH pair linux咱就说,咱也是知道环境配置的有问题 ( go env GOOS &…

算法【Java】—— 动态规划之子序列问题

最长递增子序列 https://leetcode.cn/problems/longest-increasing-subsequence 状态表示:和之前的经验一样,dp[i] 表示 以 i 为结尾元素的所有递增子序列中最大长度是多少 状态转移方程推导:从 i 前面的元素开始寻找,当 nums[j…

ASP.NET Core标识框架Identity

目录 Authentication与Authorization 标识框架(Identity) Identity框架的使用 初始化 自定义属性 案例一:添加用户、角色 案例二:检查登录用户信息 案例三:实现密码的重置 步骤 Authentication与Authorizatio…

124,【8】buuctf web [极客大挑战 2019] Http

进入靶场 查看源码 点击 与url有关,抓包 over

windows下安装Open Web UI

windows下安装openwebui有三种方式,docker,pythonnode.js,整合包. 这里我选择的是第二种,非docker. 非Docker方式安装 1. 安装Python: 下载并安装Python 3.11,建议安装路径中不要包含中文字符,并勾选“Add python 3.11 to Path”选项。 安…

Mac 基于Ollama 本地部署DeepSeek离线模型

最近节日期间最火的除了《哪吒》就是deepseek了,毕竟又让西方各个层面都瑟瑟发抖的产品。DeepSeek凭借其强大的AI能力真的是在全球多个领域展现出强大的影响力。由于受到外部势力的恶意攻击倒是deepseek官方服务不稳定,国内其他厂家的适配版本也不是很稳…

解决aspose将Excel转成PDF中文变成方框的乱码问题

原文网址:解决aspose将Excel转成PDF中文变成方框的乱码问题_IT利刃出鞘的博客-CSDN博客 简介 本文介绍如何解决aspose将Excel转成PDF中文变成方框的乱码问题。 问题描述 用aspose将word、excel等转成PDF后,英文展示正常,但中文全部变成了…

Jupyter Notebook自动保存失败等问题的解决

一、未生成配置文件 需要在命令行中,执行下面的命令自动生成配置文件 jupyter notebook --generate-config 执行后会在 C:\Users\用户名\.jupyter目录中生成文件 jupyter_notebook_config.py 二、在网页端打开Jupyter Notebook后文件保存失败;运行代码…

【漫话机器学习系列】083.安斯库姆四重奏(Anscombe‘s Quartet)

安斯库姆四重奏(Anscombes Quartet) 1. 什么是安斯库姆四重奏? 安斯库姆四重奏(Anscombes Quartet)是一组由统计学家弗朗西斯安斯库姆(Francis Anscombe) 在 1973 年 提出的 四组数据集。它们…

Axure设计教程:动态排名图(中继器实现)

一、开篇 在Axure原型设计中,动态图表是展示数据和交互效果的重要元素。今天,我们将学习如何使用中继器来创建一个动态的排名图,该图表不仅支持自动轮播,还可以手动切换,极大地增强了用户交互体验。此教程旨在提供一个…

MySQL视图索引操作

创建学生表; mysql> create table Student(-> Sno int primary key auto_increment,-> Sname varchar(30) not null unique,-> Ssex char(2) check (Ssex男 or Ssex女) not null,-> Sage int not null,-> Sdept varchar(10) default 计算机 not …

【正点原子K210连载】第六十七章 音频FFT实验 摘自【正点原子】DNK210使用指南-CanMV版指南

第六十七章 音频FFT实验 本章将介绍CanMV下FFT的应用,通过将时域采集到的音频数据通过FFT为频域。通过本章的学习,读者将学习到CanMV下控制FFT加速器进行FFT的使用。 本章分为如下几个小节: 32.1 maix.FFT模块介绍 32.2 硬件设计 32.3 程序设…

基于 Ollama+Docker+OpenWebUI 的本地化部署deepseek流程

搭建deepseek 安装Ollama Ollama官方下载地址 下载完成后双击打开Ollama进行安装,点击install 安装完成后系统会弹出下图提示代表安装成功并且已启动 验证安装 ollama -v安装完成后,cmd 打开命令行窗口,输入 “ollama -v” 测试,显示 olla…

Mac 部署Ollama + OpenWebUI完全指南

文章目录 💻 环境说明🛠️ Ollama安装配置1. 安装[Ollama](https://github.com/ollama/ollama)2. 启动Ollama3. 模型存储位置4. 配置 Ollama 🌐 OpenWebUI部署1. 安装Docker2. 部署[OpenWebUI](https://www.openwebui.com/)(可视化…

C#常用集合优缺点对比

先上结论&#xff1a; 在C#中&#xff0c;链表、一维数组、字典、List<T>和ArrayList是常见的数据集合类型&#xff0c;它们各有优缺点&#xff0c;适用于不同的场景。以下是它们的比较&#xff1a; 1. 一维数组 (T[]) 优点&#xff1a; 性能高&#xff1a;数组在内存中…

额外题目汇总2-链表

链表 1.24. 两两交换链表中的节点 力扣题目链接(opens new window) 给定一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后的链表。 你不能只是单纯的改变节点内部的值&#xff0c;而是需要实际的进行节点交换。 思路 使用虚拟头结点会很方便&#xff…

Nginx 中启用 Gzip 压缩以优化网页加载速度

&#x1f3e1;作者主页&#xff1a;点击&#xff01; Nginx-从零开始的服务器之旅专栏&#xff1a;点击&#xff01; &#x1f427;Linux高级管理防护和群集专栏&#xff1a;点击&#xff01; ⏰️创作时间&#xff1a;2025年2月7日17点14分 目录 1. 配置网页压缩 目的 …

《云夹:高效便捷的书签管理利器》

在信息爆炸的时代&#xff0c;我们每天都会浏览大量的网页&#xff0c;遇到许多有价值的内容。如何高效地管理这些网页书签&#xff0c;以便随时快速访问&#xff0c;成为了一个重要的问题。云夹作为一款出色的书签管理工具&#xff0c;为我们提供了完美的解决方案。 强大的功能…

学习数据结构(6)链表OJ

1.移除链表元素 解法一&#xff1a;&#xff08;我的做法&#xff09;在遍历的同时移除&#xff0c;代码写法比较复杂 解法二&#xff1a;创建新的链表&#xff0c;遍历原链表&#xff0c;将非val的节点尾插到新链表&#xff0c;注意&#xff0c;如果原链表结尾是val节点需要将…