Pytorch完整的模型训练套路

Pytorch完整的模型训练套路

文章目录

  • Pytorch完整的模型训练套路
  • 以CIFAR10为例实践

  1. 数据集加载步骤

使用适当的库加载数据集,例如torchvision、TensorFlow的tf.data等。
将数据集分为训练集和测试集,并进行必要的预处理,如归一化、数据增强等。

  1. 模型创建步骤

创建机器学习模型,可以是深度神经网络、传统机器学习模型或其它模型类型。
定义模型架构,包括输入层、隐藏层和输出层的结构、激活函数、损失函数等。

  1. 损失函数和优化器定义步骤

定义适当的损失函数来计算模型预测结果于真实标签之间的差异。
选择适当的优化器算法来更新模型参数,如随机梯度下降(SGD)、Adam等。

  1. 训练循环步骤

从训练集中获取一批样本数据,并将其输入模型进行前向传播。
计算损失函数,并根据损失函数进行反向传播和参数更新。
重复以上步骤,直到达到预定的训练次数或达到收敛条件。

  1. 测试循环步骤

从测试集中获取一批样本数据,并将其输入模型进行前向传播。
计算损失函数或评估指标,用于评估模型在测试集上的性能。

  1. 训练和测试过程的记录和输出步骤

使用适当的工具或库记录训练过程中的损失值、准确率、评估指标等。

  1. 结束训练步骤

根据训练结束条件、例如达到预定的训练次数或收敛条件,结束训练。可以保存模型参数或整个模型,以便日后部署和使用。

以CIFAR10为例实践

并利用tensorboard可视化

import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
'''数据集加载'''
train_data = torchvision.datasets.CIFAR10(root='dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root='dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)

# 训练数据集的长度
train_data_size = len(train_data)
print(f"训练数据集的长度为:{train_data_size}")
# 测试数据集的长度
test_data_size = len(test_data)
print(f"测试数据集的长度:{test_data_size}")
#利用DataLoader加载数据集
train_dataloader = DataLoader(test_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)
Files already downloaded and verified
Files already downloaded and verified
训练数据集的长度为:50000
测试数据集的长度:10000

‘’‘创建模型’‘’

以上篇文章《Pytorch损失函数、反向传播和优化器、Sequential使用》中的BS()为例

在这里插入图片描述

'''创建模型'''
class BS(nn.Module):

    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(in_channels=3,
                               out_channels=32,
                               kernel_size=5,
                               stride=1,
                               padding=2),  #stride和padding计算得到
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32,
                                   out_channels=32,
                                   kernel_size=5,
                                   stride=1,
                                   padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32,
                                   out_channels=64,
                                   kernel_size=5,
                                   padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten(),  #in_features变为64*4*4=1024
            nn.Linear(in_features=1024, out_features=64),
            nn.Linear(in_features=64, out_features=10),
        )
    
    def forward(self,x):
        x = self.model(x)
        return x
    
bs = BS()
print(bs)
BS(
  (model): Sequential(
    (0): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (4): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Flatten(start_dim=1, end_dim=-1)
    (7): Linear(in_features=1024, out_features=64, bias=True)
    (8): Linear(in_features=64, out_features=10, bias=True)
  )
)

一般来说,会将网络单独存放在一个model.py文件当中,然后利用from model import * 进行导入

'''定义损失函数和优化器'''
# 使用交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()  
# 定义优化器
learning_rate = 1e-2  #学习率0.01
optimizer = torch.optim.SGD(bs.parameters(), lr=learning_rate)
"""
训练循环步骤
"""
# 开始设置训练神经网络的一些参数
# 记录训练的次数
total_train_step = 0
# 记录测试的次数
total_test_step = 0
# 训练的轮数
epoch = 10



writer = SummaryWriter(".logs") #Tensorboard可视化
for i in range(epoch):
    print("----第{}轮训练开始----".format(i))
    #bs.train() # bs.train()#有batchnorm、dropout层需要调用。官方文档见torch.nn.Module
    '''训练步骤开始'''
    for data in train_dataloader:
        imgs, targets = data
        outputs = bs(imgs)
        loss = loss_fn(outputs, targets)
        
        optimizer.zero_grad() # 首先要梯度清零
        loss.backward() #得到梯度
        optimizer.step() #进行优化

        total_train_step = total_train_step + 1
        if total_train_step % 100 == 0:
            print("训练次数:{}, loss:{}".format(total_train_step,loss.item()))
            
            writer.add_scalar("train_loss", loss.item(),total_train_step)
            
    '''测试步骤开始'''
    #bs.eval() # bs.train()#有batchnorm、dropout层需要调用。官方文档见torch.nn.Module
    total_test_loss = 0
    #total_accuracy
    total_accuracy = 0
    with torch.no_grad():#torch.no_grad()是一个上下文管理器,用来禁止梯度的计算,通常用来网络推断中,它可以减少计算内存的使用量。
        for imgs, targets in test_dataloader:
            outputs = bs(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss = total_test_loss + loss.item() #.item()取出数字
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy
    """测试过程的记录和输出"""
    print("整体测试集上损失函数loss:{}".format(total_test_loss))
    print("整体测试集上正确率:{}".format(total_accuracy/test_data_size))
    writer.add_scalar("test_loss", total_test_loss, total_test_step)
    writer.add_scalar('test_accuracy',total_accuracy/test_data_size)
    total_test_step = total_test_step + 1
    torch.save(bs, "test_{}.pth".format(i))
    print("模型已保存")
"""
结束训练步骤
"""
writer.close()

利用tensoraboard显示:

tensorboar --logdir logs

在这里插入图片描述

补充.item()

  1. .item()
import torch
a = torch.tensor(5)
print(a)
print(a.item())
tensor(5)
5
  1. model.train()和model.eval()
    官方网址见:torch.nn.Module(*args, **kwargs)
    在这里插入图片描述
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/171150.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

桌面运维。

Windows运行命令: color 01/02切换字符颜色cls 清屏ipconfig 设备的ip信息ipconfig /all 设备ip的所有信息 破解系统密码: 进PE系统,使用里面的工具破解 vmware workstation安装 网卡 网卡:ncpa.cpl window远程控制 mstsc …

【JavaEE】Spring核心与设计思想(控制反转式程序演示、IoC、DI)

一、什么是Spring? 通常所说的 Spring 指的是 Spring Framework(Spring 框架),它是⼀个开源框架,有着活跃⽽庞⼤的社区,这就是它之所以能⻓久不衰的原因。Spring ⽀持⼴泛的应⽤场景,它可以让 …

人工智能对我们的生活影响有多大

随着科技的飞速发展,人工智能已经渗透到我们生活的方方面面,并且越来越受到人们的关注。从智能语音助手到自动驾驶汽车,从智能家居系统到医疗诊断,人工智能技术正在改变着我们的生活方式。那么,人工智能对我们的生活影…

【C++上层应用】5. 文件和流

文章目录 【 1. 打开文件 】1.1 open 函数1.2 open 多种模式的结合使用 【 2. 关闭文件 】【 3. 写入 & 读取文件 】【 4. 文件位置指针 】 和 iostream 库中的 cin 标准输入流和 cout 标准输出流类似,C中另一个库 fstream 也存在文件的读取流和标准写入流。fst…

erlang语言为什么天生支持高并发

Erlang 语言天生支持高并发的主要原因可以归结于它的设计哲学和一些核心特性。以下是 Erlang 支持高并发的几个关键方面: 轻量级进程:Erlang 使用轻量级的并发实体,通常称为“进程”,这些进程在Erlang虚拟机内部运行,而…

腾讯云轻量应用服务器三年租用价格表_免去续费困扰

腾讯云服务器续费贵所以一次性买3年或5年,腾讯云轻量应用服务器3年价格有优惠,CVM云服务器5年有特价,腾讯云3年轻量和5年云服务器CVM优惠活动入口,3年轻量应用服务器配置可选2核2G4M和2核4G5M带宽,5年CVM云服务器可以选…

如何给面试官解释什么是分布式和集群?

分布式(distributed) 是指在多台不同的服务器中部署不同的服务模块,通过远程调用协同工作,对外提供服务。 集群(cluster) 是指在多台不同的服务器中部署相同应用或服务模块,构成一个集群&#…

【开源】基于JAVA的社区买菜系统

项目编号: S 011 ,文末获取源码。 \color{red}{项目编号:S011,文末获取源码。} 项目编号:S011,文末获取源码。 目录 一、摘要1.1 项目介绍1.2 项目录屏 二、系统设计2.1 功能模块设计2.1.1 数据中心模块2.1…

redis的高可用

在集群当中有一个非常重要的指标,提供正常服务的时间的百分比(365天)99.9% redis的高可用含义更加宽泛,正常服务是指标之一、数据容量的扩展、数据的安全性。 在redis当中实现高可用技术:持久化、主从复制、哨兵模式、cluster集群 持久化是…

时尚女童千金小套装,爱了爱了

自带氛围感的小套装简直不要太好看 红色吸睛又喜庆,显白不挑人穿 谁穿都好看系列!! 厚实的苏两毛羊毛 保暖性能真的非常好 超柔软的手感和触感 显得很高级穿上也很舒适 一整个小公主style,时尚气质感拉满 里面的连衣裙单穿也很nice …

V8引擎隐藏类(VIP课程)

上一章我们讲了V8如何存储的对象,其中提到了隐藏类,这一章我们来看看隐藏类到底做了什么。 为什么要讲V8???? 隐藏类是V8引擎在运行时自动生成和管理的数据结构,用于跟踪对象的属性和方法 隐藏…

揭开副业的神秘面纱,上班摸鱼搞副业

副业没那么神秘,说白了就是销售。 卖文字、卖知识、卖技能、卖产品 ... 找对了渠道,有人愿意买单,就成了副业。 但是现在搞副业,坑多水深的,太多人栽跟头了,丢几百块都能算少的。 开始细说干货之前&#…

el-date-picker ie模式下 初始化未赋值;未清空

el-date-picker ie模式下 初始化未赋值;未清空 给 dete-picker 加key属性 eg:

【C++】——多态性与模板(其二)

🎃个人专栏: 🐬 算法设计与分析:算法设计与分析_IT闫的博客-CSDN博客 🐳Java基础:Java基础_IT闫的博客-CSDN博客 🐋c语言:c语言_IT闫的博客-CSDN博客 🐟MySQL&#xff1a…

ROS1余ROS2共存的一键安装(全)

ROS1的安装: ROS的一键安装(全)_ros一键安装_牙刷与鞋垫的博客-CSDN博客 ROS2的安装 在开始这一部分的ROS2安装之前,是可以安装ROS1的,当然如果你只需要安装ROS2的话就执行从此处开始的代码即可 我是ubuntu20.4的版…

给新手教师的成长建议

随着教育的不断发展和进步,越来越多的新人加入到教师这个行列中来。从学生到教师,这是一个华丽的转身,需要我们不断地学习和成长。作为一名新手老师,如何才能快速成长呢?以下是一名老师教师给的几点建议: 一…

华为防火墙 Radius认证

实现的功能:本地内网用户上网时必须要进行Radius验证,通过后才能上网 前置工作请按这个配置:华为防火墙 DMZ 设置-CSDN博客 Windows 服务器安装 Radius 实现上网认证 拓扑图如下: 一、服务器配置 WinRadius 1、安装WinRadius …

移动机器人路径规划(五)--- 基于Minimun Snap的轨迹优化

目录 1 我们本节主要介绍的 2 Minimum Snap Optimization 2.1 Differential Flatness(微分平坦) 2 Minimum Snap 3 Closed-form Solution to Minimum Snap 3.1 Decision variable mapping 待优化问题的映射 4 凸优化 及其它问题 1 我们本节主要介…

2023 羊城杯 final

前言 笔者并未参加此次比赛, 仅仅做刷题记录. 题目难度中等偏下吧, 看你记不记得一些利用手法了. arrary_index_bank 考点: 数组越界 保护: 除了 Canary, 其他保护全开, 题目给了后门 漏洞点: idx/one 为 int64, 是带符号数, 所以这里存在向上越界, 并且 buf 为局部变量,…

程序员有必要考个 985 非全日制研究生嘛?

大家好,我是伍六七。 经常有读者问我,非全日制研究生好考嘛?有用嘛?今天我们来聊聊这个问题。 科普一下:什么是非全日制研究生? 非全日制研究生是国家在 2017 年对教育行业的重大改革。 非全日制需要参加…