pytorch学习笔记(十二)

 以下代码是以CIFAR10这个10分类的图片数据集训练过程的完整的代码。

训练部分

train.py主要包含以下几个部件:

  • 准备训练、测试数据集
  • 用DateLoader加载两个数据集,要设置好batchsize
  • 创建网络模型(具体模型在model.py中)
  • 设置损失函数
  • 设置优化器,其中要包含优化的参数和学习率
  • 初始化一些参数,如训练测试的次数、以及训练的轮数epoch
  • 以训练轮数为循环进入训练
  • 从训练数据中加载数据,将数据(模型的输出和目标(标签))送进损失函数中计算损失
  • 梯度清零,并且反向传播损失函数,用优化器进行参数更新,并累计训练步数。
  • 在保证不调优的情况下看正确率(with)
    从测试集中拿数据,一样的讨论算损失,但是要算正确率
  • 用tensorboard可是话训练的结果

关于imgs, targets =data这句代码中的targets解释

  1. imgs (Images): 这个变量通常包含一批图像数据。在计算机视觉任务中,这些图像是模型的输入,可以是任何形式的视觉数据,比如照片、视频帧、医学影像等。在训练过程中,这些图像通过神经网络进行前向传播以生成预测结果。

  2. targets (Targets): 这个变量包含与 imgs 中每个图像对应的标签或目标。标签的具体形式取决于执行的任务:

    • 分类任务中,targets 可能是类别标签,例如识别图像中的对象(猫、狗、汽车等)。
    • 对象检测任务中,targets 可能包括对象的边界框(bounding boxes)和类别。
    • 语义分割任务中,targets 可能是每个像素的类别标签。
    • 回归任务中,targets 可能是一些连续值,如在面部关键点检测中的坐标点。

在训练过程中targets用于损失函数(交叉熵损失、均方误差等),这是模型学习并优化其参数的基础。损失函数衡量了模型预测和真实目标之间的差异,训练目标是最小化差异。

关于optimizer.step()的解释

在机器学习中,这玩意是个关键操作,就是用来根性模型参数的。

优化器和梯度下降,常用的优化算法(SGD、Adam、RMSprop等)来调整网络参数(如权重和偏差),以最小化损失函数。这个过程被称为梯度下降。

训练过程中的步骤:

  • Forward Pass:输入数据进行前向计算,生成预测。
  • 计算损失函数,比较网络的预测和真实计算损失
  • 反向传播:通过反向传播损失,计算每个参数梯度 loss.backward()来完成。
  • 更新参数optimizer.step()被调用来更新网络的参数。根据计算出的梯度和定义的优化算法,它会调整参数以减小损失。

注意: 

optimizer.step()根据优化器预定义的规则和计算出的梯度来更新模型参数。在调用它之后,会执行optimizer.zero_grad(),以便下一次迭代时从干净的状态开始。

import torch.nn
import torchvision
from torch.utils.tensorboard import SummaryWriter

from model import *
from torch.utils.data import DataLoader
from torch import nn

#准备数据集
train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=torchvision.transforms.ToTensor(),
                                           download=True)
#测试数据集
test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=torchvision.transforms.ToTensor(),
                                         download=True)
#length 长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度:{}".format(train_data_size))
print("测试数据集的长度:{}".format(test_data_size))

#利用DataLoader 来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)


#创建网络模型
tudui = Tudui()

#损失函数
loss_fn = nn.CrossEntropyLoss()

#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#测试的次数
total_test_step = 0
#训练的轮数
epoch = 10

#添加 tensorboard
writer = SummaryWriter("../logs_train")

for i in range(epoch):
    print("-------------第{}轮训练开始-------------".format(i+1))

    #训练步骤开始
    #并不需要把网络设置成训练状态才能进行训练
    tudui.train()
    for data in train_dataloader:
        imgs, targets =data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)
        #梯度清零
        #优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_train_step = total_train_step + 1
        #避免无用信息覆盖
        if total_train_step % 100 == 0:
            print("训练次数: {},loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)
    #测试步骤
    #也不是需要把网络设置成eval状态才能进行网络的一个测试
    tudui.eval()
    total_test_loss = 0
    #看正确率
    total_accuracy = 0
    #在with里面的代码没有了梯度,保证不会进行调优
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets =data
            outputs = tudui(imgs)
            #一部分数据在网络模型上的损失
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy = total_accuracy + accuracy
    print("整体测试集上的Loss:{}".format(total_test_loss))
    print("整体测试的正确率:{}".format(total_accuracy/test_data_size))
    writer.add_scalar("train_loss", loss.item(), total_test_step)
    writer.add_scalar("test_accuracy", total_accuracy/test_data_size, total_test_step)
    #测试的步骤+1否则图画不出来
    total_train_step = total_test_step + 1

    torch.save(tudui, "tudui_{}.pth".format(i))
    print("模型已保存")
writer.close()

上面是一个训练过程,下面介绍一下训练准确率怎么得来的。

假设有一个2分类的模型

Model(2分类)

#下面是得分

Outputs = [[0.2,0.3],[0.1,0.4]]

#通过Argmax 变成

Preds = [1]

                [1]

Inputs target=[0][1]

Preds==inputs target

#上面的这个式子返回的就是T or F

#加起来就是分类正确的个数了。

[false,true].sum()=1

                                       

这边注意一下output.argmax(x)的方向,x是0或是1,0的方向是竖着来的,1的方向是横着来的。

import torch
outputs = torch.tensor([[0.1,0.2],
                        [0.3,0.4]])
print(outputs.argmax(1))
preds = outputs.argmax(1)
targets = torch.tensor([0,1])
print((preds == targets).sum())

-----------------------------------------------------未完待续1------------------------------------------------------------- 

 训练的一些细节:

如果有Dropout和BatchNorm等一些特殊层,需要

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/360007.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【大数据】Flink SQL 语法篇(三):窗口聚合(TUMBLE、HOP、SESSION、CUMULATE)

Flink SQL 语法篇(三):窗口聚合 1.滚动窗口(TUMBLE)1.1 Group Window Aggregation 方案(支持 Batch / Streaming 任务)1.2 Windowing TVF 方案(1.13 只支持 Streaming 任务&#xff…

自己实现的小功能

小功能实现 2024/1/31 问题一: 将文本模式的csv文件作为表编辑之后,先要再变回来。找了5分钟都没找到,去网上搜也没搜到 解决方案 复制一份,对没错。 不是把表遍历一遍,重新将数据写入。 3.5给的答案就是重新写入…

网络防御基础介绍和基本的策略集

1.什么是防火墙 防火墙的主要职责在于:控制和防护 --- 安全策略 --- 防火墙可以根据安全策略来抓取流量之后做出对应的动作。 2.防火墙的分类 防火墙吞吐量 --- 防火墙同一时间处理的数据量 3.防火墙的历史 4.防火墙的分类 1.基于数据包的防火墙-包过滤防火墙 缺…

深度学习(9)--pydot库和graphviz库安装流程详解

目录 一.pydot库安装 二.graphviz库安装 一.pydot库安装 pydot的安装可直接在编译器安装相关包,以PyCharm举例: 如果搜索可用软件包显示为空,记得在此处把使用Conda软件包管理器”点亮 二.graphviz库安装 点击链接下载安装包graphviz-2.38…

网络协议与攻击模拟_11DHCP欺骗防护

开启DHCP 监听 ip dhcp snooping 指定监听vlan ip dhcp snooping vlan 1 由于开启监听后,交换机上的接口就全部变成非信任端口, 非信任端口会拒绝DHCP报文,会造成正常的DHCP请求和响应都无法完成。 现在是请求不到IP地址的,…

字符串匹配算法(BF、KMP)

一 字符串匹配算法—BF算法 BF算法简称暴力破解算法,时间复杂度很容易计算为O(m*n)(当n>>m时候) 本身字符串S,长度为m 模式字符串T,长度为n 最差情况,需要匹配(n-m)mm才可以成功,所以时间复杂度就是…

tarojs View多行文本无法换行问题解决

问题:未换行 code: 解决: 加上换行属性的css就好了 white-space: break-spaces;

银行ATM监控对讲系统分机可视对讲分机|ATM音视频终端IP网络可视对讲终端IP对讲终端对讲分机IP网络对讲系统

SV-6301T可视对讲终端 (单键) 产品简介 产品简介: 一键报警可视对讲终端是用于平安城市、银行、医院,智慧养老,景区,智慧路灯,平安校园,智慧电梯,无人超市等方案中的一…

哈希表算法模版

模拟散列哈希表 活动 - AcWing 拉链法 思路&#xff1a; 代码如下&#xff1a; #include <cstring> #include <iostream>using namespace std;const int N 1e5 3; // 取大于1e5的第一个质数&#xff0c;取质数冲突的概率最小 可以百度//* 开一个槽 h int h[…

jmeters响应结果反写csv文件及参数化

1.http响应结果反写csv文件 1.1各参数设置级别 线程组&#xff08;一级&#xff09;---->请求默认值、请求头、http请求、察看结果树&#xff08;二级&#xff09;----->正则表达式、BeanShell 后置处理程序&#xff08;三级&#xff09;。 1.2.正则表达式提取反写参数…

Backtrader 文档学习-Cheat-On-Open

Backtrader 文档学习-Cheat-On-Open 1.概述 V1.9.44.116增加了Cheat On Open的支持。对于全押的人来说&#xff0c;这似乎是一个必需的功能&#xff0c;用bar的收盘价后进行计算&#xff0c;希望与开盘价相匹配。 当开盘价差距&#xff08;上涨或下跌&#xff0c;取决于买入或…

SpringClound项目相关

nacos本机模式非虚拟机启动也可正常连接 nacos中的配置中心相当于在application.yml中的相关配置&#xff0c;转移位置&#xff0c;内容同application.yml完全一样均可。 黑马项目导入后&#xff0c;依赖缺失&#xff1a; 首先尝试maven重新加载&#xff0c;控制台提示传递依…

聊一聊GPT、文心、通义、混元

我使用同一个Prompt提示词“请以记叙文的文体来写”&#xff0c;分别发送给GPT-3.5&#xff08;调用API&#xff09;、文心、通义、混元&#xff0c;下面是它们各自生成的文本内容&#xff0c;大家一看便知了。 GPT-3.5&#xff1a; 在我个人使用GPT模型的过程中&#xff0c;我…

ESP32-C3 vscode USB-Serial-JTAG 调试

硬件 接线 查看驱动 vs code配置 debugging via builtin USB-JTAG 配置调试UART 配置下载类型 创建调试配置 调试 参考 esp32c3内置USB-Serial-JTAG的使用 链接: link 看了之后&#xff0c;还是不会ESP32-C3的调试及下载&#xff0c;你过来打我&#xff01;&#xff01;&…

KAFKA高可用架构涉及常用功能整理

KAFKA高可用架构涉及常用功能整理 1. kafka的高可用系统架构和相关组件2. kafka的核心参数2.1 常规配置2.2 特殊优化配置 3. kafka常用命令3.1 常用基础命令3.1.1 创建topic3.1.2 获取集群的topic列表3.1.3 获取集群的topic详情3.1.4 删除集群的topic3.1.5 获取集群的消费组列表…

微信小程序之下拉刷新事件、上拉触底事件和案例

学习的最大理由是想摆脱平庸&#xff0c;早一天就多一份人生的精彩&#xff1b;迟一天就多一天平庸的困扰。各位小伙伴&#xff0c;如果您&#xff1a; 想系统/深入学习某技术知识点… 一个人摸索学习很难坚持&#xff0c;想组团高效学习… 想写博客但无从下手&#xff0c;急需…

【方法】RAR分卷压缩文件如何打开?

当RAR压缩文件比较大&#xff0c;不利于传输时&#xff0c;我们可以把文件压缩成分卷文件&#xff0c;那压缩后的分卷文件如何打开呢&#xff1f;今天就来说说RAR分卷压缩文件的两种打开方法。 方法一&#xff1a; 和普通压缩包一样&#xff0c;打开分卷压缩包也需要用到解压…

Web3与个人隐私:打破数据壁垒的新时代

随着科技的不断发展&#xff0c;Web3技术的兴起为我们带来了一个全新的数字时代&#xff0c;重新定义了个人隐私的概念与实践。在这个时代&#xff0c;我们不再被动地成为数据经济的被动参与者&#xff0c;而是迎来了一个更加安全、透明和个人主导的网络生态。 1. 去中心化的数…

比FTP更好用的企业远程传输大文件工具居然是这个!

在数字化浪潮的推动下&#xff0c;企业对于数据传输的速度和安全性有了更高的要求。传统的FTP协议&#xff0c;尽管历史悠久&#xff0c;但在当前的企业应用场景中&#xff0c;其局限性逐渐暴露。企业现在寻求的是能够提供快速、安全、便捷且经济高效的文件传输解决方案。本文旨…

springboot整合mqtt实现消息订阅和推送

前言 mica-mqtt-client-spring-boot-starter是一个基于Spring Boot的MQTT客户端启动器&#xff0c;它集成了mica-mqtt客户端&#xff0c;提供了在Spring Boot应用程序中使用MQTT协议进行消息通信的能力。以下是关于mica-mqtt-client-spring-boot-starter的简介&#xff1a; 特…