完整网络模型训练(一)

文章目录

    • 一、网络模型的搭建
    • 二、网络模型正确性检验
    • 三、创建网络函数

一、网络模型的搭建

以CIFAR10数据集作为训练例子

准备数据集:

#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)

查看数据集的长度:

train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")

运行结果:
在这里插入图片描述

利用DataLoader来加载数据集:

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

搭建CIFAR10数据集神经网络:
在这里插入图片描述
卷积层【1】代码解释:
#第一个数字3表示inputs(可以看到图中为3),第二个数字32表示outputs(图中为32)
#第三个数字5为卷积核(图中为5),第四个数字1表示步长(stride)
#第五个数字表示padding,需要计算,计算公式:
在这里插入图片描述

nn.Conv2d(3, 32, 5, 1, 2)

最大池化代码解释:
#数字2表示kernel卷积核

nn.MaxPool2d(2)

读图
卷积层【1】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述

最大池化【1】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述
卷积层【2】的Inputs 和 Outputs是下图这两个:
在这里插入图片描述
以此类推

展平:
在这里插入图片描述
Flatten后它会变成64*4 *4的一个结果

线性输出:
在这里插入图片描述
线性输入是64*4 *4,线性输出是64,故如下代码
nn.LInear(64 *4 *4,64)

继续线性输出
在这里插入图片描述
nn.LInear(64,10)

搭建网络完整代码:

class Sen(nn.Module):
    def __init__(self):
        super(Sen, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1 ,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x

二、网络模型正确性检验

if __name__ == '__main__':
    sen = Sen()
    input = torch.ones((64, 3, 32, 32))
    output = sen(input)
    print(output.shape)

注释:

input = torch.ones((64, 3, 32, 32))

这一行代码的含义是:创建一个大小为 (64, 3, 32, 32) 的全 1 张量,数据类型为 torch.float32。
64:这是批次大小,代表输入有 64 张图片。
3:这是图片的通道数,通常为 RGB 图像的三个通道 (红、绿、蓝)。
32, 32:这是图片的高和宽,表示每张图片的尺寸为 32x32 像素。
torch.ones 函数用于生成一个全 1 的张量,这里的张量形状适合用于输入图像分类或卷积神经网络(CNN)中常见的 CIFAR-10 或类似的 32x32 像素图像数据。

运行结果:
在这里插入图片描述
可以得到成功变成了【64, 10】的结果。

三、创建网络函数

创建网络模型:

sen = Sen()

搭建损失函数:

loss_fn = nn.CrossEntropyLoss()

优化器:

learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)

优化器注释:
使用随机梯度下降(SGD)优化器
learning_rate = 1e-2 这里的1e-2代表的是:1 x (10)^(-2) = 1/100 = 0.01

记录训练的次数:

total_train_step = 0

记录测试的次数:

total_test_step = 0

训练的轮数:

epoch= 10

进行循环训练:

for i in range(epoch):
    print(f"第{i+1}轮训练开始")

    for data in train_dataloader:
        imgs, targets = data
        outputs = sen(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        print(f"训练次数:{total_train_step},Loss:{loss.item()}")

注释:
imgs, targets = data是解包数据,imgs 是输入图像,targets 是目标标签(真实值)
outputs = sen(imgs)将输入图像传入模型 ‘sen’,得到模型的预测输出 outputs
loss = loss_fn(outputs, targets)计算损失值(Loss),loss_fn 是损失函数,它比较outputs的值与targets 是目标标签(真实值)的误差
optimizer.zero_grad()清除优化器中上一次计算的梯度,以免梯度累积
loss.backward()反向传播,计算损失相对于模型参数的梯度
optimizer.step()使用优化器更新模型的参数,以最小化损失
loss.item() 将张量转换为 Python 的数值
loss.item演示:

import torch
a = torch.tensor(5)
print(a)
print(a.item())

运行结果:
在这里插入图片描述
因此可以得到:item的作用是将tensor变成真实数字5

本章节完整代码展示:

import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader

class Sen(nn.Module):
    def __init__(self):
        super(Sen, self).__init__()
        self.model = nn.Sequential(
            nn.Conv2d(3, 32, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, 1 ,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, 1, 2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(64*4*4, 64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model(x)
        return x
#准备数据集
#因为CIFAR10是属于PRL的数据集,所以需要转化成tensor数据集
train_data = torchvision.datasets.CIFAR10(root="./data", train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root="./data", train=False, transform=torchvision.transforms.ToTensor(),download=True)

#length长度
train_data_size = len(train_data)
test_data_size = len(test_data)
print(f"训练数据集的长度为{train_data_size}")
print(f"测试数据集的长度为{test_data_size}")

train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data,batch_size=64)

sen = Sen()

#损失函数
loss_fn = nn.CrossEntropyLoss()

#优化器
learning_rate = 1e-2
optimizer = torch.optim.SGD(sen.parameters(), lr=learning_rate)

#记录训练的次数
total_train_step = 0
#记录测试的次数
total_test_step = 0
#训练的轮数
epoch= 10

for i in range(epoch):
    print(f"第{i+1}轮训练开始")

    for data in train_dataloader:
        imgs, targets = data
        outputs = sen(imgs)
        loss = loss_fn(outputs, targets)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_train_step = total_train_step + 1
        print(f"训练次数:{total_train_step},Loss:{loss.item()}")

运行结果:
在这里插入图片描述
可以看到训练的损失函数在一直进行修正。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/885654.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

YOLO11震撼发布!

非常高兴地向大家介绍 Ultralytics YOLO系列的新模型: YOLO11! YOLO11 在以往 YOLO 模型基础上带来了一系列强大的功能和优化,使其速度更快、更准确、用途更广泛。主要改进包括 增强了特征提取功能,从而可以更精确地捕捉细节以更…

[云]Kubernetes 的基础知识

目标: 实践实验室涵盖 Kubernetes 的基础知识(这个句子的意思是在实验室中通过实践学习 Kubernetes 的基本概念) 在此过程中理解 Kubernetes 概念(这个句子的意思是在学习的过程中理解 Kubernetes 的相关概念) 议程&…

【无人机设计与技术】四旋翼无人机的建模

摘要 本项目的目标是通过 Simulink 建模和仿真,研究四旋翼无人机的建模、姿态控制、定点位置控制及航点规划功能。无人机建模包含了动力单元模型、控制效率模型和刚体模型,并运用这些模型实现了姿态控制和位置控制。姿态控制为无人机的平稳飞行提供基础…

OpenCV normalize() 函数详解及用法示例

OpenCV的normalize函数用于对数组(图像)进行归一化处理,即将数组中的元素缩放到一个指定的范围或具有一个特定的标准(如均值和标准差)。它有两个原型函数, 如下: Normalize()规范化数组的范数或值范围。当normTypeNORM…

制造企业为何需要PLM系统?PLM系统解决方案对制造业重要性分析

制造企业为何需要PLM系统?PLM系统解决方案对制造业重要性分析 新华社9月23日消息,据全国组织机构统一社会信用代码数据服务中心统计,我国制造业企业总量突破600万家。数据显示,2024年1至8月,我国制造业企业数量呈现稳…

简单线性回归分析-基于R语言

本题中&#xff0c;在不含截距的简单线性回归中&#xff0c;用零假设对统计量进行假设检验。首先&#xff0c;我们使用下面方法生成预测变量x和响应变量y。 set.seed(1) x <- rnorm(100) y <- 2*xrnorm(100) &#xff08;a&#xff09;不含截距的线性回归模型构建。 &…

计算机视觉综述

大家好&#xff0c;今天&#xff0c;我们将一起探讨计算机视觉的基本概念、发展历程、关键技术以及未来趋势。计算机视觉是人工智能的一个重要分支&#xff0c;旨在使计算机能够“看”懂图像和视频&#xff0c;从而完成各种复杂的任务。无论你是对这个领域感兴趣的新手&#xf…

Linux操作系统中MongoDB

1、什么是MongoDB 1、非关系型数据库 NoSQL&#xff0c;泛指非关系型的数据库。随着互联网web2.0网站的兴起&#xff0c;传统的关系数据库在处理web2.0网站&#xff0c;特别是超大规模和高并发的SNS类型的web2.0纯动态网站已经显得力不从心&#xff0c;出现了很多难以克服的问…

SpringBoot整合JPA详解

SpringBoot版本是2.0以上(2.6.13) JDK是1.8 一、依赖 <dependencies><!-- jdbc --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-jdbc</artifactId></dependency><!--…

C# C++ 笔记

第一阶段知识总结 lunix系统操作 1、基础命令 &#xff08;1&#xff09;cd cd /[目录名] 打开指定文件目录 cd .. 返回上一级目录 cd - 返回并显示上一次目录 cd ~ 切换到当前用户的家目录 &#xff08;2&#xff09;pwd pwd 查看当前所在目录路径 pwd -L 打印当前物理…

Unity实战案例全解析:RTS游戏的框选和阵型功能(5)阵型功能 优化

前篇&#xff1a;Unity实战案例全解析&#xff1a;RTS游戏的框选和阵型功能&#xff08;4&#xff09;阵型功能-CSDN博客 本案例来源于unity唐老狮&#xff0c;有兴趣的小伙伴可以去泰克在线观看该课程 我只是对重要功能进行分析和做出笔记分享&#xff0c;并未无师自通&#x…

ARM Process state -- SPSR

Holds the saved process state for the current mode. 保存当前模式的已保存进程状态。 N, bit [31] Set to the value of PSTATE.N on taking an exception to the current mode, and copied to PSTATE.N on executing an exception return operation in the current mod…

袋鼠云数据资产平台:数据模型标准化建表重构升级

数据模型是什么&#xff1f;简单来说&#xff0c;数据模型是用来组织和管理数据的一种方式。它为构建高效且可靠的信息系统提供了基础&#xff0c;不仅决定了如何存储和管理数据&#xff0c;还直接影响系统的性能和可扩展性。 想要建立一个良好的数据模型&#xff0c;设计时需…

链表的基础知识

文章目录 概要整体架构流程 小结 概要 链表是一种常见的数据结构&#xff0c;它通过节点之间的连接关系实现数据的存储和访问。链表由一系列节点&#xff08;Node&#xff09;组成&#xff0c;每个节点包含数据和指向下一个节点的指针。链表的特点是物理存储单元上非连续、非顺…

Qt的互斥量用法

目的 互斥量的概念 互斥量是一个可以处于两态之一的变量:解锁和加锁。这样&#xff0c;只需要一个二进制位表示它&#xff0c;不过实际上&#xff0c;常常使用一个整型量&#xff0c;0表示解锁&#xff0c;而其他所有的值则表示加锁。互斥量使用两个过程。当一个线程(或进程)…

网络编程,端口号,网络字节序,udp

前面一篇我们讲了网络的基础&#xff0c;网络协议栈是什么样的&#xff0c;数据如何流动传输的&#xff1b;接下来这篇&#xff0c;我们将进行实践操作&#xff0c;真正的让数据跨网络进行传输&#xff1b; 1.网络编程储备知识 1.1 初步认识网络编程 首先我们需要知道我们的…

Java基础 3. 面向对象

Java基础 3. 面向对象 文章目录 Java基础 3. 面向对象3.1. 面向对象3.2. 对象的创建和使用3.3. 封装3.4. 构造方法3.5. this关键字3.6. static关键字JVM体系结构 [^现阶段不用掌握]3.7. 单例模式 [^初级]3.8. 继承3.9. 方法覆盖3.10. 多态3.11. super关键字3.12. final关键字3.…

你的虚拟猫娘女友,快来领取!--文心智能体平台

文章目录 一、引言二、赛事介绍2.1 简介2.2 比赛时间2.3 大赛具体链接2.4 第一期赛题 三、智能体创建流程3.1 进入文心智能体平台3.1 创建智能体3.1 虚拟猫娘女友特性3.1 智能体调优 四、引言智能体测试五、结语 一、引言 我是热爱生活的通信汪&#xff0c;今天这篇博文记录一…

[CSP-J 2022] 解密

题目来源&#xff1a;洛谷题库 [CSP-J 2022] 解密 题目描述 给定一个正整数 k k k&#xff0c;有 k k k 次询问&#xff0c;每次给定三个正整数 n i , e i , d i n_i, e_i, d_i ni​,ei​,di​&#xff0c;求两个正整数 p i , q i p_i, q_i pi​,qi​&#xff0c;使 n …

C语言 | Leetcode C语言题解之第448题找到所有数组中消失的数字

题目&#xff1a; 题解&#xff1a; int* findDisappearedNumbers(int* nums, int numsSize, int* returnSize) {for (int i 0; i < numsSize; i) {int x (nums[i] - 1) % numsSize;nums[x] numsSize;}int* ret malloc(sizeof(int) * numsSize);*returnSize 0;for (in…