pytorch神经网络训练(AlexNet)

  • 导包
import os

import torch

import torch.nn as nn

import torch.optim as optim

from torch.utils.data import Dataset, DataLoader

from PIL import Image

from torchvision import models, transforms
  • 定义自定义图像数据集
class CustomImageDataset(Dataset): 

定义一个自定义的图像数据集类,继承自Dataset

def __init__(self, main_dir, transform=None): 

初始化方法,接收主目录和转换方法

        self.main_dir = main_dir 

主目录,包含多个子目录,每个子目录包含同一类别的图像

        self.transform = transform

 图像转换方法,用于对图像进行预处理

        self.files = [] 

存储所有图像文件的路径

        self.labels = [] 

存储所有图像的标签

        self.label_to_index = {} 

创建一个字典,用于将标签映射到索引

        for index, label in enumerate(os.listdir(main_dir)):

 遍历主目录中的所有子目录

 

          self.label_to_index[label] = index 

           label_dir = os.path.join(main_dir, label) 

将标签映射到索引,构建标签子目录的路径

           if os.path.isdir(label_dir): 

               for file in os.listdir(label_dir): 

                    self.files.append(os.path.join(label_dir, file))

                    self.labels.append(label) 

如果是目录,遍历目录中的所有文件,将文件路径添加到列表,将标签添加到列表

def __len__(self):

定义数据集的长度

        return len(self.files) 

返回文件列表的长度

def __getitem__(self, idx): 

定义获取数据集中单个样本的方法

        image = Image.open(self.files[idx]) 

        label = self.labels[idx] 

        if self.transform: 

            image = self.transform(image) 

        return image, self.label_to_index[label] 

打开图像文件,获取图像的标签,如果有转换方法,对图像进行转换,返回图像和对应的标签索引

  • 定义数据转换
transform = transforms.Compose([

    transforms.Resize((227, 227)),  # AlexNet的输入图像大小

    transforms.RandomHorizontalFlip(),  # 随机水平翻转

    transforms.RandomRotation(10),  # 随机旋转

    transforms.ToTensor(),

    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),  # AlexNet的标准化

])

  • 创建数据集
dataset = CustomImageDataset(main_dir="D:\\图像处理、深度学习\\flowers", transform=transform)
  • 创建数据加载器
data_loader = DataLoader(dataset, batch_size=32, shuffle=True)
  • 加载预训练的AlexNet模型
alexnet_model = models.alexnet(pretrained=True)
  • 修改最后几层以适应新的分类任务
num_ftrs = alexnet_model.classifier[6].in_features

alexnet_model.classifier[6] = nn.Linear(num_ftrs, len(dataset.label_to_index))
  • 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(alexnet_model.parameters(), lr=0.0001)
  • 如果有多个GPU,可以使用nn.DataParallel来并行化模型
if torch.cuda.device_count() > 1:

    alexnet_model = nn.DataParallel(alexnet_model)
  • 将模型发送到GPU(如果可用)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

alexnet_model.to(device)                                                               

  • 模型评估
def evaluate_model(model, data_loader, device):

    model.eval()  # 将模型设置为评估模式

    correct = 0

    total = 0

    with torch.no_grad():  # 在这个块中,所有计算都不会计算梯度

        for images, labels in data_loader:

            images, labels = images.to(device), labels.to(device)

            outputs = model(images)

            _, predicted = torch.max(outputs.data, 1)

            total += labels.size(0)

            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total

    return accuracy
  • 训练模型
num_epochs = 10

for epoch in range(num_epochs):

    alexnet_model.train()

    running_loss = 0.0

    for images, labels in data_loader:

        images, labels = images.to(device), labels.to(device)

前向传播

        outputs = alexnet_model(images)

        loss = criterion(outputs, labels)

反向传播和优化

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()

        running_loss += loss.item()

在每个epoch结束后评估模型

    train_accuracy = evaluate_model(alexnet_model, data_loader, device)

    print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {running_loss / len(data_loader):.4f}, Train Accuracy: {train_accuracy:.2f}%')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/705862.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Ubuntu22.04 下 pybind11 搭建,示例

Pybind11 是一个轻量级的库,用于在 C 中创建 Python 绑定。Ubuntu22下安装pybind11步骤如下: 1. 安装 pybind11 1.1 pip 命令安装 pip3 install pybind11 1.2 源代码安装 安装依赖库: sudo pip install -i https://pypi.tuna.tsinghua.e…

AVR晶体管测试仪开源项目编译

AVR晶体管测试仪开源项目编译 📍原项目地址:https://github.com/Mikrocontroller-net/transistortester/tree/master🌿 https://github.com/svn2github/transistortester🌿 https://github.com/wagiminator/ATmega-Transistor-Tes…

2. Revit API UI 之 IExternalCommand 和 IExternalApplication

2. Revit API UI 之 IExternalCommand 和 IExternalApplication 上一篇我们大致看了下 RevitAPI 的一级命名空间划分,再简单讲了一下Attributes命名空间下的3个类,并从一个代码样例,提到了Attributes和IExternalCommand ,前者是指…

vite配置unocss

在vue3vitetseslintprettierstylelinthuskylint-stagedcommitlintcommitizencz-git介绍了关于vitevue工程化搭建,现在在这个基础上,我们增加一下unocss unocss官方文档 具体开发中使用遇到的问题可以参考不喜欢原子化CSS得我,还是在新项目中使…

NumPy和数组

1.NumPy是什么 NumPy(Numerical Python的缩写)是一个开源的Python科学计算模块,其中包含了许多实用的数学函数,用来处理数值型数据。NumPy中,最重要和使用最频繁的对象就是N维数组。 为什么要学习NumPy? …

Java高级技术探索:深入理解JVM内存分区与GC机制

文章目录 引言JVM内存分区概览垃圾回收机制(GC)GC算法基础常见垃圾回收器ParNew /Serial old 收集器运行示意图 优化实践结语 引言 Java作为一门广泛应用于企业级开发的编程语言,其背后的Java虚拟机(JVM)扮演着至关重…

TikTok Ads广告综合指南:竞价策略及效果建议

作为全球最受欢迎的应用程序之一,TikTok不仅为用户提供了记录分享生活中美好时刻、交流全球创意的平台,也给全球的企业提供了一个直接触达用户的平台。随着Z时代用户人群的购买力不断上升,出海广告主们也逐渐将目光放在TikTok方面的营销。 上…

【Linux系统编程】线程

Linux线程 文章目录 Linux线程1.进程与线程区别2.线程优点3.API概要4.线程1.线程的创建2.线程等待内存共享验证3.线程退出关于对void** &的理解拓展 4.互斥锁1.创建及销毁互斥锁2.加锁及解锁 5.什么情况下会造成死锁6.条件**1. 创建及销毁条件变量****2. 等待****3. 触发**…

基于大数据的主流电商平台获取商品详情数据SKU数据价格数据

主流电商平台:淘宝 1688 闲鱼 京东 唯品会 蘑菇街 一号店 阿里妈妈 阿里巴巴 苏宁 亚马逊 易贝 速卖通 电子元件 网易考拉 洋码头 VVIC MIC Lazada 拼多多 ​ ​​​​​​​关于电商大数据的介绍: 主流电商大数据的采集:电商API接口的接入…

潮玩宇宙大逃杀APP系统开发成品案例分享指南

这是一款多人游戏,玩家需要选择一个房间躲避杀手。满足人数后,杀手会随机挑选一个房间杀掉里面所有的参与者,其他房间的幸存者将平均瓜分被杀房间的元宝。玩家在选中房间后,倒计时结束前可以自由切换不同房间。 软件项目开发成品…

【Linux】进程控制3——进程程序替换

一,前言 创建子进程的目的之一就是为了代劳父进程执行父进程的部分代码,也就是说本质上来说父子进程都是执行的同一个代码段的数据,在子进程修改数据的时候进行写时拷贝修改数据段的部分数据。 但是还有一个目的——将子进程在运行时指向一个…

自动控制原理【期末复习】(二)

无人机上桨之后可以在调试架上先调试: 1.根轨迹的绘制 /// 前面针对的是时域分析,下面针对频域分析: 2.波特图 3.奈维斯特图绘制 1.奈氏稳定判据 2.对数稳定判据 3.相位裕度和幅值裕度

JavaScript的数组排序

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…

Sora和快手可灵背后的核心技术 | 3DVAE:通过小批量特征交换实现身体和面部的三维形状变分自动编码器

【摘要】学习3D脸部和身体生成模型中一个解开的、可解释的和结构化的潜在表示仍然是一个开放的问题。当需要控制身份特征时,这个问题尤其突出。在本文中,论文提出了一种直观而有效的自监督方法来训练一个3D形状变分自动编码器(VAE),以鼓励身份特征的解开潜在表示。通过交换不同…

自学网络安全的三个必经阶段(含路线图)

一、为什么选择网络安全? 这几年随着我国《国家网络空间安全战略》《网络安全法》《网络安全等级保护2.0》等一系列政策/法规/标准的持续落地,网络安全行业地位、薪资随之水涨船高。 未来3-5年,是安全行业的黄金发展期,提前踏入…

Python:基础爬虫

Python爬虫学习(网络爬虫(又称为网页蜘蛛,网络机器人,在FOAF社区中间,更经常的称为网页追逐者),是一种按照一定的规则,自动地抓取万维网信息的程序或者脚本。另外一些不常使用的名字…

上海晋名室外危废品暂存柜助力储能电站行业危废品安全储存

近日又有一台SAVEST室外危废暂存柜项目成功验收交付使用,此次项目主要用于储能电站行业废油、废锂电池等危废品的安全储存。 用户单位在日常工作运营中涉及到废油、废锂电池等危废品的室外安全储存问题。4月中旬用户技术总工在寻找解决方案的过程中搜索到上海晋名的…

uniapp地图自定义文字和图标

这是我的结构&#xff1a; <map classmap id"map" :latitude"latitude" :longitude"longitude" markertap"handleMarkerClick" :show-location"true" :markers"covers" /> 记住别忘了在data中定义变量…

46.Python-web框架-Django - 多语言配置

目录 1.Django 多语言基础知识 1.1什么是Django国际化和本地化&#xff1f; 1.2Django LANGUAGE_CODE 1.3关于languages 1.4RequestContext对象针对翻译的变量 2.windows系统下的依赖 3.django多语言配置 3.1settings.py配置 引用gettext_lazy 配置多语言中间件&#x…

(代数:解一元二次方程)可以使用下面的公式求一元二次方程 ax2+bx+c0 的两个根:

(代数:解一元二次方程)可以使用下面的公式求一元二次方程 ax2bxc0 的两个根: b2-4ac 称作一元二次方程的判别式。如果它是正值,那么一元二次方程就有两个实数根。 如果它为 0&#xff0c;方程式就只有一个根。如果它是负值&#xff0c;方程式无实根。 编写程序&#xff0c;提示…