DataLoader与Dataset

一、人民币二分类在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

二、DataLoader 与 Dataset

DataLoader

torch.utils.data.DataLoader

功能:构建可迭代的数据装载器
(只标注了较为重要的)
• dataset: Dataset类,决定数据从哪读取及如何读取
• batchsize : 批大小
• num_works: 是否多进程读取数据
• shuffle: 每个epoch是否乱序
• drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据

DataLoader(
    dataset,
    batch_size=1,
    shuffle=False,
    sampler=None,
    batch_sampler=None,
    num_workers=0,
    collate_fn=None,
    pin_memory=False,
    drop_last=False,
    timeout=0,
    worker_init_fn=None,
    multiprocessing_context=None
)
  • Epoch: 所有训练样本都已输入到模型中,称为一个Epoch
  • Iteration:一批样本输入到模型中,称之为一个Iteration
  • Batchsize:批大小,决定一个Epoch有多少个Iteration

样本总数:80, Batchsize:8
1 Epoch = 10 Iteration

样本总数:87, Batchsize:8
1 Epoch = 10 Iteration ? drop_last = True
1 Epoch = 11 Iteration ? drop_last = False

根据给定的样本总数和批大小,可以计算出一个Epoch中的Iteration数量。

  1. 样本总数为80,批大小为8:
    • 一个Epoch中的Iteration数量 = 样本总数 / 批大小 = 80 / 8 = 10
  2. 样本总数为87,批大小为8,且设置drop_last = True
    • 一个Epoch中的Iteration数量 = 样本总数 // 批大小 = 87 // 8 = 10
  3. 样本总数为87,批大小为8,且设置drop_last = False
    • 一个Epoch中的Iteration数量 = (样本总数 + 批大小 - 1) // 批大小 = (87 + 8 - 1) // 8 = 11

在第3种情况下,由于样本总数无法被批大小整除,因此在最后一个Epoch中会有一个额外的Iteration来处理剩余的样本。

Dataset

torch.utils.data.Dataset

功能:Dataset抽象类,所有自定义的Dataset需要继承它,并且复写__getitem__()

getitem :接收一个索引,返回一个样本

class Dataset(object):
    def __getitem__(self, index):
        raise NotImplementedError
    def __add__(self, other):
        return ConcatDataset([self, other])

上述代码定义了一个名为Dataset的类,该类是一个抽象基类。它包含了两个特殊方法:

  1. __getitem__(self, index)方法:这是一个抽象方法,需要在子类中实现。它用于根据给定的索引index返回对应的数据样本。在这里,抛出了NotImplementedError异常,表示子类必须覆盖这个方法来提供具体的实现。
  2. __add__(self, other)方法:这是一个特殊方法,用于实现对象的加法操作。在这里,它返回一个ConcatDataset对象,该对象将当前的self和另一个other数据集合并在一起。__add__方法的返回值是一个ConcatDataset对象,表示将当前数据集和另一个数据集进行连接。ConcatDataset是PyTorch中的一个类,用于将多个数据集连接在一起,以便在训练过程中一起使用。

四、模型训练

# -*- coding: utf-8 -*-
"""
# @file name  : train_lenet.py
# @author     : siuserjy
# @date       : 2024-01-03 20:50:38
# @brief      : 人民币分类模型训练
"""
import os

# 获取当前文件的目录路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))

# 导入必要的库和模块
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt

# 定义lenet.py和common_tools.py文件的路径并检查文件是否存在
path_lenet = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "model", "lenet.py"))
path_tools = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "tools", "common_tools.py"))
assert os.path.exists(path_lenet), "{}不存在,请将lenet.py文件放到 {}".format(path_lenet, os.path.dirname(path_lenet))
assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools))

# 将自定义模块所在的目录添加到Python路径中
import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__) + os.path.sep + ".." + os.path.sep + "..")
sys.path.append(hello_pytorch_DIR)

# 从自定义模块导入所需内容
from model.lenet import LeNet
from tools.my_dataset import RMBDataset
from tools.common_tools import set_seed

# 设置随机种子
set_seed()

# 定义人民币数据集的标签
rmb_label = {"1": 0, "100": 1}

# 设置训练参数
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

# 设置数据集路径
split_dir = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "data", "rmb_split"))
if not os.path.exists(split_dir):
    raise Exception(r"数据 {} 不存在, 回到lesson-06\1_split_dataset.py生成数据".format(split_dir))

# 设置训练集和验证集路径
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

# 设置图像的均值和标准差
norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

# 设置训练集的数据预处理
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 将图像大小调整为32x32
    transforms.RandomCrop(32, padding=4),  # 随机裁剪32x32大小的图像
    transforms.ToTensor(),  # 将图像转换为Tensor格式
    transforms.Normalize(norm_mean, norm_std),  # 标准化图像
])

# 设置验证集的数据预处理
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),  # 将图像大小调整为32x32
    transforms.ToTensor(),  # 将图像转换为Tensor格式
    transforms.Normalize(norm_mean, norm_std),  # 标准化图像
])

# 构建训练集和验证集的数据集实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建训练集和验证集的DataLoader
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================
# 构建LeNet模型实例
net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
# 设置损失函数
criterion = nn.CrossEntropyLoss()


# ============================ step 4/5 优化器 ============================
# 设置优化器
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)

# 设置学习率下降策略
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


# ============================ step 5/5 训练 ============================
train_curve = list()  # 记录训练集的loss值
valid_curve = list()  # 记录验证集的loss值

for epoch in range(MAX_EPOCH):  # 迭代训练多个epoch

    loss_mean = 0.  # 记录每个epoch的平均loss值
    correct = 0.  # 记录分类正确的样本数量
    total = 0.  # 记录总样本数量

    net.train()  # 将模型设置为训练模式
    for i, data in enumerate(train_loader):  # 遍历训练集数据

        # forward
        inputs, labels = data  # 获取输入数据和标签
        outputs = net(inputs)  # 将输入数据输入模型,得到输出结果

        # backward
        optimizer.zero_grad()  # 将模型参数的梯度置零
        loss = criterion(outputs, labels)  # 计算损失值
        loss.backward()  # 反向传播,计算梯度

        # update weights
        optimizer.step()  # 更新模型参数

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
        total += labels.size(0)  # 累计总样本数量
        correct += (predicted == labels).squeeze().sum().numpy()  # 累计分类正确的样本数量

        # 打印训练信息
        loss_mean += loss.item()  # 累计每个batch的loss值
        train_curve.append(loss.item())  # 将每个batch的loss值记录下来
        if (i+1) % log_interval == 0:  # 每隔一定的batch数打印一次训练信息
            loss_mean = loss_mean / log_interval  # 计算平均loss值
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.  # 重置loss_mean

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:  # 每隔一定的epoch数进行一次验证

        correct_val = 0.  # 记录验证集分类正确的样本数量
        total_val = 0.  # 记录验证集总样本数量
        loss_val = 0.  # 记录验证集的loss值
        net.eval()  # 将模型设置为评估模式
        with torch.no_grad():  # 不计算梯度
            for j, data in enumerate(valid_loader):  # 遍历验证集数据
                inputs, labels = data  # 获取输入数据和标签
                outputs = net(inputs)  # 将输入数据输入模型,得到输出结果
                loss = criterion(outputs, labels)  # 计算损失值

                _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
                total_val += labels.size(0)  # 累计验证集总样本数量
                correct_val += (predicted == labels).squeeze().sum().numpy()  # 累计验证集分类正确的样本数量

                loss_val += loss.item()  # 累计验证集的loss值

            loss_val_epoch = loss_val / len(valid_loader)  # 计算验证集每个epoch的平均loss值
            valid_curve.append(loss_val_epoch)  # 将验证集每个epoch的平均loss值记录下来
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))

# 绘制训练曲线和验证曲线
train_x = range(len(train_curve))  # 训练曲线的x轴
train_y = train_curve  # 训练曲线的y轴

train_iters = len(train_loader)  # 训练集的迭代次数
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval - 1  # 验证曲线的x轴,将epoch转换为iteration
valid_y = valid_curve  # 验证曲线的y轴

plt.plot(train_x, train_y, label='Train')  # 绘制训练曲线
plt.plot(valid_x, valid_y, label='Valid')  # 绘制验证曲线

plt.legend(loc='upper right')  # 设置图例位置
plt.ylabel('loss value')  # 设置y轴标签
plt.xlabel('Iteration')  # 设置x轴标签
plt.show()  # 显示图像


# ============================ inference ============================


# 设置基本路径
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")

# 创建测试数据集
test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)

# 创建验证数据加载器
valid_loader = DataLoader(dataset=test_data, batch_size=1)

# 遍历验证数据集
for i, data in enumerate(valid_loader):
    # 前向传播
    inputs, labels = data
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)

    # 判断预测结果是1元还是100元
    rmb = 1 if predicted.numpy()[0] == 0 else 100

    # 打印模型获得的金额
    print("模型获得{}元".format(rmb))


在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/288999.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

广告公司项目管理工具推荐:高效管理订单与项目的利器

由于广告公司每年都需要处理数千个制作订单,包括促销广告预告片和海报等,迫切需要一款能够高效管理和生产广告流程的软件。 在实施采购项目管理工具之前,广告公司的流程相当繁复,员工们需要使用约10到15个不同的Excel表格&#xf…

vue-element-admin请求接口时报错:431 Request Header Fields Too Large

vue-element-admin请求接口时报错:431 Request Header Fields Too Large 1、使用vue-element-admin框架开发运行后报错431,在网上也找了一些解决方案,有说是生成的Token太大导致的报错,也说有可能是NGINX后台设置的提交大小限制。…

QT的信号与槽

QT的信号与槽 文章目录 QT的信号与槽前言一、QT 打印"hello QT"的dome二、信号和槽机制?二、信号与槽的用法1、QT5的方式1. 无参的信号与槽的dome2.带参的信号与槽dome 2、QT4的方式3、C11的语法 Lambda表达式1、函数对象参数2、操作符重载函数参数3、可修…

为什么JAVA_HOME修改后Java版本不变

今天的实验需要对java project进行降版本后重构。于是去Oracle官网下载了jdk1.7。然后将系统环境变量JAVA_HOME改成了安装后的jdk1.7路径。即 C:\Program Files\Java\jdk1.7.0_80 系统变量Path中直接引用了%JAVA_HOME%\bin。 但是当我查看版本,却出现了javac改了…

stable diffusion 基础教程-提示词之光的用法

基图 prompt: masterpiece,best quality,1girl,solo,looking at viewer,brown hair,hair between eyes,bangs,very long hair,red eyes,blush,bare shoulders,(white sundress),full body,leaning forward,medium breasts,unbuttoned clothes,Negative prompt: EasyNegativ…

GraalVM Native学习及使用

概述 在开发Spring Boot 应用或者其他JAVA程序的过程中,启动慢、内存占用大是比较头疼的问题,往往需要更多的资源去部署,成本大幅提高。为了优化上述问题,常常使用优化程序、使用更小消耗的JVM、使用容器等措施。 现在有一个叫做…

工作流入门这篇就够了!

总概 定义:工作流是在计算机支持下业务流程的自动或半自动化,其通过对流程进行描述以及按一定规则执行以完成相应工作。 应用:随着计算机技术的发展以及工业生产、办公自动化等领域的需求不断提升,面向事务审批、材料提交、业务…

在Cadence中单独添加或删除器件与修改网络的方法

首先需要在设置中使能 ,添加或修改逻辑选项。 添加或删除器件,点击logic-part,选择需要添加或删除的器件,这里的器件必须是PCB中已经有的器件,Refdes中输入添加或删除的器件标号,点击Add添加。 添加完成后就会显示在R1…

学习笔记240102 --- 表单无法输入,是否data中没有提前声明导致的

前端框架 &#xff1a;vue2.x 第三方ui组件&#xff1a;ElementUI 操作系统&#xff1a;windows 浏览器&#xff1a;chrome 谷歌 问题描述 表单使用中&#xff0c;没有在data中提前声明参数&#xff0c;当数据回显时&#xff0c;表单无法输入 <el-form :model"queryPa…

【力扣100】39.组合总和

添加链接描述 class Solution:def combinationSum(self, candidates: List[int], target: int) -> List[List[int]]:def backtrack(path,target,res,index):if target0:res.append(path[:])returnif target<0:return for i in range(index,len(candidates)):if target&g…

3个.NET开源简单易用的任务调度框架

前言 今天分享3个.NET开源、简单、易用的任务调度框架&#xff0c;帮助大家在做定时任务调度框架技术选型的时候有一个参考。 Quartz.Net Quartz.NET是一个功能齐全的开源作业调度系统&#xff0c;可用于从最小的应用程序到大规模企业系统。 Quartz.NetUI Quartz.NetUI是一…

算法导论复习——CHP25 多源最短路

问题描述 给定一个带权重的有向图G(V,E)&#xff0c;其权重函数为ω:E→R。 在图中&#xff0c;对所有的结点对 u,v∈V&#xff0c;找出从结点u到结点v的最短路径。 该问题的解以表格&#xff08;二维数组&#xff09;的形式给出&#xff1a;第u行第v列给出从结点u到结…

计算机毕业设计 基于SpringBoot的工作量统计系统的设计与实现 Java实战项目 附源码+文档+视频讲解

博主介绍&#xff1a;✌从事软件开发10年之余&#xff0c;专注于Java技术领域、Python人工智能及数据挖掘、小程序项目开发和Android项目开发等。CSDN、掘金、华为云、InfoQ、阿里云等平台优质作者✌ &#x1f345;文末获取源码联系&#x1f345; &#x1f447;&#x1f3fb; 精…

MySQL中的事务到底是怎么一回事儿

简单来说&#xff0c;事务就是要保证一组数据库操作&#xff0c;要么全部成功&#xff0c;要么全部失败。在MySQL中&#xff0c;事务支持是在引擎层实现的&#xff0c;但并不是所有的引擎都支持事务&#xff0c;如MyISAM引擎就不支持事务&#xff0c;这也是MyISAM被InnoDB取代的…

多任务并行处理相关面试题

我自己面试时被问过两次多任务并行相关的问题&#xff1a; 假设现在有10个任务&#xff0c;要求同时处理&#xff0c;并且必须所有任务全部完成才返回结果 这个面试题的难点是&#xff1a; 既然要同时处理&#xff0c;那么肯定要用多线程。怎么设计多线程同时处理任务呢&…

leetcode递归算法题总结

递归本质是找重复的子问题 本章目录 1.汉诺塔2.合并两个有序链表3.反转链表4.两两交换链表中的节点5.Pow(x,n) 1.汉诺塔 汉诺塔 //面试写法 class Solution { public:void hanota(vector<int>& a, vector<int>& b, vector<int>& c) {dfs(a,b…

基于Spring Cloud + Spring Boot的企业电子招标采购系统源码

随着企业的快速发展&#xff0c;招采管理逐渐成为企业运营中的重要环节。为了满足公司对内部招采管理提升的要求&#xff0c;建立一个公平、公开、公正的采购环境至关重要。在这个背景下&#xff0c;我们开发了一款电子招标采购软件&#xff0c;以最大限度地控制采购成本&#…

Python等高线图的绘制(Matplotlib篇-11)

Python等高线图的绘制(Matplotlib篇-11)         🍹博主 侯小啾 感谢您的支持与信赖。☀️ 🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ🌹꧔ꦿ…

Redis(二)

1、redis的持久化 "Redis 如何将数据写入磁盘"&#xff0c;首先要明白的时候&#xff0c;我们使用的redis的数据保存在内存上的&#xff0c;也就是说&#xff0c;只要我们的电脑关机或者重启服务器&#xff0c;那么在内存中的数据就会消失&#xff0c;所以要想持久化…

(一)CarPlay集成开发之概述与环境篇

系列文章目录 第一章 CarPlay集成开发之概述与环境篇 文章目录 系列文章目录概述开发环境依赖项总结 概述 CarPlay是由苹果公司开发的一款集成在iOS系统中&#xff0c;用于运行在已完成对接该系统的汽车中控台&#xff0c;仪表盘上的车载系统&#xff0c;该系统通过USB或者WI…