pytorch -- CIFAR10 完整的模型训练套路

  1. 网络结构
    在这里插入图片描述
  2. 代码
# CIFAR 10
'''
完整的模型训练套路:

'''
import torch.optim
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from model import *

# 1. 准备数据集
train_data = torchvision.datasets.CIFAR10('data',train=True,
                                       transform=torchvision.transforms.ToTensor(),
                                       download=True)
test_data = torchvision.datasets.CIFAR10('data',train=False,
                                         transform=torchvision.transforms.ToTensor(),
                                         download=True)
# 数据集大小
train_data_size = len(train_data)
test_data_size = len(test_data)
print('训练数据集的长度为{}'.format(train_data_size))
print('测试数据集的长度为{}'.format(test_data_size))

# 2 利用DataLoader加载数据集
train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

# 3 搭建神经网络
# 4 创建网络模型
tudui = Tudui()

# 5 损失函数
loss_fn = nn.CrossEntropyLoss()

# 6 优化器 1e-2=1x10^(-2)
learning_rate = 0.01
optimizer = torch.optim.SGD(tudui.parameters(),lr=learning_rate)

# 7 设置训练网络的一些参数
total_train_step = 0 # 记录训练次数
total_test_step = 0 # 记录测试次数
epoch = 10 #训练轮数
# 添加tensorboard
writer = SummaryWriter('logs_model')

for i in range(epoch):
    print('-----------第{}轮训练开始-----------'.format(i+1))
    # 训练开始
    # 训练步骤开始 dropout batchNorm仅对某些层次有作用
    tudui.train()
    for data in train_dataloader:
        imgs, targets = data
        output = tudui(imgs) #训练模型的预测输出
        loss = loss_fn(output,targets)
        # 优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print('训练次数是{}时,loss是{}'.format(total_train_step,loss.item()))# 加了item() tensor变成了数字
            writer.add_scalar('train_loss',loss.item(),total_train_step)

    # 训练完一轮,看是否训练好,有没有达到想要的需求,测试数据集中跑一篇看准确率或者损失
    # 测试步骤开始
    tudui.eval()
    total_test_loss = 0
    total_accuracy = 0
    # 测试不需要对梯度进行调整
    with torch.no_grad():
        for data in test_dataloader:
            imgs,targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs,targets)
            total_test_loss += loss.item()
            # accuracy 正确预测的样本数量
            accuracy = (outputs.argmax(1) == targets).sum()
            total_accuracy += accuracy
    print('整体测试集上的loss是{}'.format(total_test_loss))
    print('整体测试集上的正确率是{}'.format(total_accuracy/test_data_size))
    writer.add_scalar('test_loss',total_test_loss,total_test_step)
    writer.add_scalar('test_accuracy', total_accuracy, total_test_step)
    total_test_step+=1

    torch.save(tudui,'tudui_{}.pth'.format(i))
    print('模型已保存')

writer.close()

# model.py
import torch
from torch import nn

# 3 搭建神经网络
class Tudui(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024,64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

if __name__ == '__main__':
    tudui = Tudui()
    # 验证一下输入输出尺寸
    input = torch.ones((64,3,32,32))
    output = tudui(input)
    print(output.shape)

运行结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/411930.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

TensorFlow2.x 精选笔记(2)自动求导与概率

学习参考: 动手学深度学习2.0Deep-Learning-with-TensorFlow-bookpytorchlightning ①如有冒犯、请联系侵删。 ②已写完的笔记文章会不定时一直修订修改(删、改、增),以达到集多方教程的精华于一文的目的。 ③非常推荐上面(学习参考&#x…

Premiere企业发展历程公司历史回顾视频制作pr时间线模板

企业发展历程公司历史回顾视频制作PR时间线模板mogrt。 项目特点:色彩控件,适用于任何字体,包含视频帮助,适用于Premiere Pro 2021及以上版本。 更多PR资源下载:https://prmuban.com/37916.html

手机和windows的便签怎么共享账号使用

在忙碌的生活中,我经常需要在手机和电脑之间同步记事信息。可是,每次当我在手机上记下一些重要事项后,想要在电脑上查看或继续编辑时,总是遇到各种麻烦。因为手机和电脑上的便签软件各不相同,无法实现账号共享和内容同…

用 React 实现搜索 GitHub 用户功能

用 React 实现搜索 GitHub 用户功能 在本篇博客中,我们将介绍如何在 React 应用中搜索 GitHub 用户并显示他们的信息。 创建 React 应用 首先,我们使用 Create React App 创建一个新的 React 应用。Create React App 是一个快速搭建 React 项目的工具…

好用的IP反查接口

IP-API.com - Geolocation API - Documentation - JSON 自定义返回参数调用(1): http://ip-api.com/json/24.48.0.1?fieldsstatus,message,country,countryCode,region,regionName,cityhttp://ip-api.com/json/24.48.0.1?fieldscountry,co…

用C#开发Excel插件的强大开源工具

推荐一个开源项目,方便我们使用C#为Excel开发插件。 01 项目简介 Excel-DNA是一个.Net开源项目,为开发者提供了一种便利的方法,可以将.Net代码与Excel集成,能够轻松的为Excel创建自定义函数、图表、表单等,一方面不仅…

IT廉连看——Uniapp——页面样式与布局

IT廉连看——Uniapp——页面样式与布局 目标: 了解样式与布局的规范 熟记px和rpx的区别 全局样式与index样式的区别 一、查看uniapp框架简介——尺寸单位 px尺寸单位的使用是贯穿始终的。 [IT廉连看] 二、尺寸单位——实操效果 1、打开Hbuilder X并进入in…

组态王连接施耐德M580PLC

组态王连接施耐德M580 网络架构 网线连接PLC和装组态王软件的PC组态设置帮助 可先查看帮助:菜单栏点击【帮助】->【驱动帮助】,在弹出窗口中PLC系列选择莫迪康PLC的“modbusRtu\ASSCII\TCP”查看组态配置流程: 相关说明: 1、…

Springboot中如何记录好日志

Springboot中如何记录日志 日志体系整体介绍 日志一直在系统中占据这十分重要的地位,他是我们在系统发生故障时用来排查问题的利器,也是我们做操作审计的重要依据。那么如何记录好日志呢?选择什么框架来记录日志,是不是日志打越…

uni-app原生api的promise化以解决异步等待问题分析

相信各位在进行uni-app开发的时候会遇到各种关于异步回调问题,例如要传code给后端以换取session_key,在这之前需要先调用 uni.login,所以执行的顺序是必须同步等待的。在写这篇文章之前对于整体的流程概念需要做一个梳理,以便能更…

SpringCloud-Docker安装与详解

Docker 是一款强大的容器化平台,通过其轻量级的容器技术,使应用程序的开发、部署和管理变得更加便捷和高效。本文将深入探讨 Docker 的安装过程,并详细解析其基本概念、组件及常用命令,以帮助读者充分理解和熟练使用 Docker。企业…

【kubernetes】关于k8s集群中kubectl的陈述式资源管理

目录 一、k8s集群资源管理方式分类: (1)陈述式资源管理方式:增删查比较方便,但是改非常不方便 (2)声明式资源管理方式:yaml文件管理 二、陈述式资源管理方法: 三、ku…

linux查看服务器内核CUP版本相关命令

服务器参考 计算架构:x86-64产品系列:华为云耀云服务器操作系列:CentOS 7 执行uname -a查看服务器内核版本 Linux hecs-82210 3.10.0-1160.92.1.el7.x86_64 #1 SMP Tue Jun 20 11:48:01 UTC 2023 x86_64 x86_64 x86_64 GNU/Linux执行hostn…

【Flink CDC(一)】实现mysql整表与增量读取

文章目录 一. 运行前准备1. 依赖1.1. Maven dependency1.2. SQL Client JAR(推荐) 2. 配置 MySQL 服务器(必须) 二. 功能说明1. 启动模式2. 全量阶段支持 checkpoint3. 关于无主键表Exactly-Once 处理 三. 实战1. 实现mysql整表与…

交叉编译qt到arm平台

使用pkg-config命令查看xxx包是否存在: pkg-config --print-errors xxx pkg-config的搜索路径可以通过环境变量PKG_CONFIG_PATH指定。需要在运行./configure 之前指定。 ./configure -release -qt-libjpeg -qt-libpng -qt-zlib -qt-pcre -xplatform linux-aarch64-…

一文读懂 Python 值传递和引用传递

文章目录 版本前言形参和实参值传递和引用传递Python 变量存储值语义和引用语义值语义引用语义 探讨 Python 值传递和引用传递不可变(immutable)类型可变(mutable)类型案例一案例二 拓展:不可变类型真的不可变&#xf…

华为---RSTP(三)---P/A机制及RSTP的生成树形成过程

目录 1. P/A机制简介 1.1 P/A机制的作用 1.2 P/A协商的前提条件 1.3 RSTP选举思路 2. P/A协商过程 3. 举例说明RSTP的生成树形成过程 3.1 示例环境要求 3.2 RSTP的生成树形成过程 3.2.1 SW和SW1之间链路上抓包分析 3.2.2 SW和SW2之间链路上抓包分析 3.2.3 SW1和SW2之…

LeetCode_Java_转变日期格式、种花问题(题目+思路+代码)

目录 1507. 转变日期格式 605.种花问题 1507. 转变日期格式 给你一个字符串 date ,它的格式为 Day Month Year ,其中: Day 是集合 {"1st", "2nd", "3rd", "4th", ..., "30th", "3…

【计算机图形学】Where2Act: From Pixels to Actions for Articulated 3D Objects

文章目录 1.论文做了件什么事儿2. 论文为什么要做这件事3. 介绍Introduction4. 相关工作预测语义表达推理几何和物体属性从被动观察中学习Affordance从交互中学习感知 5. 问题陈述6. 方法6.1 网络模块主干特征提取器可运动性评分模块运动建议模块运动评分模块 6.2 训练数据收集…

路由器端口映射如何配置?

在网络通信中,路由器是一个重要的设备,它负责将数据包从一个网络传输到另一个网络。路由器的端口映射配置是一种重要的设置,可以使外部网络中的计算机通过访问路由器上的特定端口与内部网络中的计算机进行通信。本文将介绍什么是路由器端口映…