【Python】学习率调整策略详解和示例

学习率调整得当将有助于算法快速收敛和获取全局最优,以获得更好的性能。本文对学习率调度器进行示例介绍。

  • 学习率调整的意义
  • 基础示例
    • 无学习率调整方法
    • 学习率调整方法一
    • 多因子调度器
    • 余弦调度器
  • 结论

学习率调整的意义

首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果(陷入局部最优)。我们之前看到问题的条件数很重要。直观地说,这是最不敏感与最敏感方向的变化量的比率。

其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。 简而言之,我们希望速率衰减,但要比慢,这样能成为解决凸问题的不错选择

另一个同样重要的方面是初始化。这既涉及参数最初的设置方式,又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。本文将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

基础示例

我们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。

无学习率调整方法

import math
import torch
from torch import nn
from torch.optim import lr_scheduler, SGD
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

def load_data_fashion_mnist(batch_size):
    # 定义数据预处理
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    # 加载训练集和测试集
    train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)
    test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)

    # 创建数据加载器
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_loader, test_loader
def net_fn():
    model = nn.Sequential(
        nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
        nn.MaxPool2d(kernel_size=2, stride=2),
        nn.Flatten(),
        nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
        nn.Linear(120, 84), nn.ReLU(),
        nn.Linear(84, 10))
    return model


def train(net, train_loader, test_loader, num_epochs, loss, optimizer, device, scheduler=None):
    net.to(device)
    running_loss = 0.0
    train_losses = []
    test_losses = []
    test_accuracies = []
    for epoch in range(num_epochs):
        for i, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = net(inputs)
            loss_value = loss(outputs, labels)

            # Backward and optimize
            loss_value.backward()
            optimizer.step()

            # Print statistics
            running_loss += loss_value.item()

            # if i % 200 == 199:  # print every 200 mini-batches
            #     print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200}')
            #     running_loss = 0.0
        train_losses.append(running_loss / len(train_loader))
        # Evaluate the model on the test dataset
        test_loss, test_acc = evaluate(net, test_loader, device)
        test_losses.append(test_loss)
        test_accuracies.append(test_acc)
        print(f'Epoch {epoch+1}, Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Test Acc: {test_accuracies[-1]:.2f}')

        if scheduler:
            if scheduler.__module__ == lr_scheduler.__name__:

                scheduler.step()
            else:

                for param_group in  optimizer.param_groups:
                    param_group['lr'] = scheduler(epoch)

    plt.figure(figsize=(10, 6))
    plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')
    plt.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')
    plt.plot(range(1, num_epochs + 1), test_accuracies, label='Test Accuracy')
    plt.title('Training, Test Losses and Test Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Loss / Accuracy')
    plt.legend()
    plt.grid(True)
    plt.savefig("1.jpg")
    plt.show()





def evaluate(model, data_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for inputs, labels in data_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            test_loss += nn.CrossEntropyLoss(reduction='sum')(outputs, labels).item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
    test_loss /= len(data_loader.dataset)
    accuracy = correct / len(data_loader.dataset)
    #accuracy = 100. * correct / len(data_loader.dataset)
    return test_loss, accuracy


# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Define the model
model = net_fn()

# Define the loss function
loss = nn.CrossEntropyLoss()

# Define the optimizer
lr=0.3
optimizer = SGD(model.parameters(), lr=lr)



# Load the dataset
batch_size=128
train_loader, test_loader=load_data_fashion_mnist(batch_size)
num_epochs=30
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device)

这里没有使用学习率调整策略。训练过程和结果如下图所示:

.
.
.
.
Epoch 23, Train Loss: 0.1247, Test Loss: 0.3939, Test Acc: 0.90
Epoch 24, Train Loss: 0.1236, Test Loss: 0.4370, Test Acc: 0.89
Epoch 25, Train Loss: 0.1167, Test Loss: 0.4117, Test Acc: 0.89
Epoch 26, Train Loss: 0.1169, Test Loss: 0.4440, Test Acc: 0.89
Epoch 27, Train Loss: 0.1163, Test Loss: 0.4336, Test Acc: 0.89
Epoch 28, Train Loss: 0.1055, Test Loss: 0.4312, Test Acc: 0.90
Epoch 29, Train Loss: 0.1065, Test Loss: 0.4942, Test Acc: 0.89
Epoch 30, Train Loss: 0.1051, Test Loss: 0.4763, Test Acc: 0.89

在这里插入图片描述

学习率调整方法一

设置在每个迭代轮数(甚至在每个小批量)之后向下调整学习率。 例如,以动态的方式来响应优化的进展情况。

在代码最后添加SquareRootScheduler类,并更新train()函数参数,其它内容不变。

class SquareRootScheduler:
    def __init__(self, lr=0.1):
        self.lr = lr

    def __call__(self, num_update):
        return self.lr * pow(num_update + 1.0, -0.5)

scheduler = SquareRootScheduler(lr=0.1)
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device,scheduler)

运行代码,可得相应参数值和变化过程,如下所示。

Epoch 23, Train Loss: 0.1823, Test Loss: 0.2811, Test Acc: 0.90
Epoch 24, Train Loss: 0.1801, Test Loss: 0.2800, Test Acc: 0.90
Epoch 25, Train Loss: 0.1767, Test Loss: 0.2819, Test Acc: 0.90
Epoch 26, Train Loss: 0.1747, Test Loss: 0.2800, Test Acc: 0.91
Epoch 27, Train Loss: 0.1720, Test Loss: 0.2818, Test Acc: 0.90
Epoch 28, Train Loss: 0.1689, Test Loss: 0.2856, Test Acc: 0.90
Epoch 29, Train Loss: 0.1669, Test Loss: 0.2907, Test Acc: 0.90
Epoch 30, Train Loss: 0.1641, Test Loss: 0.2813, Test Acc: 0.90

在这里插入图片描述
我们可以看出曲线比没有策略时平滑了很多,效果有所提升。

多因子调度器

多因子调度器。
在这里插入图片描述
在这里插入图片描述
代码部分修改:

scheduler =lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.5)

运行结果为:
在这里插入图片描述
可见效果不理想,出现过拟合现象。

余弦调度器

余弦调度器是 (Loshchilov and Hutter, 2016)提出的一种启发式算法。 它所依据的观点是:我们可能不想在一开始就太大地降低学习率,而且可能希望最终能用非常小的学习率来“改进”解决方案。 这产生了一个类似于余弦的调度,函数形式如下所示,学习率的值在
之间。
在这里插入图片描述
代码中添加CosineScheduler类和修改scheduler。

class CosineScheduler:
    def __init__(self, max_update, base_lr=0.01, final_lr=0,
               warmup_steps=0, warmup_begin_lr=0):
        self.base_lr_orig = base_lr
        self.max_update = max_update
        self.final_lr = final_lr
        self.warmup_steps = warmup_steps
        self.warmup_begin_lr = warmup_begin_lr
        self.max_steps = self.max_update - self.warmup_steps

    def get_warmup_lr(self, epoch):
        increase = (self.base_lr_orig - self.warmup_begin_lr) \
                       * float(epoch) / float(self.warmup_steps)
        return self.warmup_begin_lr + increase

    def __call__(self, epoch):
        if epoch < self.warmup_steps:
            return self.get_warmup_lr(epoch)
        if epoch <= self.max_update:
            self.base_lr = self.final_lr + (
                self.base_lr_orig - self.final_lr) * (1 + math.cos(
                math.pi * (epoch - self.warmup_steps) / self.max_steps)) / 2
        return self.base_lr


#scheduler = SquareRootScheduler(lr=0.1)
#scheduler =lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.5)
scheduler = CosineScheduler(max_update=20, base_lr=0.3, final_lr=0.01)
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device,scheduler)

运行结果如下:

在这里插入图片描述
过拟合现象消失,效果提升。

结论

在开发时应根据自己需要,选择合适的学习率调整策略。优化在深度学习中有多种用途。对于同样的训练误差而言,选择不同的优化算法和学习率调度,除了最大限度地减少训练时间,可以导致测试集上不同的泛化和过拟合量。

注:部分内容摘选子书籍《动手学深度学习》

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/490491.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

音乐制作利器 :FL Studio21中文编曲音乐制作软件免费下载

一、引言 在音乐的世界里&#xff0c;每个人都有自己独特的音色和表达方式。而今天&#xff0c;我们要为你推荐一款能让您的音乐创作更上一层楼的神器——FL Studio21中文编曲音乐制作软件。这款功能强大的音乐制作软件&#xff0c;不仅拥有丰富的音色库和高效的编辑功能&#…

Quartz

Quartz 1.核心概念1.1 核心概念图1.2 demo 2.Job2.1为什么设计成JobDetailJob, 而不直接使用Job2.2 间隔执行时, 每次都会创建新的Job实例2.3 定时任务默认都是并发执行的&#xff0c;不会等待上一次任务执行完毕2.3.1 不允许并发执行 2.4 在运行时, 通过JobDataMap向Job传递数…

Python自动化测试环境搭建

&#x1f345; 视频学习&#xff1a;文末有免费的配套视频可观看 &#x1f345; 关注公众号&#xff1a;互联网杂货铺&#xff0c;回复1 &#xff0c;免费获取软件测试全套资料&#xff0c;资料在手&#xff0c;涨薪更快 请事先自行安装好​​Pycharm​​​软件哦&#xff0c;我…

深度学习模型部署(十二)CUDA编程-绪

CUDA 运行时 API 与 CUDA 驱动 API 速度没有差别&#xff0c;实际中使用运行时 API 较多&#xff0c;运行时 API 是在驱动 API 上的一层封装。​ CUDA 是什么&#xff1f;​ CUDA(Compute Unified Device Architecture) 是 nvidia 推出的一个通用并行技术架构&#xff0c;用它…

基于冠豪猪优化器(CPO)的无人机路径规划

该优化算法是2024年新发表的一篇SCI一区top论文具有良好的实际应用和改进意义。一键运行main函数代码自动保存高质量图片 1、冠豪猪优化器 摘要&#xff1a;受冠豪猪(crest Porcupine, CP)的各种防御行为启发&#xff0c;提出了一种新的基于自然启发的元启发式算法——冠豪猪…

视觉轮速滤波融合1讲:理论推导

视觉轮速滤波融合理论推导 文章目录 视觉轮速滤波融合理论推导1 坐标系2 轮速计2.1 运动学模型2.2 外参 3 状态和协方差矩阵3.1 状态3.2 协方差矩阵 4 Wheel Propagation4.1 连续运动学4.2 离散积分4.2.1 状态均值递推4.2.2 协方差递推 5 Visual update5.1 视觉残差与雅可比5.2…

蓝桥杯2023年第十四届省赛真题-买瓜|DFS+剪枝

题目链接&#xff1a; 0买瓜 - 蓝桥云课 (lanqiao.cn) 蓝桥杯2023年第十四届省赛真题-买瓜 - C语言网 (dotcpp.com) &#xff08;蓝桥官网的数据要求会高一些&#xff09; 说明&#xff1a; 这道题可以分析出&#xff1a;对一个瓜有三种选择&#xff1a; 不拿&#xff0c…

Vue3基础笔记(2)事件

一.事件处理 1.内联事件处理器 <button v-on:click"count">count1</button> 直接将事件以表达式的方式书写~ 每次单击可以完成自增1的操作~ 2.方法事件处理器 <button click"addcount(啦啦啦~)">count2</button> 如上&…

每日必学Linux命令:mv命令

mv命令是move的缩写&#xff0c;可以用来移动文件或者将文件改名&#xff08;move (rename) files&#xff09;&#xff0c;是Linux系统下常用的命令&#xff0c;经常用来备份文件或者目录。 一&#xff0e;命令格式&#xff1a; mv [选项] 源文件或目录 目标文件或目录二&am…

Open WebUI大模型对话平台-适配Ollama

什么是Open WebUI Open WebUI是一种可扩展、功能丰富、用户友好的大模型对话平台&#xff0c;旨在完全离线运行。它支持各种LLM运行程序&#xff0c;包括与Ollama和Openai兼容的API。 功能 直观的界面:我们的聊天界面灵感来自ChatGPT&#xff0c;确保了用户友好的体验。响应…

轻松掌握C语言中的sqrt函数,快速计算平方根的魔法秘诀

C语言文章更新目录 C语言学习资源汇总&#xff0c;史上最全面总结&#xff0c;没有之一 C/C学习资源&#xff08;百度云盘链接&#xff09; 计算机二级资料&#xff08;过级专用&#xff09; C语言学习路线&#xff08;从入门到实战&#xff09; 编写C语言程序的7个步骤和编程…

第1章 实时3D渲染流水线

前言 本书所剖析的Unity 3D内置着色器代码版本是2017.2.0f3&#xff0c;读者可以从Unity 3D官网下载这些着色器代码。这些代码以名为builtin_shaders-2017.2.0f3.zip的压缩包的形式提供&#xff0c;解压缩后&#xff0c;内有4个目录和1个license.txt文件。 目录CGIncludes存放了…

苍穹外卖项目-01(开发流程,介绍,开发环境搭建,nginx反向代理,Swagger)

目录 一、软件开发整体介绍 1. 软件开发流程 1 第1阶段: 需求分析 2 第2阶段: 设计 3 第3阶段: 编码 4 第4阶段: 测试 5 第5阶段: 上线运维 2. 角色分工 3. 软件环境 1 开发环境(development) 2 测试环境(testing) 3 生产环境(production) 二、苍穹外卖项目介绍 …

Docker搭建LNMP环境实战(05):CentOS环境安装Docker-CE

前面几篇文章讲了那么多似乎和Docker无关的实战操作&#xff0c;本篇总算开始说到Docker了。 1、关于Docker 1.1、什么是Docker Docker概念就是大概了解一下就可以&#xff0c;还是引用一下百度百科吧&#xff1a; Docker 是一个开源的应用容器引擎&#xff0c;让开发者可以…

SE注意力模块学习笔记《Squeeze-and-Excitation Networks》

Squeeze-and-Excitation Networks 摘要引言什么是全局平均池化&#xff1f; 相关工作Deep architectures Squeeze-and-Excitation Blocks3.1. Squeeze: Global Information Embedding3.2. Excitation: Adaptive Recalibration3.3. Exemplars: SE-Inception and SE-ResNet 5. Im…

百科词条编辑必备指南,让你轻松上手创建

1.注册账号&#xff1a;首先&#xff0c;你需要注册一个百科平台的账号。例如&#xff0c;对于百度百科&#xff0c;你需要有一个百度账号。 搜索词条&#xff1a;在百科全书平台上搜索您想要编辑的词条。如果词条已经存在&#xff0c;可以直接编辑&#xff1b;如果词条不存在&…

(已解决)vue3使用富文本出现样式乱码

我在copy代码到项目里面时候发现我的富文本乱码了 找了一圈不知道是哪里vue3不适配还是怎么&#xff0c;后来发现main.js还需要引入 import VueQuillEditor from vue-quill-editor // require styles 引入样式 import quill/dist/quill.core.css import quill/dist/quill.snow…

计算机组成原理(超详解!!) 第三节 运算器(浮点加减乘)

1.浮点加法、减法运算 操作过程 1.操作数检查 如果能够判断有一个操作数为0&#xff0c;则没必要再进行后续一系列操作&#xff0c;以节省运算时间。 2.完成浮点加减运算的操作 (1) 比较阶码大小并完成对阶 使二数阶码相同&#xff08;即小数点位置对齐&#xff09;…

力扣Lc21--- 389. 找不同(java版)-2024年3月26日

1.题目描述 2.知识点 &#xff08;1&#xff09;在这段代码中&#xff1a; // 统计字符串s中每个字符的出现次数for (int i 0; i < s.length(); i) {count[s.charAt(i) - a];}对于字符串s “abcd”&#xff1a; 当 i 0&#xff0c;s.charAt(i) ‘a’&#xff0c;ASCII…

牛客小白月赛89(A,B,C,D,E,F)

比赛链接 官方视频讲解&#xff08;个人觉得讲的还是不错的&#xff09; 这把BC偏难&#xff0c;差点就不想做了&#xff0c;对小白杀伤力比较大。后面的题还算正常点。 A 伊甸之花 思路&#xff1a; 发现如果这个序列中最大值不为 k k k&#xff0c;我们可以把序列所有数…