pytorch03:transforms常见数据增强操作

目录

  • 一、数据增强
  • 二、transforms--Crop裁剪
    • 2.1 transforms.CenterCrop
    • 2.2 transforms.RandomCrop
    • 2.3 RandomResizedCrop
    • 2.4 FiveCrop和TenCrop
  • 三、transforms—Flip翻转、旋转
    • 3.1RandomHorizontalFlip和RandomVerticalFlip
    • 3.2 RandomRotation
  • 四、transforms —图像变换
    • 4.1 transforms.Pad
    • 4.2 transforms.ColorJitter
    • 4.3 Grayscale和RandomGrayscale
    • 4.4 RandomAffine
    • 4.5 RandomErasing
  • 五、transforms的操作
    • 5.1 transforms.RandomChoice
    • 5.2 transforms.RandomApply
    • 5.3 transforms.RandomOrder
  • 六、自定义transforms
    • 6.1 自定义transforms要素
    • 6.2 通过类实现多参数传入
    • 6.3 椒盐噪声
    • 6.4 自定义transforms代码实现
  • 七、数据增强策略
    • 数据增强代码实现

一、数据增强

   数据增强又称为数据增广,数据扩增,它是对训练集进行变换,使训练集更丰富,从而让模型更具泛化能力。如下是对一张图片常见的增强操作例如:旋转、裁剪、像素抖动。
在这里插入图片描述

二、transforms–Crop裁剪

2.1 transforms.CenterCrop

功能:从图像中心裁剪图片
• size:所需裁剪图片尺寸

2.2 transforms.RandomCrop

功能:从图片中随机裁剪出尺寸为size的图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• padding:设置填充大小
  当为a时,上下左右均填充a个像素,
  当为(a, b)时,上下填充b个像素,左右填充a个像素,
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• pad_if_need:若图像小于设定size,则填充
• padding_mode:填充模式,有4种模式
  1、constant:像素值由fill设定
  2、edge:像素值由图像边缘像素决定
  3、reflect:镜像填充,最后一个像素不镜像,eg:[1,2,3,4] → [3,2,1,2,3,4,3,2]
  4、symmetric:镜像填充,最后一个像素镜像,eg:[1,2,3,4] → [2,1,1,2,3,4,4,3]
• fill:constant时,设置填充的像素值

2.3 RandomResizedCrop

功能:随机大小、长宽比裁剪图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• scale:随机裁剪面积比例, 默认(0.08, 1)
• ratio:随机长宽比,默认(3/4, 4/3)
• interpolation:插值方法
PIL.Image.NEAREST
PIL.Image.BILINEAR
PIL.Image.BICUBIC

2.4 FiveCrop和TenCrop

  功能:在图像的上下左右以及中心裁剪出尺寸为size的5张图片,TenCrop对这5张图片进行水平或者垂直镜像获得10张图片
在这里插入图片描述

• size:所需裁剪图片尺寸
• vertical_flip:是否垂直翻转

三、transforms—Flip翻转、旋转

3.1RandomHorizontalFlip和RandomVerticalFlip

在这里插入图片描述

功能:依概率水平(左右)或垂直(上下)翻转图片
• p:翻转概率

3.2 RandomRotation

功能:随机旋转图片
在这里插入图片描述
在这里插入图片描述

• degrees:旋转角度
  当为a时,在(-a,a)之间选择旋转角度
  当为(a, b)时,在(a, b)之间选择旋转角度
• resample:重采样方法
• expand:是否扩大图片,以保持原图

四、transforms —图像变换

4.1 transforms.Pad

功能:对图片边缘进行填充
在这里插入图片描述
• padding:设置填充大小
  当为a时,上下左右均填充a个像素
  当为(a, b)时,上下填充b个像素,左右填充a个像素
  当为(a, b, c, d)时,左,上,右,下分别填充a, b, c, d
• padding_mode:填充模式,有4种模式,constant、edge、reflect和symmetric
• fill:constant时,设置填充的像素值,(R, G, B) or (Gray)

4.2 transforms.ColorJitter

功能:调整亮度、对比度、饱和度和色相
在这里插入图片描述

• brightness:亮度调整因子
  当为a时,从[max(0, 1-a), 1+a]中随机选择
  当为(a, b)时,从[a, b]中
• contrast:对比度参数,同brightness
• saturation:饱和度参数,同brightness
• hue:色相参数,当为a时,从[-a, a]中选择参数,注: 0<= a <= 0.5
        当为(a, b)时,从[a, b]中选择参数,注:-0.5 <= a <= b <= 0.5

4.3 Grayscale和RandomGrayscale

功能:依概率将图片转换为灰度图
在这里插入图片描述
• num_ouput_channels:输出通道数只能设1或3
• p:概率值,图像被转换为灰度图的概率

4.4 RandomAffine

功能:对图像进行仿射变换,仿射变换是二维的线性变换,由五种基本原子变换构成,分别是旋转、平移、缩放、错切和翻转
在这里插入图片描述
在这里插入图片描述
• degrees:旋转角度设置
• translate:平移区间设置,如(a, b), a设置宽(width),b设置高(height)
    图像在宽维度平移的区间为 -img_width * a < dx < img_width * a
• scale:缩放比例(以面积为单位)
• fill_color:填充颜色设置

4.5 RandomErasing

功能:对图像进行随机遮挡
在这里插入图片描述

• p:概率值,执行该操作的概率
• scale:遮挡区域的面积
• ratio:遮挡区域长宽比
• value:设置遮挡区域的像素值,(R, G, B) or (Gray)

五、transforms的操作

5.1 transforms.RandomChoice

功能:从一系列transforms方法中随机挑选一个

transforms.RandomChoice([transforms1, transforms2, transforms3])

5.2 transforms.RandomApply

功能:依据概率执行一组transforms操作

transforms.RandomApply([transforms1, transforms2, transforms3], p=0.5)

5.3 transforms.RandomOrder

功能:对一组transforms操作打乱顺序

transforms.RandomOrder([transforms1, transforms2, transforms3])

六、自定义transforms

6.1 自定义transforms要素

1.仅接收一个参数,返回一个参数
2.注意上下游的输出与输入
当前transforms的输入是上一个transforms的输出,所以要保证数据类型匹配:
在这里插入图片描述

6.2 通过类实现多参数传入

在这里插入图片描述

在Python中,__call__是一个特殊的方法,用于使一个对象可以像函数一样被调用。如果一个类定义了__call__方法,那么实例化的对象就可以被当作函数一样调用,而调用的实际上是__call__方法。

class CallableClass:
    def __init__(self):
        print("Initializing the CallableClass")

    def __call__(self, *args, **kwargs):
        print("Calling the CallableClass with arguments:", args, kwargs)

# 实例化对象
obj = CallableClass()

# 调用对象,实际上调用了__call__方法
obj(1, 2, 3, keyword_arg="hello")

上面的例子中,CallableClass定义了__call__方法,这意味着实例obj可以像函数一样被调用。当你调用obj(1, 2, 3, keyword_arg=“hello”)时,实际上是在调用obj.call(1, 2, 3, keyword_arg=“hello”)。

6.3 椒盐噪声

椒盐噪声又称为脉冲噪声,是一种随机出现的白点或者黑点, 白点称为盐噪声,黑色为椒噪声
信噪比(Signal-Noise Rate, SNR)是衡量噪声的比例,图像中为图像像素的占比,从下图可以看出,信噪比越小,图片丢失的像素越多。
在这里插入图片描述

6.4 自定义transforms代码实现

class AddPepperNoise(object):
    """增加椒盐噪声

    Args:
        snr (float): Signal Noise Rate 信噪比
        p (float): 概率值,依概率执行该操作

    Attributes:
        snr (float): 信噪比
        p (float): 操作执行的概率
    """

    def __init__(self, snr, p=0.9):
        # 确保传入的snr和p是float类型
        assert isinstance(snr, float) and isinstance(p, float)
        self.snr = snr
        self.p = p

    def __call__(self, img):
        """
        对图像应用椒盐噪声操作。

        Args:
            img (PIL Image): PIL Image对象

        Returns:
            PIL Image: 处理后的PIL Image对象
        """
        # 根据概率决定是否执行噪声操作
        if random.uniform(0, 1) < self.p:
            img_ = np.array(img).copy()
            h, w, c = img_.shape
            signal_pct = self.snr
            noise_pct = (1 - self.snr)
            
            # 生成噪声掩码,表示每个像素是原始图像、盐噪声还是椒噪声
            mask = np.random.choice((0, 1, 2), size=(h, w, 1),
                                    p=[signal_pct, noise_pct / 2., noise_pct / 2.])
            mask = np.repeat(mask, c, axis=2)
            
            # 根据噪声类型修改图像像素值
            img_[mask == 1] = 255  # 盐噪声
            img_[mask == 2] = 0    # 椒噪声
            
            # 将NumPy数组转换回PIL Image对象,并确保数据类型为uint8,颜色通道为RGB
            return Image.fromarray(img_.astype('uint8')).convert('RGB')
        else:
            return img

在这里插入图片描述

七、数据增强策略

原则:让训练集与测试集更接近可以使用下面这些方法
• 空间位置:平移
• 色彩:灰度图,色彩抖动
• 形状:仿射变换
• 上下文场景:遮挡,填充

例如我们训练集白猫比较多,可以改变白猫色彩,让白猫的颜色接近黑猫。
在这里插入图片描述

数据增强代码实现

要求:使用第四套RMB进行训练,要求能对第5套RMB识别正确。

我们只进行普通的图片处理训练好的模型,发现将第五套100元都识别成一元,因为第四套人民币的1元和第五套人民的100元颜色相近,所以会导致识别错误:
在这里插入图片描述
解决方法,将所有训练集颜色都进行灰度处理,代码修改如下:

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.9),  #图片灰度化
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

修改后的预测结果如下:
在这里插入图片描述
训练完整代码如下:

# -*- coding: utf-8 -*-

import os
import random
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torch.optim as optim
from matplotlib import pyplot as plt
from lenet import LeNet
from my_dataset import RMBDataset
from common_tools import transform_invert


def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)


set_seed()  # 设置随机种子
rmb_label = {"1": 0, "100": 1}

# 参数设置
MAX_EPOCH = 10
BATCH_SIZE = 16
LR = 0.01
log_interval = 10
val_interval = 1

# ============================ step 1/5 数据 ============================

split_dir = os.path.join("..", "..", "data", "rmb_split")
train_dir = os.path.join(split_dir, "train")
valid_dir = os.path.join(split_dir, "valid")

norm_mean = [0.485, 0.456, 0.406]
norm_std = [0.229, 0.224, 0.225]

train_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomCrop(32, padding=4),
    transforms.RandomGrayscale(p=0.9),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])


valid_transform = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.ToTensor(),
    transforms.Normalize(norm_mean, norm_std),
])

# 构建MyDataset实例
train_data = RMBDataset(data_dir=train_dir, transform=train_transform)
valid_data = RMBDataset(data_dir=valid_dir, transform=valid_transform)

# 构建DataLoder
train_loader = DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
valid_loader = DataLoader(dataset=valid_data, batch_size=BATCH_SIZE)

# ============================ step 2/5 模型 ============================

net = LeNet(classes=2)
net.initialize_weights()

# ============================ step 3/5 损失函数 ============================
criterion = nn.CrossEntropyLoss()                                                   # 选择损失函数

# ============================ step 4/5 优化器 ============================
optimizer = optim.SGD(net.parameters(), lr=LR, momentum=0.9)                        # 选择优化器
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)     # 设置学习率下降策略

# ============================ step 5/5 训练 ============================
train_curve = list()
valid_curve = list()

for epoch in range(MAX_EPOCH):

    loss_mean = 0.
    correct = 0.
    total = 0.

    net.train()
    for i, data in enumerate(train_loader):

        # forward
        inputs, labels = data
        outputs = net(inputs)

        # backward
        optimizer.zero_grad()
        loss = criterion(outputs, labels)
        loss.backward()

        # update weights
        optimizer.step()

        # 统计分类情况
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).squeeze().sum().numpy()

        # 打印训练信息
        loss_mean += loss.item()
        train_curve.append(loss.item())
        if (i+1) % log_interval == 0:
            loss_mean = loss_mean / log_interval
            print("Training:Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, i+1, len(train_loader), loss_mean, correct / total))
            loss_mean = 0.

    scheduler.step()  # 更新学习率

    # validate the model
    if (epoch+1) % val_interval == 0:

        correct_val = 0.
        total_val = 0.
        loss_val = 0.
        net.eval()
        with torch.no_grad():
            for j, data in enumerate(valid_loader):
                inputs, labels = data
                outputs = net(inputs)
                loss = criterion(outputs, labels)

                _, predicted = torch.max(outputs.data, 1)
                total_val += labels.size(0)
                correct_val += (predicted == labels).squeeze().sum().numpy()

                loss_val += loss.item()

            valid_curve.append(loss_val)
            print("Valid:\t Epoch[{:0>3}/{:0>3}] Iteration[{:0>3}/{:0>3}] Loss: {:.4f} Acc:{:.2%}".format(
                epoch, MAX_EPOCH, j+1, len(valid_loader), loss_val, correct / total))


train_x = range(len(train_curve))
train_y = train_curve

train_iters = len(train_loader)
valid_x = np.arange(1, len(valid_curve)+1) * train_iters*val_interval # 由于valid中记录的是epochloss,需要对记录点进行转换到iterations
valid_y = valid_curve

plt.plot(train_x, train_y, label='Train')
plt.plot(valid_x, valid_y, label='Valid')

plt.legend(loc='upper right')
plt.ylabel('loss value')
plt.xlabel('Iteration')
plt.show()

# ============================ inference ============================

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
test_dir = os.path.join(BASE_DIR, "test_data")

test_data = RMBDataset(data_dir=test_dir, transform=valid_transform)
valid_loader = DataLoader(dataset=test_data, batch_size=1)

for i, data in enumerate(valid_loader):
    # forward
    inputs, labels = data
    outputs = net(inputs)
    _, predicted = torch.max(outputs.data, 1)

    rmb = 1 if predicted.numpy()[0] == 0 else 100

    img_tensor = inputs[0, ...]  # C H W
    img = transform_invert(img_tensor, train_transform)
    plt.imshow(img)
    plt.title("LeNet got {} Yuan".format(rmb))
    plt.show()
    plt.pause(0.5)
    plt.close()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/280986.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

[Angular] 笔记 23:Renderer2 - ElementRef 的生产版本

chatgpt: Renderer2 简介 在 Angular 中&#xff0c;Renderer2 是一个服务&#xff0c;用于处理 DOM 操作的抽象层。它提供了一种安全的方式来操作 DOM&#xff0c;同时与平台无关&#xff0c;有助于维护应用程序的跨浏览器兼容性和安全性。 Renderer2 的作用是在 Angular 组…

【vim 学习系列文章 3.1 -- vim 删除 ^M】

请阅读【嵌入式开发学习必备专栏 之 VIM 专栏】 文章目录 ^M 来源^M 删除 ^M 来源 在 Vim 中打开文件时&#xff0c;您可能会遇到行尾的 ^M 字符&#xff0c;这通常是因为文件使用了 Windows 风格的回车换行符&#xff08;CRLF&#xff09;&#xff0c;而不是 Unix/Linux 风格…

集合基础知识点

集合基础 1. 集合的由来 当 Java 程序中需要存放数据的时候&#xff0c;通常会定义变量来实现数据的存储&#xff0c;但是&#xff0c;当需要存储大量数据的时候该怎么办呢&#xff1f;这时首先想到的是数组&#xff0c;但是&#xff01;数组只能存放同一类型的数据&#xff…

【计算机网络】第五,六章摘要重点

1.运输层协议概述 运输层提供的是进程之间的通信 2. 3.套接字指的是什么 ip地址端口号 4.每一条TCP语句唯一地被通信两端连接的两个端点 5.TCP传输如何实现 以字节为单位的滑动窗口 超时重传 选择确认 6.TCP流量控制和拥塞控制的区别 流量控制&#xff1a;点对点通信…

Android Studio如何创建尺寸大小及API通用的模拟器

目录 前言 一、操作步骤 二、总结 三、更多资源 前言 在开发移动应用程序的过程中&#xff0c;使用模拟器进行测试是一种常见和方便的方式。Android Studio是一款功能强大的集成开发环境&#xff0c;它提供了创建和管理模拟器的功能。在本文中&#xff0c;我们将介绍如何创…

XTU-OJ-1452-完全平方数-笔记

参考博客 XTU-OJ 1452-完全平方数 题意 输入一个奇数&#xff0c;使得 n*(2*an-1)/2是一个完全平方数&#xff0c;求满足条件的最小的a 1<n<1e9 先输入样例数&#xff0c;再输入n 输入 2 1 3 输出 0 2 代码 #include<stdio.h>#define N 1000000010int a…

SpringIOC之ApplicationObjectSupport

博主介绍&#xff1a;✌全网粉丝5W&#xff0c;全栈开发工程师&#xff0c;从事多年软件开发&#xff0c;在大厂呆过。持有软件中级、六级等证书。可提供微服务项目搭建与毕业项目实战&#xff0c;博主也曾写过优秀论文&#xff0c;查重率极低&#xff0c;在这方面有丰富的经验…

数据库视图

使用视图就是执行视图的select操作。 为什么要使用视图&#xff1f; 安全性&#xff1a;只开放表中的部分数据&#xff08;只能看到视图开放的列&#xff09;便捷性&#xff1a;可以将复杂的多表关联查询放到视图中&#xff0c;屏蔽底层的复杂性。 create view viewa_name as…

2023年03月20日_对李开复3月20日线下媒体会的解读

最近这个AI大模型 因为GPT4.0 ChatGPT 文心一言等等这些事情呢 一下子就被推到了风口浪尖 我们也做了来介绍相关的进展 国内呢也不断有一些大佬开始下场 包括王慧文、张朝阳、李彦宏什么的 都开始说自己要搞AI大模型 就在昨天呢 创新工厂的董事长兼CEO李开复 也发朋友…

2022第十届中国互联网测试开发大会(MTSC2022)-核心PPT资料下载

一、峰会简介 MTSC中国互联网测试开发大会 Make Tester Super Cool&#xff08;简称 MTSC&#xff09;是由国内最大的移动测试技术社区 TesterHome 发起的软件测试行业技术会议&#xff0c;大会以“软件质量保障体系和测试研发技术交流”为主要目的。MTSC 大会于 2015 年举办第…

产品经理学习-怎么写PRD文档

目录 瀑布流方法论介绍 产品需求文档&#xff08;PRD&#xff09;介绍 产品需求文档的基本要素 撰写产品需求文档 优先产品需求文档的特点 其他相关文档 瀑布流方法论介绍 瀑布流模型是一种项目的开发和管理的方法论&#xff0c;是敏捷的开发管理方式相对应的另一种方法…

前端算法之双指针之快慢指针(Floyd 判圈法)

双指针与快慢指针快慢指针&#xff08;Floyd 判圈法&#xff09; 简介推导 在链表中&#xff0c;快指针和慢指针都可以指向头节点&#xff0c;然后根据问题的要求进行移动。 快指针通常会比慢指针移动得更快&#xff0c;例如每次移动两步&#xff0c;而慢指针则每次移动一步。…

Linux之磁盘分区,挂载

Linux分区 分区介绍 对linux来说无论有几个分区&#xff0c;分给哪个目录使用&#xff0c;归根结底只有一个根目录&#xff0c;linux中每个分区都是用来组成整个文件系统的一部分。linux采用“载入"的处理方法&#xff0c;他的整个文件系统中包含一整套的文件和目录&…

68.乐理基础-打拍子-大附点与变体

上一节内容&#xff1a;66.乐理基础-打拍子-小切分-CSDN博客&#xff0c;只所以没有67因为67可以不用知道&#xff0c;67节内容在&#xff1a;※-打拍子&#xff08;8&#xff09;-一拍内的变体1-乐理教程-腾讯课堂 (qq.com) 大附点&#xff1a;大附点这个名字不是通用的&…

如何使用python脚本生成redis格式的数据包

用python脚本生成redis格式的数据包 &#xff08;1&#xff09;使用下述网站下载开源的生成gopher协议规则的包的工具 https://github.com/firebroo/sec_tools/tree/master/redis-over-gopher &#xff08;2&#xff09;首先要修改redis.cmd中的内容 flushall config set di…

【面试】 Maven 的八大核心概念

Maven 的八大核心概念 在这里&#xff0c;举出这个标题&#xff0c;自然大家知道Maven是干啥的&#xff0c;就不过多进行赘述&#xff01;我们主要对于Maven的八大核心概念做一个解释补充&#xff0c;这也是我自己的一个学习历程&#xff0c;我们一起共勉&#xff01; 文章概述…

【网络安全常用术语解读】SCAP详解

本文主要介绍什么是SCAP&#xff0c;SCAP的产生背景是怎样的&#xff0c;SCAP有什么用途&#xff0c;有哪些组件&#xff0c;各个组件的用途是什么&#xff1f; SCAP产生背景 由于计算机和网络技术的快速发展&#xff0c;越来越多的软件和系统被应用到企业和机构中&#xff0c…

一文掌握Java注解之@SpringBootApplication知识文集(1)

&#x1f3c6;作者简介&#xff0c;普修罗双战士&#xff0c;一直追求不断学习和成长&#xff0c;在技术的道路上持续探索和实践。 &#x1f3c6;多年互联网行业从业经验&#xff0c;历任核心研发工程师&#xff0c;项目技术负责人。 &#x1f389;欢迎 &#x1f44d;点赞✍评论…

编程羔手解决Maven引入多个版本的依赖包,导致包冲突了

最近升级了些依赖发现有个hutool的方法老报错&#xff0c;java.lang.NoSuchMethodError: cn.hutool.core.util.ObjectUtil.defaultIfNull(Ljava/lang/Object;Ljava/util/function/Supplier;) 在 Maven 项目中&#xff0c;当不同的依赖模块引入 Hutool 的不同版本时&#xff0c…

软件测试/测试开发丨学习笔记之 Python 函数

python 函数 函数的作用 函数是组织好的&#xff0c;可重复使用的&#xff0c;用来实现单一或相关联功能的代码段函数能提高应用的模块性和代码的重复利用率python 内置函数&#xff1a;docs.python.org/zh-cn/3.8/l… 函数定义 def&#xff1a;函数定义关键词function_nam…