10-pytorch-完整模型训练

b站小土堆pytorch教程学习笔记

一、从零开始构建自己的神经网络

1.模型构建
#准备数据集
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

from model import *
from torch.utils.data import DataLoader

train_data=torchvision.datasets.CIFAR10('dataset',train=True,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
test_data=torchvision.datasets.CIFAR10('dataset',train=False,
                                        transform=torchvision.transforms.ToTensor(),
                                        download=True)
#查看训练数据集和测试集大小
train_data_size=len(train_data)
test_data_size=len(test_data)
print('训练数据集长度为:{}'.format(train_data_size))#训练数据集长度为:50000
print('测试数据集长度为:{}'.format(test_data_size))#测试数据集长度为:10000

#利用datalo加载数据集
train_dataloader=DataLoader(train_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)

#搭建神经网络,在model文件中搭建网络,在此文件中引用
han=Han()

#损失函数
loss_fn=nn.CrossEntropyLoss()

#优化器
# learning_rate=0.01
learning_rate=1e-2
optimizer=torch.optim.SGD(han.parameters(),lr=learning_rate)

#设置训练网络的相关参数
total_train_step = 0#记录训练的次数
total_test_step = 0#记录测试的次数
epoch=10#训练轮数

#添加tensorboard
writer=SummaryWriter('logs/train')

for i in range(10):
    print('-------第{}轮训练开始-------'.format(i+1))

    for data in train_dataloader:
        imgs,target=data
        output=han(imgs)
        loss=loss_fn(output,target)

        #优化器优化模型
        optimizer.zero_grad()#梯度清零
        loss.backward()#反向传播计算梯度
        optimizer.step()#参数优化

        total_train_step=total_train_step+1
        if total_train_step % 100==0:#逢100打印
            print('训练次数:{},loss:{}'.format(total_train_step,loss.item()))#loss.item()取出tensor类型的数字
            writer.add_scalar('train_loss',loss.item(),total_train_step)

    #每训练完一轮将在测试集上跑一遍,评估其训练效果
    total_test_loss=0
    with torch.no_grad():
        for data in test_dataloader:
            imgs,target=data
            output=han(imgs)
            loss=loss_fn(output,target)
            total_test_loss=total_test_loss+loss.item()

    print('所有测试集上的损失:{}'.format(total_test_loss))
    writer.add_scalar('test_loss',total_test_loss,total_test_step)
    total_test_step+=1

    #保存每一轮模型
    torch.save(han,'han_{}.pth'.format(i))
    print('模型已保存')
writer.close()
import torch
from torch import nn


class Han(nn.Module):
    def __init__(self):
        super(Han, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )

    def forward(self, x):
        x = self.model(x)
        return x

if __name__ == '__main__':
    han=Han()
    input=torch.ones(64,3,32,32)
    output=han(input)
    print(output.shape)#torch.Size([64, 10])10表示十个类别输出概率

结果如下:
在这里插入图片描述

2.使用argmax计算整体正确率
#每训练完一轮将在测试集上跑一遍,评估其训练效果
    total_test_loss=0
    total_acc=0
    with torch.no_grad():
        for data in test_dataloader:
            imgs,target=data
            output=han(imgs)
            loss=loss_fn(output,target)
            total_test_loss=total_test_loss+loss.item()

            acc=(output.argmax(1)==target).sum()#(1)横着看
            total_acc+=acc
    print('所有测试集上的损失:{}'.format(total_test_loss))
    print('整体测试集上的正确率:{}'.format(total_acc/test_data_size))
    writer.add_scalar('test_loss',total_test_loss,total_test_step)
    writer.add_scalar('test_acc', total_acc/test_data_size, total_test_step)
    total_test_step+=1

整体测试集上的正确率:0.27480000257492065

3.当训练或测试时存在dropout层或batch normal层,则需要在训练训练和测试前加入:
#训练前
han.train()
#测试前
han.eval()

二、使用GPU

网络模型、数据(输入、标注)、损失函数调用cuda()

1.方式1
#模型
if torch.cuda.is_available():
    han=han.cuda()
#损失函数
loss_fn=nn.CrossEntropyLoss()
loss_fn=loss_fn.cuda()
imgs,target=data
imgs=imgs.cuda()
target=target.cuda()
2.方式2
#定义训练设备
device=torch.device('cuda')
han=han.to(device)
imgs = imgs.to(device)
target = target.to(device)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/408235.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

WPF 启动项目 Grid、StackPanel 布局

WPF 启动项目 <!--x:Class"WPF_Study.App" 对应类&#xff1a;WPF_Study.App--> <!--xmlns:local"clr-namespace:WPF_Study" 命名空间&#xff1a;WPF_Study--> <Application x:Class"WPF_Study.App"xmlns"http://schema…

网络原理TCP之“三次握手“

TCP内核中的建立连接 众所周知,TCP是有连接的. 当我们在客户端敲出socket new Socket(serverIp,severPort)时,就在系统内核就在建立连接 真正建立连接是在系统内核中建立的,我们程序员只是调用相关的api. 在此处,我们把TCP的建立连接称为三次握手. 系统在内核建立连接时如上…

【Linux进程】进程状态---进程僵尸与孤儿

&#x1f4d9; 作者简介 &#xff1a;RO-BERRY &#x1f4d7; 学习方向&#xff1a;致力于C、C、数据结构、TCP/IP、数据库等等一系列知识 &#x1f4d2; 日后方向 : 偏向于CPP开发以及大数据方向&#xff0c;欢迎各位关注&#xff0c;谢谢各位的支持 目录 1.进程排队2.进程状态…

OSCP靶场--Nickel

OSCP靶场–Nickel 考点(1.POST方法请求信息 2.ftp&#xff0c;ssh密码复用 3.pdf文件密码爆破) 1.nmap扫描 ┌──(root㉿kali)-[~/Desktop] └─# nmap 192.168.237.99 -sV -sC -p- --min-rate 5000 Starting Nmap 7.92 ( https://nmap.org ) at 2024-02-22 04:06 EST Nm…

机器学习基础(五)监督与非监督学习的结合

导语&#xff1a;上一节我们详细探索非监督学习的进阶应用&#xff0c;详情可见&#xff1a; 机器学习基础&#xff08;四&#xff09;非监督学习的进阶探索-CSDN博客文章浏览阅读613次&#xff0c;点赞15次&#xff0c;收藏13次。非监督学习像一位探险家&#xff0c;挖掘未标…

前后端分离Vue+ElementUI+nodejs蛋糕甜品商城购物网站95m4l

本文主要介绍了一种基于windows平台实现的蛋糕购物商城网站。该系统为用户找到蛋糕购物商城网站提供了更安全、更高效、更便捷的途径。本系统有二个角色&#xff1a;管理员和用户&#xff0c;要求具备以下功能&#xff1a; &#xff08;1&#xff09;用户可以修改个人信息&…

ARMv8-AArch64 的异常处理模型详解之异常向量表vector tables

目录 一&#xff0c;AArch64 异常向量表 二&#xff0c;栈指针以及SP寄存器的选择 三&#xff0c;从异常返回 一&#xff0c;AArch64 异常向量表 异常向量表&#xff08;vector tables&#xff09;是一组存放于普通内存&#xff08;normal memory&#xff09;空间的&#xf…

视频号视频下载(如何把视频号中的视频下载下来)

在如今的信息时代&#xff0c;热点创作者和科技创作者们的素材库越来越丰富&#xff0c;视频号作为一种新兴的媒体形式&#xff0c;其中蕴含的优质内容更是不可或缺。但是&#xff0c;如何将心仪的视频号视频下载下来&#xff0c;进行二次创作并在其他平台发布呢&#xff1f;今…

Vue+SpringBoot打造开放实验室管理系统

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容2.1 实验室类型模块2.2 实验室模块2.3 实验管理模块2.4 实验设备模块2.5 实验订单模块 三、系统设计3.1 用例设计3.2 数据库设计 四、系统展示五、样例代码5.1 查询实验室设备5.2 实验放号5.3 实验预定 六、免责说明 一、摘…

Kafka集群详解

Kafka集群的目标 1、高并发 2、高可用&#xff08;防数据丢失&#xff09; 3、动态伸缩 Kafka集群规模如何预估 吞吐量&#xff1a; 集群可以提高处理请求的能力。单个Broker的性能不足&#xff0c;可以通过扩展broker来解决。 磁盘空间&#xff1a; 比如&#xff0c;如…

Ansible user 模块 该模块主要是用来管理用户账号

目录 参数语法验证创建用户删除用户验证 删除用户 参数 comment  # 用户的描述信息 createhome  # 是否创建家目录 force  # 在使用stateabsent时, 行为与userdel –force一致. group  # 指定基本组 groups  # 指定附加组&#xff0c;如果指定为(groups)表示删除所有…

Nest.js权限管理系统开发(二)连接MySQL、Redis

安装MySQL及相关依赖 下载dmg文件安装 前往MySQL :: Download MySQL Community Server下载最新版本的MySQL。 打开系统设置&#xff0c;拉到最下方可以看到MySQL&#xff0c;打开看到两个绿点表示安装成功&#xff0c;也可以在这里修改MySQL密码。 配置环境变量 打开终端配…

Windows中的Git Bash运行conda命令:未找到命令的错误(已解决)

在windows中的Gitbash中 打开激活conda环境&#xff0c;并运行&#xff08;前提是你先安装好git&#xff08;自己去官网下载&#xff09;&#xff09;。 要能够在Gitbash上运行Conda&#xff0c; 临时配置 如果你只是临时用一下&#xff0c;就是临时爽一把&#xff0c;那就按…

独立站建站全攻略:从0到1打造专属在线商业平台

独立站建站全攻略&#xff1a;从0到1打造专属在线商业平台 随着互联网的普及和发展&#xff0c;越来越多的企业和个人开始认识到拥有一个独立站的重要性。独立站不仅可以提升品牌形象&#xff0c;还能为企业带来更多的流量和潜在客户。本文将为大家详细介绍独立站建站的全过程…

代码随想录day32--动态规划理论基础

什么是动态规划 动态规划简称DP&#xff0c;如果某一问题有很多重叠子问题&#xff0c;使用动态规划是最有效的。 所以动态规划中每一个状态一定是由上一个状态推导出来的&#xff0c;这一点一定要和贪心区别出来&#xff0c;贪心没有状态推导&#xff0c;而是直接从局部直接…

【深度学习】主要提出者【Hinton】中国大会最新演讲【通往智能的两种道路】

「但我已经老了&#xff0c;我所希望的是像你们这样的年轻有为的研究人员&#xff0c;去想出我们如何能够拥有这些超级智能&#xff0c;使我们的生活变得更好&#xff0c;而不是被它们控制。」 6 月 10 日&#xff0c;在 2023 北京智源大会的闭幕式演讲中&#xff0c;在谈到如…

opengles 顶点坐标变换常用的矩阵(九)

文章目录 前言一、opengles 常用的模型矩阵1. 单位矩阵2. 缩放矩阵3. 位移矩阵4. 旋转矩阵二、第三方矩阵数学库1. glm1.1 ubuntu 上安装 glm 库1.2 glm 使用实例1.2.1 生成一个沿Y轴旋转45度的4x4旋转矩阵, 代码实例如下1.2.2 生成一个将物体移到到Z轴正方向坐标为5处的4x4 vi…

K线实战分析系列之七:行情顶部的看跌信号——黄昏星形态

K线实战分析系列之七&#xff1a;行情顶部的看跌信号——黄昏星形态 一、黄昏星形态二、黄昏线总结 一、黄昏星形态 二、黄昏线总结 黄昏星的高点形成阻力位&#xff0c;启明星的低点形成支撑位中间的星线实体与第一根K线的实体跳空区域比较宽&#xff0c;第三根K线覆盖了第一…

SSM项目集成Spring Security 4.X版本 之 加入DWZ,J-UI框架实现登录和主页菜单显示

目录 前言 一、加入DWZ J-UI框架 二、实现登录页面 三、实现主页面菜单显示 前言 大家好&#xff01;写文章之前先列出几篇相关文章。本文内容也在其项目中接续实现。 一. SSM项目集成Spring Security 4.X版本&#xff08;使用spring-security.xml 配置文件方式&#xff…

旅游组团自驾游拼团系统 微信小程序python+java+node.js+php

随着社会的发展&#xff0c;旅游业已成为全球经济中发展势头最强劲和规模最大的产业之一。为方便驴友出行&#xff0c;寻找旅游伙伴&#xff0c;更好的规划旅游计划&#xff0c;开发一款自驾游拼团小程序&#xff0c;通过微信小程序发起自驾游拼团&#xff0c;吸收有车或无车驴…