【动手学深度学习】--18.图像增广

文章目录

  • 图像增广
    • 1.常用的图像增广方法
      • 1.1翻转和裁剪
      • 1.2改变颜色
      • 1.3结合多种图像增广方法
    • 2.使用图像增广进行训练
    • 3.训练

图像增广

官方笔记:图像增广
学习视频:数据增广【动手学深度学习v2】

图像增广在对训练图像进行一系列的随机变化之后,生成相似但不同的训练样本,从而扩大了训练集的规模。 此外,应用图像增广的原因是,随机改变训练样本可以减少模型对某些属性的依赖,从而提高模型的泛化能力。 例如,我们可以以不同的方式裁剪图像,使感兴趣的对象出现在不同的位置,减少模型对于对象出现位置的依赖。 我们还可以调整亮度、颜色等因素来降低模型对颜色的敏感度。 可以说,图像增广技术对于AlexNet的成功是必不可少的。

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

1.常用的图像增广方法

在对常用图像增广方法的探索时,我们将使用下面这个尺寸为400×500的图像作为示例

d2l.set_figsize()
img = d2l.Image.open('F:/pytorch/img/cat.jpg')
d2l.plt.imshow(img);

大多数图像增广方法都具有一定的随机性。为了便于观察图像增广的效果,我们下面定义辅助函数apply。 此函数在输入图像img上多次运行图像增广方法aug并显示所有结果。

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    d2l.show_images(Y, num_rows, num_cols, scale=scale)

1.1翻转和裁剪

左右翻转图像通常不会改变对象的类别。这是最早且最广泛使用的图像增广方法之一。 接下来,我们使用transforms模块来创建RandomFlipLeftRight实例,这样就各有50%的几率使图像向左或向右翻转。

apply(img, torchvision.transforms.RandomHorizontalFlip())

image-20230720105619238

上下翻转图像不如左右图像翻转那样常用。但是,至少对于这个示例图像,上下翻转不会妨碍识别。接下来,我们创建一个RandomFlipTopBottom实例,使图像各有50%的几率向上或向下翻转。

apply(img, torchvision.transforms.RandomVerticalFlip())

image-20230720111122704

在我们使用的示例图像中,猫位于图像的中间,但并非所有图像都是这样。我们解释了池化层可以降低卷积层对目标位置的敏感性。 另外,我们可以通过对图像进行随机裁剪,使物体以不同的比例出现在图像的不同位置。 这也可以降低模型对目标位置的敏感性。

下面的代码将随机裁剪一个面积为原始面积10%到100%的区域,该区域的宽高比从0.5~2之间随机取值。 然后,区域的宽度和高度都被缩放到200像素。 在本节中(除非另有说明),a和b之间的随机数指的是在区间[a,b]中通过均匀采样获得的连续值

shape_aug = torchvision.transforms.RandomResizedCrop(
    (200, 200), scale=(0.1, 1), ratio=(0.5, 2))
apply(img, shape_aug)

image-20230720111219412

1.2改变颜色

另一种增广方法是改变颜色。 我们可以改变图像颜色的四个方面:亮度、对比度、饱和度和色调。 在下面的示例中,我们随机更改图像的亮度,随机值为原始图像的50%(1−0.5)到150%(1+0.5)之间。

apply(img, torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0, saturation=0, hue=0))

image-20230720111305660

同样,我们可以随机更改图像的色调。

apply(img, torchvision.transforms.ColorJitter(
    brightness=0, contrast=0, saturation=0, hue=0.5))

image-20230720111403505

我们还可以创建一个RandomColorJitter实例,并设置如何同时随机更改图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue

color_aug = torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
apply(img, color_aug)

image-20230720111449723

1.3结合多种图像增广方法

在实践中,我们将结合多种图像增广方法。比如,我们可以通过使用一个Compose实例来综合上面定义的不同的图像增广方法,并将它们应用到每个图像。

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)

image-20230720111524757

2.使用图像增广进行训练

让我们使用图像增广来训练模型。 这里,我们使用CIFAR-10数据集,而不是我们之前使用的Fashion-MNIST数据集。 这是因为Fashion-MNIST数据集中对象的位置和大小已被规范化,而CIFAR-10数据集中对象的颜色和大小差异更明显。 CIFAR-10数据集中的前32个训练图像如下所示。

all_images = torchvision.datasets.CIFAR10(train=True, root="./data",
                                          download=True)
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8);
image-20230720111614485

为了在预测过程中得到确切的结果,我们通常对训练样本只进行图像增广,且在预测过程中不使用随机操作的图像增广。 在这里,我们只使用最简单的随机左右翻转。 此外,我们使用ToTensor实例将一批图像转换为深度学习框架所要求的格式,即形状为(批量大小,通道数,高度,宽度)的32位浮点数,取值范围为0~1。

train_augs = torchvision.transforms.Compose([
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ToTensor()])

test_augs = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor()])

接下来,我们定义一个辅助函数,以便于读取图像和应用图像增广。PyTorch数据集提供的transform参数应用图像增广来转化图像

def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root="../data", train=is_train,
                                           transform=augs, download=True)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                    shuffle=is_train, num_workers=d2l.get_dataloader_workers())
    return dataloader

3.训练

我们定义一个函数,使用GPU对模型进行训练和评估

def train_batch_ch13(net,X,y,loss,trainer,device):
    if isinstance(X, list):
        # Bert微调所需的
        X = [x.to(device) for x in X]
    else:
        X = X.to(device)
    y = y.to(device)
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred,y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred,y)
    return train_loss_sum,train_acc_sum

下面的函数中,不进行画图,将每次训练的损失误差、准确率显示出来

def train_ch13(net,train_iter,test_iter,loss,trainer,num_epochs,device):
    timer, num_batches = d2l.Timer(), len(train_iter)
    net.to(device)
    for epoch in range(num_epochs):
        metric = d2l.Accumulator(4)
        for i,(features,labels) in enumerate(train_iter):
            timer.start()
            l,acc = train_batch_ch13(
                net,features,labels,loss,trainer,device
            )
            metric.add(l,acc,labels.shape[0],labels.numel())
            timer.stop()
        test_acc = d2l.evaluate_accuracy_gpu(net,test_iter)
        print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(device)}')

现在,我们可以定义train_with_data_aug函数,使用图像增广来训练模型。该函数获取所有的GPU,并使用Adam作为训练的优化算法,将图像增广应用于训练集,最后调用刚刚定义的用于训练和评估模型的train_ch13函数。

batch_size, device, net = 256, d2l.try_gpu(), d2l.resnet18(10, 3)

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    trainer = torch.optim.Adam(net.parameters(), lr=lr)
    train_ch13(net, train_iter, test_iter, loss, trainer, 3, devices)

让我们使用基于随机左右翻转的图像增广来训练模型。

train_with_data_aug(train_augs, test_augs, net)

image-20230822195443240

总结:

  • 数据增广通过变形数据来获取多样性从而使得模型泛化性能更好
  • 常见图片增广包括翻转、切割、变色

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/91845.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

实验八 网卡驱动移植

【实验目的】 掌握 Linux 内核配置的基本方法,完成对网卡驱动、NFS 等相关功能的配置 【实验环境】 ubuntu 14.04 发行版FS4412 实验平台交叉编译工具:arm-none-linux-gnueabi- 【注意事项】 实验步骤中以“$”开头的命令表示在 ubuntu 环境下执行&…

21.2 CSS 三大特性与页面布局

1. 开发者工具修改样式 使用开发者工具修改样式, 操作步骤如下: * 1. 打开开发者工具: 在浏览器中右键点击页面, 然后选择检查或者使用快捷键(一般是 F12 或者 CtrlShiftI)来打开开发者工具.* 2. 打开样式编辑器: 在开发者工具中, 找到选项卡或面板, 一般是Elements或者Elemen…

最新AI系统ChatGPT程序源码/微信公众号/H5端+搭建部署教程+完整知识库

一、前言 SparkAi系统是基于国外很火的ChatGPT进行开发的Ai智能问答系统。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。 那么如何搭建部署AI创作ChatGPT?小编这里写一个详细图文教程吧&#xff01…

不系安全带抓拍自动识别

不系安全带抓拍自动识别系统通过yolo系列算法框架模型利用高清摄像头,不系安全带抓拍自动识别算法对高空作业场景进行监控,当检测到人员未佩戴安全带时会自动抓拍并进行告警记录。YOLO系列算法是一类典型的one-stage目标检测算法,其利用ancho…

一生一芯8——在github上添加ssh key

为在github上下载代码框架,这里在github上使用ssh key进行远程连接,方便代码拉取 参照博客https://blog.csdn.net/losthief/article/details/131502734 本机 系统ubuntu22.04 git 版本2.34.1 本人是第一次配置,没有遇到奇奇怪怪的错误&…

Faster RCNN网络数据流总结

前言 在学习Faster RCNN时,看了许多别人写的博客。看了以后,对Faster RCNN整理有了一个大概的了解,但是对训练时网络内部的数据流还不是很清楚,所以在结合这个版本的faster rcnn代码情况下,对网络数据流进行总结。以便…

TiDB 源码编译之 TiProxy 篇

作者: ShawnYan 原文来源: https://tidb.net/blog/3d57f54d TiProxy 简介 TiProxy 是一个基于 Apache 2.0 协议开源的、轻量级的 TiDB 数据库代理,基于 Go 语言编写,支持 MySQL 协议。 TiProxy 支持负载均衡,接收来…

【SpringCloud技术专题】「Gateway网关系列」(2)微服务网关服务的Gateway功能配置指南分析

Spring Cloud Gateway简介 Spring Cloud Gateway是Spring Cloud体系的第二代网关组件,基于Spring 5.0的新特性WebFlux进行开发,底层网络通信框架使用的是Netty,所以其吞吐量高、性能强劲,未来将会取代第一代的网关组件Zuul。Spri…

opencv-gpu版本编译(添加java支持,可选)实现硬解码

目录 opencv gpu版本编译,实现硬解码,加速rtsp视频流读取1、准备文件2、复制 NVCUVID 头文件到 cuda 安装目录 include3、安装相关依赖4、 执行cmake5、编译安装6、测试 opencv gpu版本编译,实现硬解码,加速rtsp视频流读取 前置条…

为什么使用Nacos而不是Eureka(Nacos和Eureka的区别)

文章目录 前言一、Eureka是什么?二、Nacos是什么?三、Nacos和Eureka的区别3.1 支持的CAP3.2连接方式3.3 服务异常剔除3.4 操作实例方式 总结 前言 为什么如今微服务注册中心用Nacos相对比用Eureka的多了?本文章将介绍他们之间的区别和优缺点…

2023前端面试笔记 —— CSS3

系列文章目录 内容链接2023前端面试笔记HTML52023前端面试笔记CSS3 文章目录 系列文章目录前言一、CSS选择器的优先级二、通过 CSS 的哪些方式可以实现隐藏页面上的元素三、px、em、rem之间有什么区别?四、让元素水平居中的方法有哪些五、在 CSS 中有哪些定位方式六…

windows11系统重装步骤及优化技巧

目录 目录 本文目的 Windows11介绍 Windows下载 和win10对比 重装步骤 系统设置调整 系统备份还原 C盘减肥,空间优化技巧 Java开发工具 本文目的 说明windows11的系统重装步骤,大部分步骤也适用于其他windows版本。常用软件的安装与介绍。系统…

Jmeter 如何才能做好接口测试?

现在对测试人员的要求越来越高,不仅仅要做好功能测试,对接口测试的需求也越来越多! 所以也越来越多的同学问,怎样才能做好接口测试? 要真正的做好接口测试,并且弄懂如何测试接口,需要从如下几…

HarmonyOS/OpenHarmony应用开发-ArkTS语言渲染控制LazyForEach数据懒加载

LazyForEach从提供的数据源中按需迭代数据,并在每次迭代过程中创建相应的组件。当LazyForEach在滚动容器中使用了,框架会根据滚动容器可视区域按需创建组件,当组件划出可视区域外时,框架会进行组件销毁回收以降低内存占用。一、接…

Oracle-rolling upgrade升级19c

前言: 本文主要描述Oracle11g升19c rolling upgrade升级测试,通过逻辑DGautoupgrade方式实现rolling upgrade,从而达到在较少停机时间内完成Oracle11g升级到19c的目标 升级介绍: 升级技术: rolling upgrade轮询升级,通过采用跨版…

手把手教你用 ANSYS workbench

ANSYS Workbench ANSYS Workbench是一款基于有限元分析(FEA)的工程仿真软件。其基本概念包括: 工作区(Workspace):工程仿真模块都在此区域内,包括几何建模、网格划分、边界条件设置、分析求解等…

【leetcode 力扣刷题】链表基础知识 基础操作

链表基础知识 基础操作 链表基础操作链表基础知识插入节点删除节点查找节点 707. 设计链表实现:单向链表:实现:双向链表 链表基础操作 链表基础知识 在数据结构的学习过程中,我们知道线性表【一种数据组织、在内存中存储的形式】…

django的简易的图书管理系统jsp书店进销存源代码MySQL

本项目为前几天收费帮学妹做的一个项目,Java EE JSP项目,在工作环境中基本使用不到,但是很多学校把这个当作编程入门的项目来做,故分享出本项目供初学者参考。 一、项目描述 django的简易的图书管理系统 系统有1权限&#xff1a…

Redis的基本操作

文章目录 1.Redis简介2.Redis的常用数据类型3.Redis的常用命令1.字符串操作命令2.哈希操作命令3.列表操作命令4.集合操作命令5.有序集合操作命令6.通用操作命令 4.Springboot配置Redis1.导入SpringDataRedis的Maven坐标2.配置Redis的数据源3.编写配置类,创还能Redis…

如何在VSCode中将html文件打开到浏览器

天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…