神经网络 torch.nn---优化器的使用

torch.optim - PyTorch中文文档 (pytorch-cn.readthedocs.io)

torch.optim — PyTorch 2.3 documentation

反向传播可以求出神经网路中每个需要调节参数的梯度(grad),优化器可以根据梯度进行调整,达到降低整体误差的作用。下面我们对优化器进行介绍。

构造优化器

optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam([var1, var2], lr=0.0001)
  • 选择优化器的算法optim.SGD
  • 之后在优化器中放入模型参数model.parameters(),这一步是必备
  • 还可在函数中设置一些参数,如学习速率lr=0.01(这是每个优化器中几乎都会有的参数)

调用优化器中的step方法

step()方法就是利用我们之前获得的梯度,对神经网络中的参数进行更新。

for input, target in dataset:
    optimizer.zero_grad()
    output = model(input)
    loss = loss_fn(output, target)
    loss.backward()
    optimizer.step()
  • 步骤optimizer.zero_grad()是必选的。对每个参数的梯度进行清零

  • 我们的输入经过了模型,并得到了输出output

  • 之后计算输出和target之间的误差loss

  • 调用误差的反向传播loss.backwrd更新每个参数对应的梯度

  • 调用optimizer.step()对卷积核中的参数进行优化调整

  • 之后继续进入for循环,使用函数optimizer.zero_grad()对每个参数的梯度进行清零,防止上一轮循环中计算出来的梯度影响下一轮循环。

优化器的使用

优化器中算法共有的参数(其他参数因优化器的算法而异):

  • params: 传入优化器模型中的参数

  • lr: learning rate,即学习速率

关于学习速率

  • 一般来说,学习速率设置得太大,模型运行起来会不稳定

  • 学习速率设置得太小,模型训练起来会过

  • 建议在最开始训练模型的时候,选择设置一个较大的学习速率;训练到后面的时候,再选择一个较小的学习速率

代码实战:

选择CIFAR10

import torch
import torchvision
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Linear, Sequential, Flatten
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset = torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(),
                                       download=True)
dataloader = DataLoader(dataset, batch_size=1)

class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = Sequential(
            Conv2d(3, 32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(32, 32, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Conv2d(32, 64, kernel_size=5, padding=2),
            MaxPool2d(kernel_size=2),
            Flatten(),
            Linear(64 * 4 * 4, 64),
            Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

loss = nn.CrossEntropyLoss()
tudui = Tudui()
optim = torch.optim.SGD(tudui.parameters(), lr=0.01)
for epoch in range(20):
    running_loss = 0.0
    for data in dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        # print(outputs)
        # print(targets)
        result_loss = loss(outputs, targets)
        optim.zero_grad()
        result_loss.backward()
        optim.step()
        # print(result_loss)
        running_loss = running_loss + result_loss

    print(running_loss)

输出:

总结使用优化器训练的训练套路):

  • 设置损失函数loss function

  • 定义优化器optim

  • 从使用循环dataloader中的数据:for data in dataloder

    • 取出图片imgs,标签targets:imgs,targets=data

    • 将图片放入神经网络,并得到一个输出:output=model(imgs)

    • 计算误差:loss_result=loss(output,targets)

    • 使用优化器,初始化参数的梯度为0:optim.zero_grad()

    • 使用反向传播求出梯度:loss_result.backward()

    • 根据梯度,对每一个参数进行更新:optim.step()

  • 进入下一个循环,直到完成训练所需的循环次数

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/685111.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2024年6月6日第十五周六级系列1作文

对于2024年上半年六级作文预测的大纲 一、作文总体格式与要求 六级作文将继续沿用以往的大致格式,主要包括两个主要任务: Task 1: 针对给定话题的简短摘要或观点文章(150-200字)。Task 2: 就更复杂的话题进行深入分析的作文&…

23 - 每位教师所教授的科目种类的数量(高频 SQL 50 题基础版)

23 - 每位教师所教授的科目种类的数量 考点: 排序和分组 selectteacher_id,count(distinct subject_id) cnt fromTeacher group byteacher_id;

完整指南:远程管理 Linux 服务器的 Xshell6 和 Xftp6 使用方法(Xshell无法启动:要继续使用此程序........,的解决方法)

😀前言 在当今软件开发领域,远程管理 Linux 服务器已成为日常工作的重要组成部分。随着团队成员分布在不同的地理位置,远程登录工具的使用变得至关重要,它们为开发人员提供了访问和管理服务器的便捷方式。本文将介绍两款功能强大的…

敏捷开发:拥抱变化,持续交付价值的艺术

目录 敏捷开发:拥抱变化,持续交付价值的艺术 引言 第一部分:敏捷开发是什么? a.定义:敏捷开发的基本概念和核心原则 b.历史:敏捷宣言的诞生和敏捷开发的历史背景 c.价值观:敏捷宣言的12条…

一篇文章带你入门XXE

1.什么是XXE? XML External Entity(XXE)攻击是一种利用 XML 处理器的漏洞,通过引入恶意的外部实体来攻击应用程序的安全性。这种攻击通常发生在对用户提供的 XML 数据进行解析时,攻击者利用了 XML 规范允许引用外部实体…

USB (1)

再看USB 其实回过头看USB 是可以和无线 有线(internet) 可以和PCIe 甚至AXI类比的 它们共通处在于都是communication 有限通信 有7层网络协议 在USB也是分层的 不同的协议,依赖的介质不一样 这就要求相应的Physical层处理 USB的physical层是基于Serdes的,所以有串并/并串的…

JavaScript html css前端 日期对象 date对象 日期格式化 时间戳

日期对象 Date对象 Date 对象和 Math 对象不一样,他是一个构造函数,所以我们需要实例化后才能使用 Date 实例用来处理日期和时间 Date()使用方法 示例:获取当前时间 let now new Date() console.log(now) 示例:获取指定时间…

二叉树非递归遍历(C++)

文章目录 前言一、前序遍历二、中序遍历三、后序遍历总结 前言 我们之前学习过用递归解决二叉树的前序,中序,后序。 下面我们将用非递归,也就是遍历的方法对二叉树进行遍历 一、前序遍历 我们要解决这个问题最重要的就是用另一个角度看待这…

推荐个 Edge/Chrome/Firefox 都支持的 IP 定位查询扩展

作为一个博客站长,对 IP 地址应该都不陌生,可以说是跟站长的工作是息息相关的,反正明月几乎每天都会面临 IP 查询、定位的需要,今天让明月给找到了一个叫”IP 定位查询“的浏览器扩展,在 Edge 和 Firefox 下体验后感觉…

嵌入式Linux系统编程 — 2.2 标准I/O库:检查或复位状态

目录 1 检查或复位状态简介 2 feof()函数 2.1 feof()函数简介 2.2 示例程序 3 ferror()函数 4 clearerr()函数 4.1 clearerr()函数简介 4.2 示例程序 1 检查或复位状态简介 调用 fread() 函数读取数据时,如果返回值小于参数 nmemb 所指定的值,这…

【MySQL数据库】:MySQL内外连接

目录 内外连接和多表查询的区别 内连接 外连接 左外连接 右外连接 简单案例 内外连接和多表查询的区别 在 MySQL 中,内连接是多表查询的一种方式,但多表查询包含的范围更广泛。外连接也是多表查询的一种具体形式,而多表查询是一个更…

[笔试训练](三十四)100:[NOIP2008]ISBN号码101:kotori和迷宫102:矩阵最长递增路径

目录 100:[NOIP2008]ISBN号码 101:kotori和迷宫 102:矩阵最长递增路径 100:[NOIP2008]ISBN号码 题目链接:[NOIP2008]ISBN号码_牛客题霸_牛客网 (nowcoder.com) 题目&#xff1a; 题解: 简单模拟 #include <iostream> #include<string> using namespace std; str…

【ARM Cache 系列文章 7.2 – ARMv8/v9 MMU 页表配置详细介绍 03 】

请阅读【ARM Cache 及 MMU/MPU 系列文章专栏导读】 及【嵌入式开发学习必备专栏】 文章目录 表描述符 Table descriptor52-bit OA 颗粒为4KB 和16KB52-bit OA 颗粒为64KB48-bit OA 颗粒为4KB 和16KBStage 1 和 Stage 2 介绍第一阶段(Stage 1)转换的表描述符属性字段第二阶段(…

Apple开发者Identifier唯一身份标识创建

1. 选中Identifiers然后点击加号进入创建页面 2.选择要注册的标识类型 选择类型为App然后点击继续 输入应用描述与BundleID并勾选要开启的功能后点击继续 点击注册标识 注册成功后,会在标识列表中看到

关于浔川python社在全网博主原力月度排名泸州地区排名第三名的调查——浔川总社部

6月6日&#xff0c;浔川python社在全网博主原力月度排名泸州地区排名第三名。 引起了广大python爱好者等的不满&#xff0c;为什么浔川python社这么一个大的python社排名第三&#xff1f; 2024-06-04 21:59:52 浔川python社在全网博主原力月度排名泸州地区排名第一名。 2024-…

前端开发之性能优化

本文章 对各大学习技术论坛知识点&#xff0c;进行总结、归纳自用学习&#xff0c;共勉&#x1f64f; 文章目录 1. [CDN](https://www.bootcdn.cn/)2.懒加载3.缓存4.图片压缩5.图片分割6.sprite7.Code Splitting8.gzip9.GPU加速10.Ajax11.Tree Shaking12.Resource Hints 1. CD…

Centos修改默认端口22

修改Centos服务器ssh链接的默认端口22到任意端口&#xff0c;主要两个步骤&#xff1a; 修改配置文件&#xff0c;添加端口开放防火墙 一、 vim /etc/ssh/sshd_config 在文件中 #Port 22 行下增加 # Port 22 Port [修改后端口]注意&#xff1a; 这里 先将其中的#Port 22前的…

Redis集群之高可用可水平扩展

文章目录 一、Redis集群方案比较二、Redis高可用集群搭建三、Java操作redis集群四、集群的Spring Boot整合Redis 一、Redis集群方案比较 在redis3.0以前的版本要实现集群一般是借助哨兵sentinel工具来监控master节点的状态&#xff0c;如果master节点异 常&#xff0c;则会做主…

The First项目报告:AI+元宇宙+链游,Ultiverse能否引发新一轮GameFi浪潮?

2 月 19 日&#xff0c;由AI 驱动的 Web3 游戏制作和发布一站式平台 Ultiverse 宣布上线 Ulti-Pilot&#xff0c;Ulti-Pilot 允许用户以零成本的方式获得积分、SOUL、和 Ultiverse 生态的其他游戏内资产。 链游赛道一直是Web3领域热议的话题&#xff0c;其数字资产天然契合加密…

晓柯云一键部署WordPress主题

首先先在晓柯云注册一个账号&#xff08;首页 - 晓柯云官网-为草根站长提供优质的建站服务 (chgskj.cn)&#xff09; 然后根据提示完成注册&#xff08;博主选择用邮箱注册&#xff09; 这里博主踩了一个坑&#xff0c;下拉会有勾选的一个东西&#xff08;一定要勾选&#xff0…