PyTorch之完整的神经网络模型训练

简单的示例:

在PyTorch中,可以使用nn.Module类来定义神经网络模型。以下是一个示例的神经网络模型定义的代码:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        
        # 定义神经网络的层和参数
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.fc1 = nn.Linear(32 * 14 * 14, 128)
        self.fc2 = nn.Linear(128, 10)
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.maxpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.softmax(x)
        
        return x

在上面的示例中,定义了一个名为MyModel的神经网络模型,继承自nn.Module类。在__init__方法中,我们定义了模型的层和参数。具体来说:

  • 代码定义了一个卷积层,输入通道数为1,输出通道数为32,卷积核大小为3x3,步长为1,填充为1。
  • 定义了一个ReLU激活函数,用于在卷积层之后引入非线性性质。
  • 定义了一个最大池化层,池化核大小为2x2,步长为2。
  • 定义了一个全连接层,输入大小为32x14x14(经过卷积和池化后的特征图大小),输出大小为128。
  • 定义了另一个全连接层,输入大小为128,输出大小为10。
  • 定义了一个softmax函数,用于将模型的输出转换为概率分布。

forward方法中,定义了模型的前向传播过程。具体来说:

  • x = self.conv1(x): 将输入张量传递给卷积层进行卷积操作。
  • x = self.relu(x): 将卷积层的输出通过ReLU激活函数进行非线性变换。
  • x = self.maxpool(x): 将ReLU激活后的特征图进行最大池化操作。
  • x = x.view(x.size(0), -1): 将池化后的特征图展平为一维,以适应全连接层的输入要求。
  • x = self.fc1(x): 将展平后的特征向量传递给第一个全连接层。
  • x = self.relu(x): 将第一个全连接层的输出通过ReLU激活函数进行非线性变换。
  • x = self.fc2(x): 将第一个全连接层的输出传递给第二个全连接层。
  • x = self.softmax(x): 将第二个全连接层的输出通过softmax函数进行归一化,得到每个类别的概率分布。

这个示例展示了一个简单的卷积神经网络模型,适用于处理单通道的图像数据,并输出10个类别的分类结果。可以根据自己的需求和数据特点来定义和修改神经网络模型。

接下来将用于实际的数据集进行训练:

以下是基于CIFAR10数据集的神经网络训练模型:

import torch
import torchvision
from torch import nn
from torch.nn import MaxPool2d
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from nn_mode import *

#准备数据集
train_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=True,transform=torchvision.transforms.ToTensor())
test_data=torchvision.datasets.CIFAR10(root='../chap4_Dataset_transforms/dataset',train=False,transform=torchvision.transforms.ToTensor())
#输出数据集的长度
train_data_size=len(train_data)
test_data_size=len(test_data)
print(train_data_size)
print(test_data_size)
#加载数据集
train_loader=DataLoader(dataset=train_data,batch_size=64)
test_loader=DataLoader(dataset=test_data,batch_size=64)
#创建神经网络
sjnet=Sjnet()

#损失函数
loss_fn=nn.CrossEntropyLoss()
#优化器
learn_lr=0.01#便于修改
YHQ=torch.optim.SGD(sjnet.parameters(),lr=learn_lr)

#设置训练网络的参数
train_step=0#训练次数
test_step=0#测试次数
epoch=10#训练轮数

writer=SummaryWriter('wanzheng_logs')

for i in range(epoch):
    print("第{}轮训练".format(i+1))
    #开始训练
    for data in train_loader:
        imgs,targets=data
        outputs=sjnet(imgs)
        loss=loss_fn(outputs,targets)
        #优化器
        YHQ.zero_grad()  # 将神经网络的梯度置零,以准备进行反向传播
        loss.backward()  # 执行反向传播,计算神经网络中各个参数的梯度
        YHQ.step()  # 调用优化器的step()方法,根据计算得到的梯度更新神经网络的参数,完成一次参数更新
        train_step =train_step+1
        if train_step%100==0:
            print('训练次数为:{},loss为:{}'.format(train_step,loss))
            writer.add_scalar('train_loss',loss,train_step)

    #开始测试
    total_loss=0
    with torch.no_grad():#上下文管理器,用于指示在接下来的代码块中不计算梯度。
        for data in test_loader:
            imgs,targets=data
            outputs = sjnet(imgs)
            loss = loss_fn(outputs, targets)#使用损失函数 loss_fn 计算预测输出与目标之间的损失。
            total_loss=total_loss+loss#将当前样本的损失加到总损失上,用于累积所有样本的损失。
    print('整体测试集上的loss:{}'.format(total_loss))
    writer.add_scalar('test_loss', total_loss, test_step)
    test_step = test_step+1

    torch.save(sjnet,'sjnet_{}.pth'.format(i))
    print("模型已保存!")

writer.close()

 其神经网络训练以及测试时的损失值使用TensorBoard进行展示,如图所示:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/453684.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

DPN网络

DPN DPN(Dual Path Networks)是一种网络结构,它结合了DensNet和ResNetXt两种思想的优点。这种结构的目的是通过不同的路径来利用神经网络的不同特性,从而提高模型的效率和性能。 DenseNet 的特点是其稠密连接路径,使…

【Python】新手入门学习:详细介绍开放封闭原则(OCP)及其作用、代码示例

【Python】新手入门学习:详细介绍开放封闭原则(OCP)及其作用、代码示例 🌈 个人主页:高斯小哥 🔥 高质量专栏:Matplotlib之旅:零基础精通数据可视化、Python基础【高质量合集】、PyT…

「哈哥赠书活动 - 50期」-『AI赋能写作:AI大模型高效写作一本通』

⭐️ 赠书 - 《AI赋能写作:AI大模型高效写作一本通》 ⭐️ 内容简介 本书以ChatGPT为科技行业带来的颠覆性革新为起点,深入探讨了人工智能大模型如何为我们的创作提供强大支持。本书旨在帮助创作者更好地理解AI的价值,并充分利用其能力提升写…

复习C的内存管理

来自:漫谈C语言内存管理_c语言内存管理机制-CSDN博客 C语言学习笔记 —— 内存管理_c语言内存管理-CSDN博客 C语言是音视频开发所必须的。 变量是一段连续内存空间的别名。变量的类型是固定内存大小的别名。但是类型不是只确定了变量内存大小,还确定了…

代码随想录-java-栈与队列总结

栈(Stack):是只允许在一端进行插入或删除的线性表。栈是一种线性表,限定这种线性表只能在某一端进行插入和删除操作。进行操作的这一端称为栈顶。 队列(Queue)是只允许在一端进行插入操作,而在另…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:TextPicker)

滑动选择文本内容的组件。 说明: 该组件从API Version 8开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 子组件 无 接口 TextPicker(options?: {range: string[] | string[][] | Resource | TextPickerRangeContent[] | Te…

自动生成单元测试、外挂开源代码库等新功能,上线JetBrains IDEs的CodeGeeX插件!

CodeGeeX第三代模型发布后,多项基于第三代模型能力的新功能今天也同步上线JetBrains IDEs全家桶。 用户可以在IDEA、PyCharm等JetBrains系的IDE中,搜索下载CodeGeeX v2.5.0版本,深度使用最新功能。 一、新模型加持的代码补全和智能问答 …

【Java基础概述-8】Lambda表达式概述、方法引用、Stream流的使用

1、Lambda表达式概述 什么是Lambda表达式? Lambda表达式是JDK1.8之后的新技术,是一种代码的新语法。 是一种特殊写法。 作用:“核心的目的是为了简化匿名内部类的写法”。 Lambda表达式的格式: (匿名内部类被重写的形参列表){ 被重写的代码 …

[MYSQL数据库]--表内操作(CURD)

前言 作者:小蜗牛向前冲 名言:我可以接受失败,但我不能接受放弃 如果觉的博主的文章还不错的话,还请点赞,收藏,关注👀支持博主。如果发现有问题的地方欢迎❀大家在评论区指正 目录 一、表的 Cre…

AI时代Python金融大数据分析实战:ChatGPT让金融大数据分析插上翅膀【文末送书-38】

文章目录 Python驱动的金融智能:数据分析、交易策略与风险管理Python在金融数据分析中的应用 实战案例:基于ChatGPT的金融事件预测AI时代Python金融大数据分析实战:ChatGPT让金融大数据分析插上翅膀【文末送书-38】 Python驱动的金融智能&…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:TimePicker)

时间选择组件,根据指定参数创建选择器,支持选择小时及分钟。 说明: 该组件从API Version 8开始支持。后续版本如有新增内容,则采用上角标单独标记该内容的起始版本。 子组件 无 接口 TimePicker(options?: TimePickerOptions)…

【计算机网络】HTTPS协议

🎥 个人主页:Dikz12🔥个人专栏:网络原理📕格言:那些在暗处执拗生长的花,终有一日会馥郁传香欢迎大家👍点赞✍评论⭐收藏 目录 HTTPS是什么? 为什么要加密? …

如何选择最佳文档管理系统?2024年11款企业DMS详细评测!

11款文档管理系统 (DMS)包括:1.PingCode 知识库;2.DocuWare ;3.Dropbox Business ;4.eFileCabinet ;5.Google Drive ;6.Laserfiche ;7.LogicalDOC ;8.M-Files ;9.OnlyOff…

大龄全职宝妈的出路在哪里?怎么突破困境,实现财富自由?

我是电商珠珠 对于女性来说从怀孕到生子起码要经过最低2年的时间,这2年就是空窗期,有的女性需要三年甚至更多的时间去带孩子,别说让财富自由了,人身都很难自由。 特别是现在的大龄全职妈妈,明明还正值壮年&#xff0…

【网络协议】RPC、REST API深入理解及简单demo实现

目录 一、RPC二、REST三、代码3.1 RPC demo3.2 Rest API demo 一、RPC RPC 即远程过程调用(Remote Procedure Call Protocol,简称RPC),像调用本地服务(方法)一样调用服务器的服务(方法)。通常的实现有 XML-RPC , JSON-RPC , 通信…

人工智能测试开发

随着人工智能在各行各业的广泛应用,学习并掌握AI技术在软件测试中的应用变得至关重要。不仅能使你跟上行业的发展趋势,还能提升你的竞争力。而且,市场对具备AI测试技能的测试工程师的需求正日益增长,这使得掌握这些技能能够帮助你…

MySQL知识点极速入门

准备SQL 创建数据库: 创建一个名为emptest的数据库 create database emptest; use emptest; 创建数据表: 设计一张员工信息表,要求如下: 1. 编号(纯数字) 2. 员工工号 (字符串类型,长度不超…

mac启动skywalking报错

这个命令显示已经成功 但是日志报错了以上内容。 然后去修改。vim .bash_profile 查看全局变量,这个jdk却是有2个。所以这个问题没解决。

考研C语言复习进阶(1)

目录 1. 数据类型介绍 1.1 类型的基本归类: 2. 整形在内存中的存储 2.1 原码、反码、补码 2.2 大小端介绍 3. 浮点型在内存中的存储 ​编辑 1. 数据类型介绍 前面我们已经学习了基本的内置类型: char //字符数据类型 short //短整型 int /…

VTK包围盒,AABB包围盒

1. 概述 包围盒是指能够包含三维图形的长方体,常常用于模型的碰撞检测。包围盒可以分成:轴对齐包围盒(AABB),有向包围盒(OBB)和凸包(Convex Hull) 今天就来实践一下AABB包围盒 2.绘制一个锥体 #include "vtkAutoInit.h" VTK_MODULE_INIT(v…