Pytorch参数优化

前言:
当我们训练神经网络时,我们需要调整模型的参数,使得损失函数的值逐渐减小,从而优化模型。但是模型的参数我们一般是无法看见的,所以我们必须学会对参数的更新,下面,我将介绍两种参数更新的方法

下面以梯度下降法为例进行展示:

  1. 手动遍历参数更新

在PyTorch中,模型的参数是通过torch.nn.Parameter类来表示的,并存储在模型的parameters()方法返回的迭代器中。

for param in models.parameters():
    param.data -= param.grad.data * lr
  • 我们遍历模型models中的每个参数,通过param.data来访问参数的值,即参数的张量。在训练过程中,通过反向传播计算得到每个参数的梯度,这些梯度存储在param.grad.data中。梯度表示损失函数关于参数的变化率,通过更新参数,我们期望能够朝着损失函数下降的方向调整参数值。
  • 学习率lr是梯度下降法的超参数,它决定了每次更新参数的步幅。在梯度下降中,我们通过梯度与学习率的乘积来更新参数的值。这个操作使得参数朝着损失函数下降最快的方向更新,从而优化模型。
  1. 参数优化器

torch.optim是PyTorch中用于实现优化算法的模块。它提供了多种常用的优化器,可以用于自动调整模型参数以最小化损失函数,从而实现神经网络的训练。
优化器的作用是根据模型的梯度信息来更新模型的参数,以最小化损失函数。在神经网络的训练过程中,优化器会不断地调整参数值,使得模型的预测结果与真实标签更接近,从而提高模型的性能。
torch.optim模块提供了许多优化器,常见的包括:

  • SGD(Stochastic Gradient Descent,随机梯度下降):每次迭代使用单个样本计算梯度,更新模型参数。是最经典的优化算法之一。
  • Adam(Adaptive Moment Estimation,自适应矩估计):结合了动量法和RMSprop方法,并进行了参数的偏差校正。在深度学习中广泛使用,通常能够快速收敛。
  • RMSprop(Root Mean Square Propagation,均方根传播):调整学习率来适应不同的参数。
  • Adagrad(Adaptive Gradient Algorithm,自适应梯度算法):对每个参数使用不同的学习率,以适应不同参数的更新频率。
  • Adadelta:是对Adagrad的扩展,使用了更稳定的学习率。
  • AdamW:是对Adam优化器的改进版本,添加了权重衰减。

使用torch.optim优化器的基本流程是:

  1. 定义神经网络模型。
  2. 定义损失函数。
  3. 创建优化器对象,将模型的参数传递给优化器。
  4. 在每个训练迭代中,执行以下步骤:
    a. 前向传播计算预测值。
    b. 计算损失函数。
    c. 将优化器的梯度清零。
    d. 反向传播计算梯度。
    e. 使用优化器来更新模型参数。
import torch
from torch.optim import SGD

# ... 定义模型和其他训练相关的代码 ...

# 定义优化器
optimizer = SGD(models.parameters(), lr=lr)	#传入参数(参数和梯度),超参数(学习率)
# 迭代进行训练
for epoch in range(epoch_n):
    y_pred = models(x)  # 前向传播,计算预测值
    loss = loss_fn(y_pred, y)  # 计算均方误差损失
    if epoch % 1000 == 0:
        print("epoch:{}, loss:{:.4f}".format(epoch, loss.item()))
    optimizer.zero_grad()  # 将模型参数的梯度清零,避免梯度累积
    loss.backward()  # 反向传播,计算梯度
    optimizer.step()  # 使用优化器来自动更新模型参数

完整演示

import torch
import torch.nn as nn
import torch.optim as optim

# 定义神经网络模型
class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc = nn.Linear(2, 1)

    def forward(self, x):
        return self.fc(x)

# 定义训练数据和目标数据
x_train = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0]], dtype=torch.float32)
y_train = torch.tensor([[3.0], [5.0], [7.0]], dtype=torch.float32)

# 创建神经网络模型和损失函数
model = SimpleModel()
loss_fn = nn.MSELoss()

# 创建优化器对象,将模型参数传递给优化器
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 定义训练轮数
epochs = 1000

# 训练过程
for epoch in range(epochs):
    # 前向传播
    y_pred = model(x_train)
    
    # 计算损失函数
    loss = loss_fn(y_pred, y_train)
    
    # 将优化器的梯度缓存清零
    optimizer.zero_grad()
    
    # 反向传播
    loss.backward()
    
    # 使用优化器来更新模型参数
    optimizer.step()
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

# 在训练完成后,可以使用训练好的模型来进行预测
x_new = torch.tensor([[4.0, 5.0], [5.0, 6.0]], dtype=torch.float32)
with torch.no_grad():
    y_pred_new = model(x_new)
    print("Predictions for new data:")
    print(y_pred_new)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/52395.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Vulnhub: hacksudo: search靶机

kali:192.168.111.111 靶机:192.168.111.170 信息收集 端口扫描 nmap -A -sC -v -sV -T5 -p- --scripthttp-enum 192.168.111.170 80端口目录爆破 feroxbuster -k -d 1 --url http://192.168.111.170 -w /opt/zidian/SecLists-2022.2/Discovery/Web…

机器学习之Boosting和AdaBoost

1 Boosting和AdaBoost介绍 1.1 集成学习 集成学习 (Ensemble Learning) 算法的基本思想就是将多个分类器组合,从而实现一个预测效果更好的集成分类器。 集成学习通过建立几个模型来解决单一预测问题。它的工作原理是生成多个分类器/模型,各自独立地学…

ChatGPT长文本对话输入方法

ChatGPT PROMPTs Splitter 是一个开源工具,旨在帮助你将大量上下文数据分成更小的块发送到 ChatGPT 的提示,并根据如何处理所有块接收到 ChatGPT(或其他具有字符限制的语言模型)的方法。 推荐:用 NSDT设计器 快速搭建可…

iOS开发-NotificationServiceExtension实现实时音视频呼叫通知响铃与震动

iOS开发-NotificationServiceExtension实现实时音视频呼叫通知响铃与震动 在之前的开发中,遇到了实时音视频呼叫通知,当App未打开或者App在后台时候,需要通知到用户,用户点击通知栏后是否接入实时音视频的视频或者音频通话。 在…

【雕爷学编程】MicroPython动手做(17)——掌控板之触摸引脚

知识点:什么是掌控板? 掌控板是一块普及STEAM创客教育、人工智能教育、机器人编程教育的开源智能硬件。它集成ESP-32高性能双核芯片,支持WiFi和蓝牙双模通信,可作为物联网节点,实现物联网应用。同时掌控板上集成了OLED…

ChatGLM-6B 部署与 P-Tuning 微调实战-使用Pycharm实战

国产大模型ChatGLM-6B微调部署入门-使用Pycharm实战 1.ChatGLM模型介绍 ChatGLM-6B 是一个开源的、支持中英双语的对话语言模型,基于 General Language Model (GLM) 架构,具有 62 亿参数。结合模型量化技术,用户可以在消费级的显卡上进行本…

【指针二:穿越编程边界的超能力】

本章重点 5. 函数指针 6. 函数指针数组 7. 指向函数指针数组的指针 8. 回调函数 五、函数指针 首先看一段代码: 输出的是两个地址相同,这两个相同的地址都是 test 函数的地址。 那我们的函数的地址要想保存起来,怎么保存? 下面我…

Install the Chinese input method on Linux

Open terminal and input: sudo -i apt install fcitx fcitx-googlepinyinWait for it to finish. Search fcitx: "设置"-->"输入法": Finally, we get the following result: Ctrl Space:Switch the input method. The test …

HbuilderX运行时遇见文件找不到问题

错误类型 解决方法 找到报错的文件 系统提示crypto-js 和 sm-crypto 找不到,然后注释掉找不到的文件 运行成功!!!

【OpenCV • c++】图像几何变换 | 图像坐标映射

🚀 个人简介:CSDN「博客新星」TOP 10 , C/C 领域新星创作者💟 作 者:锡兰_CC ❣️📝 专 栏:【OpenCV • c】计算机视觉🌈 若有帮助,还请关注➕点赞➕收藏&#xff…

今天学学消息队列RocketMQ:消息类型

RocketMQ支持的消息类型有三种:普通消息、顺序消息、延时消息、事务消息。以下内容的代码部分都是基于rocketmq-spring-boot-starter做的。 普通消息 普通消息是一种无序消息,消息分布在各个MessageQueue当中,以保证效率为第一使命。这种消息…

AI绘画Stable Diffusion原理之Autoencoder-Latent

前言 传送门: stable diffusion:Git|论文 stable-diffusion-webui:Git Google Colab Notebook:Git kaggle Notebook:Git 今年AIGC实在是太火了,让人大呼许多职业即将消失,比如既能帮…

【Vscode | R | Win】R Markdown转html记录-Win

Rmd文件转html R语言环境Vscode扩展安装及配置配置radian R依赖包pandoc安装配置pandoc环境变量验证是否有效转rmd为html 注意本文代码块均为R语言代码,在R语言环境下执行即可 R语言环境 官网中去下载R语言安装包以及R-tool 可自行搜寻教程 无需下载Rstudio Vscod…

Linux:ELK:日志分析系统(使用elasticsearch集群)

原理 1. 将日志进行集中化管理(beats) 2. 将日志格式化(logstash) 将其安装在那个上面就对那个进行监控 3. 对格式化后的数据进行索引和存储(elasticsearch) 4. 前端数据的展示(kibana&…

python多进程编程(模式与锁)

multiprocessing的三种模式 fork,【拷贝几乎所有资源】【支持文件对象/线程锁等传参】【unix】【任意位置开始】【快】spawn,【run参数传参必备资源】【不支持文件对象/线程锁等传参】【unix、win】【main代码块开始】【慢】forkserver,【ru…

C++ 类和对象

面向过程/面向对象 C语言是面向过程,关注过程,分析出求解问题的步骤,通过函数调用逐步解决问题 C是基于面对对象的,关注的是对象——将一件事拆分成不同的对象,依靠对象之间的交互完成 引入 C语言中结构体只能定义…

41. linux通过yum安装postgresql

文章目录 1.下载安装包2.关闭内置PostgreSQL模块:3.安装postgresql服务:4.初始化postgresql数据库:5.设置开机自启动:6.启动postgresql数据库7.查看postgresql进程8.通过netstat命令或者lsof 监听默认端口54329.使用find命令查找了一下postgresql.conf的配置位置10.修改postgre…

ARM将常数加载到寄存器方法之LDR伪指令

一、是什么? LDR Rd,const伪指令可在单个指令中构造任何32位数字常数,使用伪指令可以生成超过MOV和MVN指令 允许范围的常数. 实现原理: (1)如果可以用MOV或MVN指令构造该常数,则汇编程序会生成适当的指令 (2)如果不能用MOV或MVN指令构造该常数,则汇编程序会执行下列…

QEMU源码全解析19 —— QOM介绍(8)

接前一篇文章:QEMU源码全解析18 —— QOM介绍(7) 本文内容参考: 《趣谈Linux操作系统》 —— 刘超,极客时间 《QEMU/KVM》源码解析与应用 —— 李强,机械工业出版社 特此致谢! 上一回讲到了Q…

用C语言实现堆排序算法

1.设计思路 排序的思想将一个数组按递增的顺序进行排序,将数组的第一个位置空下(下标为0),因为会导致子节点和本身同一个结点(i和2i一致),每次堆排序在下标1的位置放上了最大值,然后…