完整的模型训练套路(一、二、三)

搭建神经网络

在这里插入图片描述

model

import torch
from torch import nn

#搭建神经网络
class Tudui(nn.Module):
    def __init__(self):
        super(Tudui, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64 * 4 * 4, 64),
            nn.Linear(64, 10)
        )
    def forward(self, x):
        x = self.model(x)
        return x

if __name__ == '__main__':
    tudui = Tudui()
    input = torch.ones((64, 3, 32, 32))
    output = tudui(input)
    print(output.size())  # torch.Size([64, 10])

train

import torchvision
from model import *
from torch.utils.data import DataLoader

#准备数据集
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))  # 50000
print("测试数据集的长度为:{}".format(test_data_size))  # 10000
#利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

#创建网络模型
tudui = Tudui()
#损失函数
loss_fn = nn.CrossEntropyLoss()
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epochs = 10

for epoch in range(epochs):
    print("------第{}轮训练开始------".format(epoch+1))
    #训练步骤开始
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

        #优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print("训练次数: {}, Loss: {}".format(total_train_step, loss))  # loss.item()

    #测试步骤开始
    total_test_loss = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss
        print("整体测试集上的Loss: {}".format(total_test_loss))

确实每轮有所提升
在这里插入图片描述

添加tensorboard

writer = SummaryWriter(log_dir='./logs_train')
writer.add_scalar('train_loss', loss, total_train_step)
writer.add_scalar('test_loss', total_test_loss, total_test_step)
total_test_step += 1
writer.close()

test_loss train_loss

在这里插入图片描述

保存模型

torch.save(tudui, "tudui_{}.pth".format(epoch+1))
print('模型已保存')

整体代码

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

from model import *
from torch.utils.data import DataLoader

#准备数据集
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))  # 50000
print("测试数据集的长度为:{}".format(test_data_size))  # 10000
#利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

#创建网络模型
tudui = Tudui()
#损失函数
loss_fn = nn.CrossEntropyLoss()
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epochs = 10

#添加tensorboard
writer = SummaryWriter(log_dir='./logs_train')

for epoch in range(epochs):
    print("------第{}轮训练开始------".format(epoch+1))
    #训练步骤开始
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

        #优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print("训练次数: {}, Loss: {}".format(total_train_step, loss))  # loss.item()
            writer.add_scalar('train_loss', loss, total_train_step)

    #测试步骤开始
    total_test_loss = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss
        print("整体测试集上的Loss: {}".format(total_test_loss))
        writer.add_scalar('test_loss', total_test_loss, total_test_step)
        total_test_step += 1

        torch.save(tudui, "tudui_{}.pth".format(epoch+1))
        print('模型已保存')

writer.close()

预测

在这里插入图片描述

import torch
outputs = torch.tensor([[0.1, 0.2],
                        [0.3, 0.4]])
print(outputs.argmax(dim=1))  # 取最大值的位置;1横着看, 0竖着看

在这里插入图片描述

预测的正确率

import torch

outputs = torch.tensor([[0.1, 0.2],
                        [0.3, 0.4]])
print(outputs.argmax(dim=1))  # 取最大值的位置;1横着看, 0竖着看
preds = outputs.argmax(1)
targets = torch.tensor([0, 1])
print((preds == targets).sum())  # 对应位置相等的个数

在这里插入图片描述

对源代码的进行修改(增正确取率)

主要加了这一句,看分类的正确率

total_accuracy += (outputs.argmax(1) == targets).sum()

完整

import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

from model import *
from torch.utils.data import DataLoader

#准备数据集
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=torchvision.transforms.ToTensor())
test_data = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=torchvision.transforms.ToTensor())
train_data_size = len(train_data)
test_data_size = len(test_data)
print("训练数据集的长度为:{}".format(train_data_size))  # 50000
print("测试数据集的长度为:{}".format(test_data_size))  # 10000
#利用Dataloader来加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

#创建网络模型
tudui = Tudui()
#损失函数
loss_fn = nn.CrossEntropyLoss()
#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(tudui.parameters(), lr=learning_rate)

#设置训练网络的一些参数
#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epochs = 10

#添加tensorboard
writer = SummaryWriter(log_dir='./logs_train')

for epoch in range(epochs):
    print("------第{}轮训练开始------".format(epoch+1))
    #训练步骤开始
    for data in train_dataloader:
        imgs, targets = data
        outputs = tudui(imgs)
        loss = loss_fn(outputs, targets)

        #优化器优化模型
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step += 1
        if total_train_step % 100 == 0:
            print("训练次数: {}, Loss: {}".format(total_train_step, loss))  # loss.item()
            writer.add_scalar('train_loss', loss, total_train_step)

    #测试步骤开始
    total_test_loss = 0
    total_accuracy = 0
    with torch.no_grad():
        for data in test_dataloader:
            imgs, targets = data
            outputs = tudui(imgs)
            loss = loss_fn(outputs, targets)
            total_test_loss += loss
            total_accuracy += (outputs.argmax(1) == targets).sum()

        print("整体测试集上的Loss: {}".format(total_test_loss))
        print("整体测试集上的正确率: {}".format(total_accuracy/test_data_size))
        writer.add_scalar('test_loss', total_test_loss, total_test_step)
        writer.add_scalar('test_accuracy', total_accuracy/test_data_size, total_test_step)
        total_test_step += 1

        torch.save(tudui, "tudui_{}.pth".format(epoch+1))
        print('模型已保存')

writer.close()

正确率是有提升的
在这里插入图片描述
在这里插入图片描述

(三)细节

tudui.train()
tudui.eval()

并不是这样才能开始,仅对部分层有用,比如Dropout

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/311509.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

蓝桥杯练习题(二)

📑前言 本文主要是【算法】——蓝桥杯练习题(二)的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 …

单片机中的PWM(脉宽调制)的工作原理以及它在电机控制中的应用。

目录 工作原理 在电机控制中的应用 脉宽调制(PWM)是一种在单片机中常用的控制技术,它通过调整信号的脉冲宽度来控制输出信号的平均电平。PWM常用于模拟输出一个可调电平的数字信号,用于控制电机速度、亮度、电压等。 工作原理 …

【Helm 及 Chart 快速入门】03、Chart 基本介绍

目录 一、Chart 基本介绍 1.1 什么是 Chart 1.2 Chart ⽬录结构 1.3 Chart.yaml ⽂件 二、创建不可配置 Chart 2.1 创建 Chart 2.2 安装 Chart 三、创建可配置的 Chart 3.1 修改 chart 3.2 安装 Chart 一、Chart 基本介绍 1.1 什么是 Chart Helm 部署的应…

List列表操作中的坑

使用 Arrays.asList 把数据转换为 List 的三个坑 在如下代码中,我们初始化三个数字的 int[]数组,然后使用 Arrays.asList 把数组转换为 List: int[] arr {1, 2, 3}; List list Arrays.asList(arr); log.info("list:{} size:{} class…

每天刷两道题——第十二天+第十三天

1.1合并区间 以数组 i n t e r v a l s intervals intervals 表示若干个区间的集合,其中单个区间为 i n t e r v a l s [ i ] [ s t a r t i , e n d i ] intervals[i] [starti, endi] intervals[i][starti,endi] 。请你合并所有重叠的区间,并返回 …

Visual Studio Code 连接远程服务器方法

1、输入用户名和服务器ip连接远程服务器 2、选择配置文件 配置文件路径:C:\Users\Administrator\.ssh\config config的内容大致如下: Host 192.168.134.3HostName 192.168.134.3User zhangshanHost 192.168.134.3HostName 192.168.134.3User lisiHost…

基础篇_快速入门(Java简介,安装JDK,cmd命令行运行Java文件产生乱码问题的解决方式,IDE工具,实用工具)

文章目录 一. Java 简介1. JVM2. JRE3. JDK 二. 安装 JDK1. 下载和安装2. 配置 Path3. 配置 JAVA_HOME(选讲)优化 三. 入门案例1. 第一行代码1) jshell2) 代码解读总结 3) 为何要分成对象与方法 2. 第一份源码1) 源码结构2) 编写 java 源代码3) 编译 jav…

聊一聊 C# 线程切换后上下文都去了哪里

一:背景 1. 讲故事 总会有一些朋友问一个问题,在 Windows 中线程做了上下文切换,请问被切的线程他的寄存器上下文都去了哪里?能不能给我挖出来?这个问题其实比较底层,如果对操作系统没有个体系层面的理解…

groovy XmlParser 递归遍历 xml 文件,修改并保存

使用 groovy.util.XmlParser 解析 xml 文件,对文件进行修改(新增标签),然后保存。 是不是 XmlParser 没有提供方法遍历每个节点,难道要自己写? 什么是递归? 不用说,想必都懂得~ …

基于Pixhawk和ROS搭建自主无人车(一):底盘控制篇

参考 ArduPilot Development超维空间科技乐迪MiniPix车船使用说明书 1. 硬件篇 1.1 底盘构成一览 1.2 底盘接线示意 2. 软件篇 2.1 APM 固件下载 pixhawk 是硬件平台,PX4 是 pixhawk 的原生固件,APM(Ardupilot Mega)是硬件平台…

C++里main函数int main(int argc, char **argv)

C里main函数int main(int argc, char **argv), 这两个参数argc和argv分别是什么

安全帽/反光衣检测AI智能分析网关V4如何查看告警信息并进行处理?

智能分析网关V4属于高性能、低功耗的软硬一体AI边缘计算硬件设备,目前拥有3种型号(8路/16路/32路),支持Caffe / DarkNet / TensorFlow / PyTorch / MXNet / ONNX / PaddlePaddle等主流深度学习框架。硬件内部署了近40种AI算法模型…

9个icon图标网站,海量免费矢量图标库!

​划到最后“阅读原文”——领取工具包(超过1000工具,免费素材网站分享和行业报告) Hi,我是胡猛夫~,专注于分享各类价值网站、高效工具! 更多内容,更多资源,欢迎交流!公…

MacOS安装Miniforge、Tensorflow、Jupyter Lab等(2024年最新)

大家好,我是邵奈一,一个不务正业的程序猿、正儿八经的斜杠青年。 1、世人称我为:被代码耽误的诗人、没天赋的书法家、五音不全的歌手、专业跑龙套演员、不合格的运动员… 2、这几年,我整理了很多IT技术相关的教程给大家&#xff0…

应用案例 | 基于三维机器视觉的自动化无序分拣解决方案

​ 近年来,电商行业蓬勃发展,订单的海量化、订单类型的碎片化,使物流行业朝着“多品种、无边界、分类广”的方向迅速发展。根据许多研究机构的预测,电子商务销售额预计将以每年两位数的速度增长,推动整个行业的规模不…

【排序】快速排序(C语言实现)

文章目录 前言1. Hoare思想2. 挖坑法3. 前后指针法4. 三路划分5. 快速排序的一些小优化5.1 三数取中常规的三数取中伪随机的三数取中 5.2 小区间优化 6. 非递归版本的快排7. 快速排序的特性总结 前言 快速排序是Hoare于1962年提出的一种二叉树结构的交换排序方法,其…

Leetcode 416 分割等和子集

题意理解: 给你一个 只包含正整数 的 非空 数组 nums 。请你判断是否可以将这个数组分割成两个子集,使得两个子集的元素和相等。 即将数组的元素分成两组,每组数值sum(nums)/2 若能分成这样的两组,则返回true,否则返回false 本质上…

国标28181平台的手机视频监控客户端的电子地图功能对比

目 录 一、手机客户端 1、概述 2、具体功能简述 二、电子地图功能 1、经纬度定位 2、附近设备 3、实时浏览功能 4、录像回放 5、缩放功能 三、手机web客户端和CS客户端上的电子地图功能对比 1、对比表 2、测距(PC客户端功能) 3…

【分布式技术】rsync远程同步服务

目录 一、rsync(远程同步) 二、实操rsync远程文件同步 准备一个服务端192.168.20.18以及一个客户端192.168.20.30 1、服务端搭建:先完成服务端配置,启动服务 rsync拓展 1、关于rsyncd服务的端口号 2、rsync和scp的区别 2、测…

在qemu虚拟机环境下,使用kgdb调试kernel

enable kgdb的情况下,使用qemu启动kernel 1,需要先在内核配置中增加kgdb的支持 2,启动qemu虚拟机时,增加参数-s -S,这两个参数会使得kernel在启动之后遇到的第一个指令等待gdb连接 例子: /qemu-project…