PyTorch搭建AlexNet训练集

本次项目是使用AlexNet实现5种花类的识别。

训练集搭建与LeNet大致代码差不多,但是也有许多新的内容和知识点。

1.导包,不必多说。

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time

2.指定设备

device函数用来指定在训练过程中所使用的设备:如果有可用的GPU,那么使用第一块GPU,如果没有就默认使用cpu。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

 3.数据预处理函数

单独定义出来,当key为“train”或为“val”时,返回数据集要使用的一系列预处理方法。

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),   # 把图片重新裁剪为224*224
                                 transforms.RandomHorizontalFlip(),  # 水平方向随机翻转
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)])}

4.获取数据集的路径

os.getcwd()方法获取当前文件所在的目录

os.path.join()方法将当前路径与上两级路径链接起来

image_path:获取到flower_data所在路径

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train", # 获取训练集的路径
                                     transform=data_transform["train"])  # 训练预处理
train_num = len(train_dataset)  # 打印训练集有多少张照片

5.加载数据集分类文件 

{'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4} :数据集共分为五类
flower_list = train_dataset.class_to_idx 获取分类的名称所对应的索引值
cla_dict = dict((val, key) for key, val in flower_list.items()) 将字典中键与值的位置对换

?为什么要换位置

=>这样在预测后可以直接通过值给到我们最后的测试类别
json_str = json.dumps(cla_dict, indent=4) :将字典编码成json格式
with open('class_indices,json', 'w') as json_file:
        json_file.write(json_str)  :将键值对保存到json文件中,方便后续在预测时读取信息

下面是生成的json文件

# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:
    json_file.write(json_str)

 6.载入测试集

代码大致与LeNet网络差不多,载入测试集的图片路径需要自己定义并进行预处理。

在使用matplotlib查看图片时,注意修改为batch_size=4,shuffle=True参数。

batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validata_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=batch_size,
                                              shuffle=False, num_workers=0)

暂时的全部代码,训练集还没有完全实现,我后续会补充上的,因为课真的是太多了。

import torch
import torch.nn as nn
from torchvision import transforms, datasets, utils
import matplotlib as plt
import numpy as np
import torch.optim as optim
from model import AlexNet
import os
import json
import time

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

data_transform = {
    "train": transforms.Compose([transforms.RandomResizedCrop(224),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
    "val": transforms.Compose([transforms.Resize((224, 224)),
                               transforms.ToTensor(),
                               transforms.Normalize(0.5, 0.5, 0.5), (0.5, 0.5, 0.5)])}

data_root = os.path.abspath(os.path.join(os.getcwd(), "../.."))
image_path = data_root + "/data_set/flower_data"
# train set
train_dataset = datasets.ImageFolder(root=image_path + "/train",
                                     transform=data_transform["train"])
train_num = len(train_dataset)


# {'daisy': 0, 'dandelion': 1, 'roses': 2, 'sunflower': 3, 'tulips': 4}
flower_list = train_dataset.class_to_idx
cla_dict = dict((val, key) for key, val in flower_list.items())
# 把文件写入接送文件
json_str = json.dumps(cla_dict, indent=4)
with open('class_indices,json', 'w') as json_file:
    json_file.write(json_str)


batch_size = 32
train_loader = torch.utils.data.DataLoader(train_dataset,
                                           batch_size=batch_size, shuffle=True,
                                           num_workers=0)
#
validate_dataset = datasets.ImageFolder(root=image_path + "/val",
                                        transform=data_transform["val"])
val_num = len(validate_dataset)
validata_loader = torch.utils.data.DataLoader(validate_dataset, batch_size=4,
                                              shuffle=True, num_workers=0)

学习碎碎念:

学习的道路上总会是遇到困难和麻烦的,不要心急,不要烦躁,一步一步的解决问题,慢慢来总会好的!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/449412.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

NFTScan | 03.04~03.10 NFT 市场热点汇总

欢迎来到由 NFT 基础设施 NFTScan 出品的 NFT 生态热点事件每周汇总。 周期:2024.03.04~ 2024.03.10 NFT Hot News 01/ 数据:比特币链上 NFT 过去 24 小时销售额超 3100 万美元 3 月 4 日,据数据显示,比特币链上 NFT 过去 24 小…

设计模式十:原型模式

文章目录 1、原型模式1.1 类创建过程1.2 浅拷贝1.3 深拷贝 2、示例2.1 简单形式2.2 复杂形式 3、spring中的原型模式3.1 ArrayList的原型模式3.2 spring中的原型模式 1、原型模式 原型模式就是从一个对象再创建另外一个可定制的对象, 而且不需要知道任何创建的细节。…

Vscode+QT+Python

参考链接:VSCodePyQt之Python界面编写_vscode编写图形化界面-CSDN博客 1.安装库 pip install PyQt5 pip install PyQt5-tools pip install qt5_applications 2.在VSCode里下载并安装PYQT Integration 3.配置pyqt integration 4.打开qt designer 在工程文件的空白…

python自动化之pytest框架以及数据驱动(第五天)

1.pytest框架需要遵循的规则 (1).py 测试文件必须以test 开头(或者以 test结尾) (2)测试类必须以Test开头,并且不能有 init 方法 (3)测试方法必须以test 开头 (4)断言…

分享个好用的GPT网站

目录 一、背景 二、功能描述 1、写代码 2、联网查询 3、AI绘图 一、背景 我现在的开发工作都依靠ChatGPT,效率提升了好几倍。这样一来,我有更多时间来摸鱼,真是嘎嘎香~ ⭐⭐⭐点击直达 ⭐⭐⭐ 二、功能描述 1、写代码 import java.ut…

机器学习之分类回归模型(决策数、随机森林)

回归分析 回归分析属于监督学习方法的一种,主要用于预测连续型目标变量,可以预测、计算趋势以及确定变量之间的关系等。 Regession Evaluation Metrics 以下是一些最流行的回归评估指标: 平均绝对误差(MAE):目标变量的预测值与实际值之间的平均绝对差…

基于PHP+Amaze+JQuery的学习论坛的设计与实现1.99

摘 要 互联网教育服务是在互联网技术、通信技术、计算机技术不断发展融合的基础之上,人们在对以信息为基础的各种各样应用需求快速增长的激励之下,在现在社会信息化的水平日益提高前提之下,迅速发展起来的一种全新大众服务方式。 笔者拟设计…

前端食堂技术周刊第 115 期:Rolldown 正式开源、马斯克宣布 xAI 本周将开源 Grok、如何使用 Copilot 完成 50% 的日常工作?

美味值:🌟🌟🌟🌟🌟 口味:手打柠檬茶 食堂技术周刊仓库地址:https://github.com/Geekhyt/weekly 大家好,我是童欧巴。欢迎来到前端食堂技术周刊,我们先来看…

Docker的安装及镜像加速的配置

文章目录 一.切换到root二.卸载旧版docker三.配置docker的yum库四.安装Docker五.Docker的启动和验证六.配置Docker阿里云镜像加速(全程免费) 该文章文章演示在Linux系统中安装docker,Windows安装docker请参考以下文章 Windows系统中安装docker及镜像加速的配置 一…

基于android的物业管理系统的设计与实现19.8

目录 基于android的物业管理系统的设计与实现 3 摘 要 3 Android property managemengt system 5 Abstract 5 1 绪论 6 1.1 选题背景 6 1.2 课题研究现状 6 1.3 设计研究主要内容 7 1.4 系统主要设计思想 8 2 开发环境 8 2.1 Android系统的结构 8 图2-1 Android系统架构图 9 2…

kibana新增查看更新删除es中的数据

登录kibana,打开开发工具 写入数据 PUT test20240311/person/1 {"name": "张三","mobile":"13011111111" } 查询数据 GET /test20240311/person/_search {"query": {"term": {"mobile": {…

中科数安|公司办公终端、电脑文件数据 \ 资料防泄密系统

#中科数安# 中科数安是一家专注于信息安全技术与产品研发的高新技术企业,其提供的公司办公终端、电脑文件数据及资料防泄密系统(也称为终端数据防泄漏系统或简称DLP系统)主要服务于企业对内部敏感信息的安全管理需求。 www.weaem.com 该系统…

ffmpeg日记4001-原理介绍-视频切割原理

原理 打开输入---->打开输出---->根据输入来创建流---->拷贝流设置---->循环读帧---->判断时间点是否到达切割点,并做设置---->设置pts和dts---->写入---->善后 重点是pts和dts如何设置。参考《ffmpeg学习日记25-pts,dts概念的…

HBase非关系型数据库

HBase非关系型数据库 1 什么是HBase2 HBase的特点3 什么时候需要HBase4 HBase的数据模型5 HBase架构5.1 架构5.2 HBase如何列式储存 6 如何正确设计RowKey 1 什么是HBase HBase – Hadoop Database,是一个高可靠性、高性能、面向列、可伸缩、 实时读写的分布式数据…

Java并发编程: AQS

文章目录 一、前置知识二、什么是AQS三、使用AQS框架的锁和同步器1、ReentrantLock2、ReentrantReadWriteLock3、CountDownLatch4、CyclicBarrier5、Semaphore:信号量 四、锁和同步器的关系1、锁:面向锁的使用者2、同步器:面向锁的实现者 五、…

Material UI 5 学习03-Text Field文本输入框

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 Text Field文本输入框 一、最基本的本文输入框1、基础示例2、一些表单属性3、验证 二、多行文本 一、最基本的本文输入框 1、基础示例 import {Box, TextField} from "…

九、ELMo 语言模型

ELMo(Embeddings from Language Models)兼顾了两个问题:一是词语用法在语义和语法上的复杂特点;二是随着语言环境的改变,这些用法也应该随之改变,解决多义词的问题。 ELMo 语言模型原理图: ELMo…

Matlab如何批量读取Excel数据?科研效率UpUp第3期

上一篇文章中,讲了如何批量统计一组Excel数据中多个站位所有物种的数量之和(Matlab如何高效统计多站数据中各站目标总数?科研效率UpUp第2期)。 进一步,假如我们有多组Excel数据,也就是多个Excel表格&#…

【历年案例分析真题考点汇总】与【专栏文章案例分析高频考点目录】(2024年软考高级系统架构设计师冲刺知识点总结-案例分析篇-先导篇)

专栏系列文章推荐: 2024高级系统架构设计师备考资料(高频考点&真题&经验)https://blog.csdn.net/seeker1994/category_12601310.html 案例分析篇01:软件架构设计考点架构风格及质量属性(2024年软考高级系统…

鸿蒙Harmony应用开发—ArkTS声明式开发(基础手势:Navigation)

Navigation组件是路由导航的根视图容器,一般作为Page页面的根容器使用,其内部默认包含了标题栏、内容区和工具栏,其中内容区默认首页显示导航内容(Navigation的子组件)或非首页显示(NavDestination的子组件…