Pytorch第一部分数据模块

数据划分:

从数据集中将数据划分为训练集,测试集,验证集

# -*- coding: utf-8 -*-
"""
# @file name  : 1_split_dataset.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 将数据集划分为训练集,验证集,测试集
"""

import os
import random
import shutil


def makedir(new_dir):
    if not os.path.exists(new_dir):
        os.makedirs(new_dir)


if __name__ == '__main__':

    random.seed(1)

    dataset_dir = "F:\\depthlearning data\\RMB_data"
    split_dir = "F:\\depthlearning data\\rmb_split"
    train_dir = os.path.join(split_dir, "train")
    valid_dir = os.path.join(split_dir, "valid")
    test_dir = os.path.join(split_dir, "test")

    train_pct = 0.8
    valid_pct = 0.1
    test_pct = 0.1

    for root, dirs, files in os.walk(dataset_dir):
        for sub_dir in dirs:

            imgs = os.listdir(os.path.join(root, sub_dir))
            imgs = list(filter(lambda x: x.endswith('.jpg'), imgs))
            random.shuffle(imgs)
            img_count = len(imgs)

            train_point = int(img_count * train_pct)
            valid_point = int(img_count * (train_pct + valid_pct))

            for i in range(img_count):
                if i < train_point:
                    out_dir = os.path.join(train_dir, sub_dir)
                elif i < valid_point:
                    out_dir = os.path.join(valid_dir, sub_dir)
                else:
                    out_dir = os.path.join(test_dir, sub_dir)

                makedir(out_dir)

                target_path = os.path.join(out_dir, imgs[i])
                src_path = os.path.join(dataset_dir, sub_dir, imgs[i])

                shutil.copy(src_path, target_path)

            print('Class:{}, train:{}, valid:{}, test:{}'.format(sub_dir, train_point, valid_point-train_point,
                                                                 img_count-valid_point))

整体代码:人民币二分类训练,这里只关注数据部分

# -*- coding: utf-8 -*-
"""
# @file name  : train_lenet.py
# @author     : tingsongyu
# @date       : 2019-09-07 10:08:00
# @brief      : 人民币分类模型训练
"""
import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from model.lenet import LeNet
from tools.my_dataset import RMBDataset


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            loss_val_epoch = loss_val / len(valid_loader)
            valid_curve.append(loss_val_epoch)
            # valid_curve.append(loss.item())    # 20191022改,记录整个epoch样本的loss,注意要取平均
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val_epoch, correct_val / total_val))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

# ============================ inference ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")

test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)

for i, data in enumerate(valid_loader):
    # forward
    inputs, labels = data
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)

    rmb = 1 if predicted.numpy()[0] == 0 else 100
    print("模型获得{}元".format(rmb))






数据部分:

# ============================ step 1/5 数据 ============================

#读取数据路径
split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")


norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

#训练集 数据预处理  缩放 裁剪 转换为张量 
train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])
#验证集 少了裁剪的方法
valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例  传入数据路径 数据预处理
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

dataloader分为sampler(索引)和dataset(标签)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/563365.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Gamba:将高斯溅射与Mamba结合用于单视图3D重建

Gamba: Marry Gaussian Splatting with Mamba for Single-View 3D Reconstruction Gamba&#xff1a;将高斯溅射与Mamba结合用于单视图3D重建 Qiuhong Shen11  Xuanyu Yi31 Zike Wu31  Pan Zhou2,42 Hanwang Zhang3,5 沈秋红 1 易轩宇 3 吴子可 3 潘周 2,4 2 张汉旺 3,5Shu…

验证线缆(汽车线束、网线、多芯线)破损或断开与正常线缆的区别在哪里?依AEM CV-100 k50测试仪

工厂产线生产的线缆&#xff08;汽车线束、网线、多芯线&#xff09;做成成品&#xff0c;即2端都安装好了模块。在这种情况下如何快速的判定此条线缆是合格的呢&#xff0c;此处的合格为物理层面上的合格&#xff08;不会出现开路、短路&#xff09;&#xff0c;也就是最基本保…

【LAMMPS学习】八、基础知识(3.9)输出结构化数据

8. 基础知识 此部分描述了如何使用 LAMMPS 为用户和开发人员执行各种任务。术语表页面还列出了 MD 术语&#xff0c;以及相应 LAMMPS 手册页的链接。 LAMMPS 源代码分发的 examples 目录中包含的示例输入脚本以及示例脚本页面上突出显示的示例输入脚本还展示了如何设置和运行各…

android开发 多进程的基本了解

目录 如何开启多进程?理解多进程模式的运行机制 如何开启多进程? 给四大组件在androidMenifest中指定android:precess <activityandroid:name".ThreeActivity"android:exported"false"android:process"com.my.process.three.remote" />…

冒泡排序c++

题目描述 编程输入n(1≤n≤20)个小于1000非负整数&#xff0c;然后自动按从大到小的顺序输出。&#xff08;冒泡排序&#xff09; 输入 第一行&#xff0c;数的个数n; 第二行&#xff0c;n个非负整数。 输出 由大到小的n个非负整数&#xff0c;每个数占一行。 样例输入 …

C++异步编程小论

目录 std::async与std::future 其他 std::package_task std::promise Reference 浅论&#xff1a;我看有人写的浅论异步编程的文章实际上在干的是介绍多线程&#xff0c;这里刚好最近对异步编程有所兴趣&#xff1a;我们来看看几个C11新加进来的一些异步编程关键字。 这里…

揭开ChatGPT面纱(3):使用OpenAI进行文本情感分析(embeddings接口)

文章目录 一、embeddings接口解析二、代码实现1.数据集dataset.csv2.代码3.运行结果 openai版本1.6.1 本系列博客源码仓库&#xff1a;gitlab&#xff0c;本博客对应文件夹03 在这一篇博客中我将使用OpenAI的embeddings接口判断21条服装评价是否是好评。 首先来看实现思路&am…

视频教程下载:用ChatGPT的 API 开发AI应用指南

通过这门关于 OpenAI API 和 ChatGPT API 的全面课程&#xff0c;在您的应用中释放人工智能的力量。随着人工智能技术的快速发展&#xff0c;比以往任何时候都更重要的是保持领先地位&#xff0c;并为您的项目利用这些尖端工具。在本课程中&#xff0c;您将深入了解人工智能驱动…

每日论文推荐:Prismatic VLMs VLM设计经验总结

&#x1f4cc; 元数据概览&#xff1a; 标题&#xff1a;“Prismatic VLMs: Investigating the Design Space of Visually-Conditioned Language Models”作者&#xff1a;Siddharth Karamcheti, Suraj Nair, Ashwin Balakrishna, Percy Liang, Thomas Kollar, Dorsa Sadigh&a…

LLM学习笔记-2

在未标记数据上进行预训练 本章概要 在上节的笔记中&#xff0c;因为训练出的效果&#xff0c;并不是特别理想&#xff0c;在本节中&#xff0c;会用数据进行训练&#xff0c;使得模型更加的好&#xff1b; 计算文本生成损失 inputs torch.tensor([[16833, 3626, 6100],…

SpringBoot 操作 Redis

导入对应版本的依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-data-redis</artifactId> </dependency>修改配置文件中的信息 spring:redis:host: 127.0.0.1port: 8888注意: 我这里 xsh…

springboot注解开发如何映射对象型数据

创作灵感 最近在帮学校写一款小程序时&#xff0c;有这样一个数据需要展示&#xff1a;一条申请记录&#xff0c;里面包含了申请时间、申请状态、申请所提供的六条活动记录等待&#xff0c;其中&#xff0c;申请所提供的六条活动记录为一个数组&#xff0c;数组中的每个元素又…

面板数据门槛归回分析,xthreg的安装,xthreg2安装包

我用的是Stata17,数据是不平衡的面板数据,需要用到xthreg2,虽然我找到了xthreg2.ado,但是还需要安装xthreg,因为运行xthreg2需要xthreg包顺带安装的lxthreg.mlib文件。但是!我后来发现还是不行,最后是去买了一个真正能用到xthreg2的lxthreg.mlib文件,才可以运行。 一、…

[图解]软件开发中的糊涂用语-04-为什么要追究糊涂用语

0 00:00:00,030 --> 00:00:05,620 今天呢&#xff0c;我们来说一个为什么要追究糊涂用语的问题 1 00:00:06,310 --> 00:00:06,548 2 00:00:06,548 --> 00:00:11,077 大家知道我们前些天都发了好几个视频 3 00:00:11,077 --> 00:00:13,461 追究这个糊涂用语 4 00…

2024免费专为Mac用户设计的清理和优化工具CleanMyMac X

CleanMyMac X是一款专为Mac用户设计的清理和优化工具。以下是对CleanMyMac X的详细介绍&#xff1a; 一、主要功能 系统清理&#xff1a;CleanMyMac X能够智能扫描Mac的磁盘空间&#xff0c;识别并清理各种垃圾文件&#xff0c;这些垃圾文件包括重复文件、无用的语言安装包、i…

五.音视频编辑-音频混合-应用

引言 音频混合技术是一项强大的工具&#xff0c;可以为应用程序增添丰富的功能和用户体验。在前一篇博客中&#xff0c;我们深入探讨了AVFoundation框架中的音频混合基础知识和实现方法。现在&#xff0c;让我们进一步探索&#xff0c;看看如何将这些技术应用到实际项目中。 …

Django项目无法安装python-ldap依赖解决方案

最近工作中安排了一个Python web项目&#xff0c;使用Pycharm从git拉取代码后&#xff0c;配置号Python的解释器和pip后&#xff0c;Pycharm自动下载安装项目所需的依赖&#xff0c;但是有一个依赖django-auth-ldap4.1.0安装始终失败&#xff0c;最初的异常信息提示是&#xff…

Java客户端如何直接调用es的API

Java客户端如何直接调用es的API 一. 问题二. withJson 前言 这是我在这个网站整理的笔记,有错误的地方请指出&#xff0c;关注我&#xff0c;接下来还会持续更新。 作者&#xff1a;神的孩子都在歌唱 一. 问题 今天做项目的时候&#xff0c;想要直接通过java客户端调用es的api…

解决Django中调页面时出现“Did you forget to register or load this tag”报错

解决Django中调页面时出现“Did you forget to register or load this tag?”报错 1.问题收录 2.分析问题 在HTML文件中&#xff0c;{{title}}&#xff0c;{{lanyy}}&#xff0c;django 默认规定的语法&#xff0c;用{{}}包起来的变量叫做模板变量。 django渲染模板时会将大…

2.1K Star微软开源的高质量 iot库

功能描述 该项目是一个开源的 .NET Core 实现&#xff0c;旨在帮助开发者构建适用于物联网(IoT)设备和场景的应用程序。它提供了与传感器、显示器和输入设备等相互作用所需的 GPIO 引脚、串口等硬件的接口。该仓库包含 System.Device.Gpio 库以及针对各种板卡&#xff08;如 Ra…