图像增广:强化深度学习的视觉表现力

目录

摘要:

1. 图像增广简介

2. 图像增广的原理

3. 常见的图像增广技术

4. 如何在实际项目中应用图像增广

5.实际应用


摘要:

当今,深度学习已经在计算机视觉领域取得了令人瞩目的成就。图像增广作为一种数据处理技术,让我们在使用有限的图像数据集时能够充分挖掘图像特征,提高模型的泛化能力。本文将详细介绍图像增广的概念、原理以及如何在实际项目中应用。

1. 图像增广简介

图像增广(Image Augmentation)是一种通过对原始图像进行各种变换来生成新的图像的方法。这些变换包括旋转、翻转、缩放、剪切、色彩变换等。通过图像增广,我们可以扩大数据集的规模,增加模型训练时的输入样本。这有助于提高模型的泛化能力,从而在面对新的、未知的数据时,也能达到较高的准确性。

2. 图像增广的原理

深度学习模型在训练过程中需要大量的数据来学习特征表达。然而,在实际应用中,我们并不总是能获得足够多的数据。图像增广通过对原始图像进行各种变换,创造出具有不同视觉特征的新图像。这样一来,模型在训练时可以接触到更多样的数据,从而学习到更丰富的特征表达,提高泛化能力。

值得注意的是,图像增广并不能完全解决数据不足的问题,但它可以在一定程度上缓解这个问题,提高模型的性能。

3. 常见的图像增广技术

以下是一些常见的图像增广技术:

- **旋转**:将图像按一定的角度进行旋转。
- **翻转**:对图像进行水平或垂直翻转。
- **缩放**:对图像进行放大或缩小。
- **剪切**:在图像上随机选择一块区域,将其裁剪为新的图像。
- **色彩变换**:改变图像的亮度、对比度、饱和度等色彩属性。
- **噪声添加**:在图像中添加随机噪声。
- **仿射变换**:对图像进行平移、旋转、缩放等操作。

4. 如何在实际项目中应用图像增广

许多深度学习框架都提供了图像增广的相关工具,例如 TensorFlow、PyTorch、Keras 等。在使用这些框架时,我们可以轻松地将图像增广技术应用到我们的项目中。以下是一个使用 Keras 进行图像增广的简单示例:

from keras.preprocessing.image import ImageDataGenerator

# 创建一个图像数据生成器
datagen = ImageDataGenerator(
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True)

# 将数据生成器应用到训练集
train_generator = datagen.flow_from_directory(
    train_data_dir,
    target_size=(img_height, img_width),
    batch_size=batch_size,
    class_mode='categorical')

在上述代码中,我们定义了一个图像数据生成器,并设置了一些增广参数。然后,我们使用这个数据生成器对训练集进行处理。

5.实际应用

%matplotlib inline
import torch
import torchvision
from torch import nn
from d2l import torch as d2l

导入图片: 图片大小为400*500

d2l.set_figsize()
img = d2l.Image.open('./img/1.jpg') #务必将图片放到该根目录
d2l.plt.imshow(img)

 大多数图像增广方法都具有一定的随机性。为了便于观察图像增广的效果,我们下面定义辅助函数apply。 此函数在输入图像img上多次运行图像增广方法aug并显示所有结果。

def apply(img, aug, num_rows=2, num_cols=4, scale=1.5):#aug增广
    Y = [aug(img) for _ in range(num_rows * num_cols)]
    d2l.show_images(Y, num_rows, num_cols, scale=scale)

 左右翻转图像通常不会改变对象的类别。这是最早且最广泛使用的图像增广方法之一。 使用transforms模块来创建RandomFlipLeftRight实例,这样就各有50%的几率使图像向左或向右翻转。

 上下翻转并不常用:

apply(img, torchvision.transforms.RandomVerticalFlip())

 随机剪切函数:剪切大小200*200像素,在10%-100%的范围内剪切,长宽比为1:2

shape_aug = torchvision.transforms.RandomResizedCrop((200,200),scale=(0.1,1),ratio=(0.5,2))
apply(img,shape_aug)

随机改变亮度

apply(img,torchvision.transforms.ColorJitter(brightness=0.5,contrast=0,saturation=0,hue=0))

 

 改变色调:

apply(img, torchvision.transforms.ColorJitter(
    brightness=0, contrast=0, saturation=0, hue=0.5))

 创建一个RandomColorJitter实例,并设置如何同时随机更改图像的亮度(brightness)、对比度(contrast)、饱和度(saturation)和色调(hue)。

color_aug = torchvision.transforms.ColorJitter(
    brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5)
apply(img, color_aug)

 最常用的就是将各种增广方法结合起来:

augs = torchvision.transforms.Compose([
    torchvision.transforms.RandomHorizontalFlip(), color_aug, shape_aug])
apply(img, augs)

 下载一个常用数据集:

#下载数据集
all_images = torchvision.datasets.CIFAR10(train=True, root="./data",
                                          download=True)
d2l.show_images([all_images[i][0] for i in range(32)], 4, 8, scale=0.8);

 通常对训练样本只进行图像增广,且在预测过程中不使用随机操作的图像增广。使用ToTensor实例将一批图像转换为深度学习框架所要求的格式,即形状为(批量大小,通道数,高度,宽度)的32位浮点数,取值范围为0~1。

train_augs = torchvision.transforms.Compose([
     torchvision.transforms.RandomHorizontalFlip(),
     torchvision.transforms.ToTensor()]) #变成一个4d矩阵

test_augs = torchvision.transforms.Compose([
     torchvision.transforms.ToTensor()])
def load_cifar10(is_train, augs, batch_size):
    dataset = torchvision.datasets.CIFAR10(root="./data", train=is_train,
                                           transform=augs, download=True)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
                    shuffle=is_train, num_workers=d2l.get_dataloader_workers())
    return dataloader

 使用gpu进行训练数据:

#@save
#使用多GPU对模型进行训练和评估
def train_batch_ch13(net, X, y, loss, trainer, devices):
    """用多GPU进行小批量训练"""
    if isinstance(X, list):
        # 微调BERT中所需
        X = [x.to(devices[0]) for x in X]
    else:
        X = X.to(devices[0])
    y = y.to(devices[0])
    net.train()
    trainer.zero_grad()
    pred = net(X)
    l = loss(pred, y)
    l.sum().backward()
    trainer.step()
    train_loss_sum = l.sum()
    train_acc_sum = d2l.accuracy(pred, y)
    return train_loss_sum, train_acc_sum

#@save
def train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,
               devices=d2l.try_all_gpus()):
    """用多GPU进行模型训练"""
    timer, num_batches = d2l.Timer(), len(train_iter)
    animator = d2l.Animator(xlabel='epoch', xlim=[1, num_epochs], ylim=[0, 1],
                            legend=['train loss', 'train acc', 'test acc'])
    net = nn.DataParallel(net, device_ids=devices).to(devices[0])
    for epoch in range(num_epochs):
        # 4个维度:储存训练损失,训练准确度,实例数,特点数
        metric = d2l.Accumulator(4)
        for i, (features, labels) in enumerate(train_iter):
            timer.start()
            l, acc = train_batch_ch13(
                net, features, labels, loss, trainer, devices)
            metric.add(l, acc, labels.shape[0], labels.numel())
            timer.stop()
            if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:
                animator.add(epoch + (i + 1) / num_batches,
                             (metric[0] / metric[2], metric[1] / metric[3],
                              None))
        test_acc = d2l.evaluate_accuracy_gpu(net, test_iter)
        animator.add(epoch + 1, (None, None, test_acc))
    print(f'loss {metric[0] / metric[2]:.3f}, train acc '
          f'{metric[1] / metric[3]:.3f}, test acc {test_acc:.3f}')
    print(f'{metric[2] * num_epochs / timer.sum():.1f} examples/sec on '
          f'{str(devices)}')

 可以定义train_with_data_aug函数,使用图像增广来训练模型。该函数获取所有的GPU,并使用Adam作为训练的优化算法,将图像增广应用于训练集,最后调用刚刚定义的用于训练和评估模型的train_ch13函数。

batch_size, devices, net = 256, d2l.try_all_gpus(), d2l.resnet18(10, 3)

def init_weights(m):
    if type(m) in [nn.Linear, nn.Conv2d]:
        nn.init.xavier_uniform_(m.weight)

net.apply(init_weights)

def train_with_data_aug(train_augs, test_augs, net, lr=0.001):
    train_iter = load_cifar10(True, train_augs, batch_size)
    test_iter = load_cifar10(False, test_augs, batch_size)
    loss = nn.CrossEntropyLoss(reduction="none")
    trainer = torch.optim.Adam(net.parameters(), lr=lr)
    train_ch13(net, train_iter, test_iter, loss, trainer, 10, devices)
train_with_data_aug(train_augs, test_augs, net)

        总之,图像增广作为一种数据处理技术,在深度学习领域具有重要的意义。通过应用图像增广,我们能够充分挖掘图像特征,提高模型的泛化能力。在实际项目中,我们可以根据需求选择不同的增广技术,从而优化模型的性能。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/37874.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

一.CreateFileMapping实现的共享内存及用法

共享内存概念 1.在32位的Windows系统中,每一个进程都有权访问他自己的4GB(2324294967296)平面地址空间,没有段,没有选择符,没有near和far指针,没有near和far函数调用,也没有内存模式…

修改npm路径

npm config ls如果是第一次使用NPM安装包的话,在配置中只会看到prefix的选项,就是NPM默认的全局安装目录。但是如果有多次使用NPM安装包的话,就会看到cache和prefix两个路径。 新建两个文件夹node_global_modules和node_cache npm config s…

【CesiumJS入门】(7)绘制多段线(动态实时画线)

前言 鼠标左键添加点、右键完成绘制,单击右侧弹窗关闭按钮清空绘制。参考沙盒示例:Drawing on Terrain 直接上代码了 /** Date: 2023-07-12 18:47:18* LastEditors: ReBeX 420659880qq.com* LastEditTime: 2023-07-16 16:26:19* FilePath: \cesium-tyro-blog\s…

Verdi之波形展示nWave

6.nWave 6.1 添加波形文件 1.打开nWave界面,具体操作如下: 2.正式添加波形,使用快捷键G或者点击以下图标,选择需要的信号。 也可以在 n Trace中选中信号后,鼠标中键拖拽,或者ctrlw进行添加; 6…

R和python中dataframe读取方式总结

首先我有一个如图所示的文件 如果在python中读取 import pandas as pd df pd.read_csv("./6group_count.csv",index_col0) df而在R中读取的方式如下 df read.csv("./6group_count.csv",row.names 1)

MySQL---索引

目录 一、索引的分类 二、索引的底层原理是什么? 2.1、Innodb和MyIsAM两种引擎搜索数据时候的区别: 2.2、为什么MySQL(MyIsAM、Innodb)索引选择B树而不是B树呢? 2.3、Innodb的主键索引和二级索引(辅助…

【Ajax】笔记-Ajax案例准备与请求基本操作

案例准备HTML 按钮div <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>AJAX GET 请求</title&g…

2D、3D机器视觉各有优势与局限,融合应用将成工业领域生产新方式

在智能制造的浪潮中&#xff0c;制造行业生产线亟需转型升级&#xff0c;为国内机器视觉市场释放出了惊人的机器视觉技术及产品需求。在自动化工业质量控制和在线检测领域&#xff0c;2D机器视觉与3D机器视觉都具有重要的作用。那在机器视觉自动化场景中该如何选择合适的机器视…

python 乘法口诀

下面是一个用Python打印乘法口诀表的代码&#xff1a; print("乘法口诀表:")for i in range(1, 10):for j in range(1, i1):print(f"{j} {i} {i*j}", end"\t")print()

Blazor前后端框架Known-V1.2.4

V1.2.4 Known是基于C#和Blazor开发的前后端分离快速开发框架&#xff0c;开箱即用&#xff0c;跨平台&#xff0c;一处代码&#xff0c;多处运行。 Gitee&#xff1a; https://gitee.com/known/KnownGithub&#xff1a;https://github.com/known/Known 概述 基于C#和Blazor…

【图像处理】Python判断一张图像是否亮度过低,图片模糊判定

文章目录 亮度判断模糊判断 亮度判断 比如&#xff1a; 直方图&#xff1a; 代码&#xff1a; 这段代码是一个用于判断图像亮度是否过暗的函数is_dark&#xff0c;并对输入的图像进行可视化直方图展示。 首先&#xff0c;通过import语句导入了cv2和matplotlib.pyplot模块…

Element-Plus搭建CMS页面结构 引入第三方图标库iconfont(详细)

Element-Plus组件库使用 element plus组件库是由饿了么前端团队专门针对vue框架开发的组件库&#xff0c;专门用于电脑端网页的。因为里面集成了很多组件&#xff0c;所以使用他可以非常快速的帮我们实现网站的开发。 安装&#xff1a; npm install element-plus --save 引入…

jenkins 采用ssh方式连接gitlab连接不上

一、gitlab 添加jenkins服务器的公钥 jenkins 生成秘钥命令 ssh-keygen -t rsa2.jenkins 秘钥地址&#xff1a; cd /root/.ssh3.复制公钥 到gitlab 添加 cat id_rsa_pub4.添加私钥到jenkins cat id_rsa5.绑定&#xff08;顺利的话到这里就结束了&#xff09; &#xff0…

Linux下Lua和C++交互

前言 lua&#xff08;wiki 中文 官方社区&#xff1a;lua-users&#xff09;是一门开源、简明、可扩展且高效的弱类型解释型脚本语言。 由于其实现遵循C标准&#xff0c;它几乎能在所有的平台&#xff08;windows、linux、MacOS、Android、iOS、PlayStation、XBox、wii等&…

【条件与循环】——matlab入门

目录索引 if&#xff1a;else与elseif&#xff1a; for&#xff1a; if&#xff1a; if 条件语句块 endelse与elseif&#xff1a; if 条件代码块 elseif 条件代码块 else 代码块 endfor&#xff1a; for 条件循环体 end在matlab里面类似的引号操作都是包头又包尾的。上面的c…

postman测试接口出现404

postman测试接口出现404 1.用postman调试接口的过程中&#xff0c;出现404的情况&#xff0c;但是接口明明已调到了&#xff0c;而且数据也已经存入数据库了&#xff0c;这让我感到很疑惑。看网上的解决办法检查了我的路径&#xff0c;提交方式、参数类型等都是正确的&#xf…

安装adobe系列产品,提示错误代码81解决办法

安装adobe系列软件&#xff0c;如Photoshop、Premiere Pro、Illustrator等时&#xff0c;出现如下图提示错误代码81&#xff0c;如何解决呢&#xff1f;一起来看看。 解决方法一 (重启电脑等待5分钟再安装&#xff01;) 解决方法二 应用程序中打开Adobe Creative Cloud 点击…

Linux系统终端窗口ctrl+c,ctrl+z,ctrl+d的区别

时常在Linux系统上&#xff0c;执行某命令停不下来&#xff0c;就这几个ctrl组合键按来按去&#xff0c;今天稍微总结下具体差别&#xff0c;便于以后linux系统运维操作 1、ctrlc强制中断程序&#xff0c;相应进程会被杀死&#xff0c;中断进程任务无法恢复执行 2、ctrlz暂停正…

【运维工程师学习】Centos中MySQL替换MariaDB

【运维工程师学习】Centos8中MySQL替换MariaDB 1、查看已有的mysql2、MySQL官网tar包下载3、找到下载路径解压4、移动解压后的文件夹到/usr/local/mysql5、创建data文件夹&#xff0c;一般用于存放数据库文件数据6、创建用户组7、更改用户文件夹权限8、生成my.cnf文件9、编辑my…

ZooKeeper ZAB

文章首发地址 在接收到一个写请求操作后&#xff0c;追随者会将请求转发给群首&#xff0c;群首将探索性地执行该请求&#xff0c;并将执行结果以事务的方式对状态更新进行广播。一个事务中包含服务器需要执行变更的确切操作&#xff0c;当事务提交时&#xff0c;服务器就会将这…