pytorch神经网络训练(VGG-16)

VGG-16

  1. 导包
import torch

import torch.nn as nn

import torch.optim as optim

import torchvision

from torchvision import datasets, transforms

from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
  1. 数据预处理和增强
transform = transforms.Compose([

transforms.Resize((224, 224)), 

调整图像大小为 224x224像素,符合VGG16输入

    transforms.ToTensor(),

transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 

])

使用ImageNet的标准化参数

  1. 加载数据集
data_path = "D:\\图像处理、深度学习\\flowers"

dataset = datasets.ImageFolder(data_path, transform=transform)
  1. 划分训练集和测试集
train_size = int(0.8 * len(dataset))

test_size = len(dataset) - train_size

train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
  1. 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
  1. 使用预训练的VGG16模型
model = torchvision.models.vgg16(pretrained=True)
  1. 修改全连接层以适应新的分类任务
num_classes = len(dataset.classes)

model.classifier[6] = nn.Linear(4096, num_classes)
  1. 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters())
  1. 将模型移动到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
  1. 初始化列表来存储每个epoch的损失和准确率
train_losses = []

train_accuracies = []

  1. 训练模型
num_epochs = 30 

定义训练的轮数为30轮

for epoch in range(num_epochs): 

    model.train() 

    running_loss = 0.0  

    correct = 0 

total = 0 

开始训练循环,将模型设置为训练模式初始化累计损失为0.0,初始化正确预测的数量为0,初始化总样本数量为0

for inputs, labels in train_loader: 

遍历训练数据加载器中的每个批次

        inputs, labels = inputs.to(device), labels.to(device) 

        optimizer.zero_grad() 

将输入和标签移动到定义的设备上(如GPU),清空模型的梯度

        outputs = model(inputs) 

        loss = criterion(outputs, labels) 

将输入传递给模型以获得输出,计算输出和标签之间的损失

        loss.backward() 

        optimizer.step() 

        running_loss += loss.item() 

反向传播损失以计算梯度,更新模型的参数,累计批次损失

        _, predicted = torch.max(outputs.data, 1) 

获取每个样本的最高预测值

        total += labels.size(0) 

        correct += (predicted == labels).sum().item() 

累计总样本数量,累计正确预测的样本数量

epoch_loss = running_loss / len(train_loader) 

计算平均损失

epoch_accuracy = 100 * correct / total 

计算准确率

train_losses.append(epoch_loss) 

train_accuracies.append(epoch_accuracy) 

将损失添加到训练损失列表中,将准确率添加到训练准确率列表中

print(f'Epoch {epoch + 1}/{num_epochs}, Loss: {epoch_loss}, Accuracy: {epoch_accuracy}%') 

打印轮数、损失和准确率

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/734538.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Flutter-实现头像叠加动画效果

实现头像叠加动画效果 在这篇文章中,我们将介绍如何使用 Flutter 实现一个带有透明度渐变效果和过渡动画的头像叠加列表。通过这种效果,可以在图片切换时实现平滑的动画,使 UI 更加生动和吸引人。 需求 我们的目标是实现一个头像叠加列表&…

基于顺序存储的环形队列算法库构建

学习贺利坚老师基于数组的环形队列 数据结构之自建算法库——顺序环形队列_下空队列q中依次入队列数据元素abc-CSDN博客文章浏览阅读5.2k次,点赞6次,收藏6次。本文针对数据结构基础系列网络课程(3):栈和队列中第9课时环形队列的存储及基本操…

基于ESP8266串口WIFI模块ESP-01S在Station模式(即接收无线信号( WiFi))下实现STC单片机与手机端网路串口助手相互通信功能

基于ESP8266串口WIFI模块ESP-01S在Station模式(即接收无线信号( WiFi))下实现STC单片机与手机端网路串口助手相互通信功能 ESP8266_01S引脚功能图ESP8266_01S原理图ESP8266_01S尺寸图检验工作1、USB-TTL串口工具(推荐使用搭载CP2102芯片的安信可USB-T1串口)与ESP8266_01S…

会声会影2024永久破解和谐版下载 包含激活码序列号

亲爱的创作伙伴们,今天我要分享一个让我的影视编辑生活大放异彩的神器——会声会影2024破解版本!🎉🌟 🌈**功能全面升级**:作为一款专业的视频编辑软件,会声会影2024破解版本不仅继承了之前版本…

KEITHLEY吉时利24611kW 脈沖模式 SourceMeter 手侧

KEITHLEY吉时利24611kW 脈沖模式 SourceMeter 手侧

【vue3|第13期】深入了解Vue3生命周期:管理组件的诞生、成长与消亡

日期:2024年6月22日 作者:Commas 签名:(ง •_•)ง 积跬步以致千里,积小流以成江海…… 注释:如果您觉得有所帮助,帮忙点个赞,也可以关注我,我们一起成长;如果有不对的地方&#xf…

【大数据 复习】第8章 Hadoop架构再探讨

一、概念 1.Hadoop1.0的核心组件(仅指MapReduce和HDFS,不包括Hadoop生态系统内的Pig、Hive、HBase等其他组件),主要存在以下不足: (1)抽象层次低,需人工编码 (2&#xf…

【Linux】基础IO_4

文章目录 六、基础I/O4. 动静态库 未完待续 六、基础I/O 4. 动静态库 既然我们能够成功创建静态库了,接下来我们将这个代码打包成动态库: shared: 表示生成共享库格式 fPIC:产生位置无关码(position independent code) 动态库库名规则&…

Swift 中的动态数组

Swift 的 Array 类型是一种强大而灵活的集合类型,可以根据需要自动扩展或缩减其容量。 动态数组的基本概念 Swift 中的数组是基于动态数组(dynamic array)的概念实现的。动态数组能够根据需要自动调整其容量,以容纳新增的元素&a…

第一题(伏羲六十四卦)

题目: 首先伏羲64卦解密 再用base64解密即可

基于SpringBoot+Vue教材订购系统设计和实现(源码+LW+调试文档+讲解等)

💗博主介绍:✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者,博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 🌟文末获取源码数据库🌟 感兴趣的可以先收藏起来,…

Android 14 权限等级剖析

Android 14 权限等级剖析 Android 14 引入了新的权限等级,为应用程序访问敏感信息和功能提供了更细粒度的控制。了解这些新的权限等级及其应用场景对于开发人员至关重要。 1. 概述 Android 一直以来都使用权限机制来控制应用程序对系统资源和用户数据的访问。随着…

Vue81-独享路由守卫

一、 独享路由守卫的定义 当只有某个特定的路由需要做校验的时候,可以在改路由组件规则的地方单独配置独属于改组件的路由守卫。 二、示例 1、需求 系统只在进入新闻路由页面的时候做校验。 2、代码实现 注意: 独享路由守卫,只有前置路由守…

Unity Apple Vision Pro 开发(三):visionOS 应用形态

文章目录 📕教程说明📕常用名词解释📕visionOS 空间类型⭐Shared Space 共享空间⭐Full Space/Immersive Space 独占空间 📕visionOS 渲染框架📕Unity 开发 visionOS 应用的不同模式⭐**窗口模式**⭐VR 模式⭐MR 模式 …

基于matlab的高斯滤波与图像去噪

1 高斯滤波原理 1.1 原理 高斯滤波是一种线性平滑滤波技术,主要用于消除图像中的高斯噪声。它的工作原理可以理解为对整幅图像进行加权平均的过程,即每个像素点的值都由其本身和邻域内的其他像素值经过加权平均后得到。 高斯滤波实质上是一种信号的滤…

动手学深度学习(Pytorch版)代码实践 -卷积神经网络-25使用块的网络VGG

25使用块的网络VGG import torch from torch import nn import liliPytorch as lp import matplotlib.pyplot as plt# 定义VGG块 # num_convs: 卷积层的数量 # in_channels: 输入通道的数量 # out_channels: 输出通道的数量 def vgg_block(num_convs, in_channels, out_channel…

VS Code Arduino编程

①Arduino Arduino是一款便捷灵活、方便上手的开源电子原型平台。包含硬件(各种型号的Arduino板)和软件(Arduino IDE)。 ②VS Code(全称 Visual Studio Code) 是由微软开发的一款开源、轻量级的跨平台现代代码编辑器…

昇思25天学习打卡营第4天|数据变换Transforms

学习内容复盘 1.1 数据变换 什么是数据变换、为何要数据变换 通常情况下,直接加载的原始数据并不能直接送入神经网络进行训练,此时我们需要对其进行数据预处理。MindSpore提供不同种类的数据变换(Transforms),配合数…

Jenkins定时构建自动化(二):Jenkins的定时构建

目录 ​编辑 一、 jenkins定时构建语法: 1. 语法规则: 2. 常见用法举例 3. 再次举例 接上一篇:Jenkins定时构建自动化(一):Jenkins下载安装配置:Jenkins定时构建自动化(一):Jenkins下载安装配置-CSDN博客 …

32 - 判断三角形(高频 SQL 50 题基础版)

32 - 判断三角形 select *,if(xy>z and xz>y and zy > x,Yes,No) triangle fromTriangle;