使用合成数据训练语义分割模型

计算机视觉应用现在在各个科技领域无处不在。 这些模型高效且有效,研究人员每年都会尝试新想法。 新的前沿是试图消除深度学习的最大负担:需要大量的标记数据。 正如本文所述,此问题的解决方案是使用合成数据。

从这些研究中获益最多的计算机视觉领域当然是语义分割领域,即预测图像每个像素的标签的任务,以便从图像中检索感兴趣的对象。 正如人们所预料的那样,手动标记训练集是一个昂贵、耗时且容易出错的过程,因此有多种利用合成数据的新方法。

在本文中,我们将看到其中一种方法,它利用生成对抗网络来解决使用合成数据的域适应问题。另一种常用的合成数据生成方法是利用逼真渲染的游戏引擎,例如基于UE5开发的UnrealSynth合成数据生成器:
在这里插入图片描述

在线工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 3D场景编辑器

1、合成数据生成

为了生成语义分割任务的数据,最常见的解决方案是使用与渲染引擎关联的模拟器。 通过这种方式,可以随意生成图像,改变闪电条件、物体的数量和姿势以及它们之间的交互,并始终关联像素完美的语义标签。 例如,一个非常流行的数据集,几乎所有研究都用作基准,是 GTAV [1],其中使用的模拟引擎是同名视频游戏。 该数据集包含从汽车驾驶员的角度拍摄的图像,非常适合自动驾驶等应用。 另一个著名的数据集是 SINTHIA [2],它也包含城市环境的图像。

在这里插入图片描述

图 1.1 — 来自 GTAV 数据集 [2] 的带有标签的图像示例

2、领域适应的生成方法

直接使用合成数据训练模型是不够的,神经网络可能会学习模拟环境中存在的一些不真实的模式,无法很好地概括现实世界的数据。 这称为域适应问题(Domain Adaption Problem)。

为了克服这个问题,模型必须在训练过程中学习重新调整源域 S(合成域)和目标域 T(真实域)之间特征分布的最佳方法。 这可以通过对抗性训练、知识蒸馏和自我监督学习等多种方法来实现。

特别是,对抗性训练的特点是采用生成方法,将源域数据转换为更类似于目标域的分布。 它可以表述如下:

给定源域数据集 Dₛ= {(xᵢˢ, yᵢˢ), i=1…nₛ} 和目标域数据集 Dₜ = {xᵢᵗ, i=1…nₜ},其中 xᵢˢ 和 xᵢᵗ 是输入样本, yᵢˢ 是对应的样本 xᵢˢ 的标签,目标是学习一个映射函数 𝓍ᵢˢ = G(xᵢˢ),称为生成器,它将源域特征映射到目标域特征,以便在转换后的源域图像上训练的深度学习模型可以表现良好 在目标域上。 它是通过判别器来完成的,判别器是一种神经网络,它接收真实图像和变换后的合成图像的输入,并尝试预测输入是否来自真实分布。

网络在对抗性环境中进行训练,只有当鉴别器失败时,生成器才会获胜。 当变换后的图像与真实图像非常相似以至于鉴别器无法区分它们时,该过程会收敛,从而使预测不比随机猜测更好(准确度为 50%)。

3、几何引导的输入输出自适应

各种算法都利用生成方法。 其中之一被称为 GIO-Ada [3],代表几何引导输入输出适应。 该算法相对于简单方法引入了 2 项改进。

它使用可以从模拟引擎轻松检索的另一条信息:深度图。 直觉是,对象的几何信息更好地编码在其深度信息中,而不是其像素的语义标签中。 因此,模型被训练来估计输入图像的深度图,并且这个额外的信息仅在训练期间用作辅助损失。
它在输出级别使用第二个对抗阶段,第二个鉴别器对任务网络的输出(语义标签图和几何深度图)进行操作,经过训练以预测预测的输出来自真实的还是合成的 图像。

在这里插入图片描述

图 1.2 — GIO-Ada 架构概述。 源数据的流向以橙线显示,目标数据的流向以黑线显示

完整的架构由 4 个神经网络组成:生成器(用于转换合成图像)、任务网络(预测真实图像和转换图像的标签和深度图)以及 2 个判别器。 所有网络都经过端到端训练,并采用遵循对抗训练规则的通用优化步骤。

4、Pytorch Lightening实现

为了轻松实现和训练这种复杂的算法,pytorch_lightning 是一个可以提供帮助的库。 这是 pytorch 的包装器,有助于避免重新实现一些与 torch 配合使用所需的样板代码,例如实现训练循环、处理超参数和权重的记录和保存、管理 GPU(或多个 GPU)并执行优化器步骤。 在我们的例子中,最后一个功能不是必需的,因为对抗训练的特殊性恰恰在于生成器和判别器之间优化步骤的交替,并且需要定制。

让我们首先导入库并定义一个实用函数,该函数将用于为鉴别器创建标签。

import itertools
from typing import Iterator

import pytorch_lightning as pl
import torch
from torch import nn
from torchmetrics.classification.jaccard import MulticlassJaccardIndex


def _labels(inputs: torch.Tensor, fill_value: int) -> torch.Tensor:
    return torch.full((inputs.size(0), 1), fill_value).to(inputs)

神经网络被实现为torch模块。 给定 B = 批量大小、C = 图像通道、K = 类数、W、H= 图像的宽度和高度:

  • 任务网络必须处理形状为 B × C × W × H 的批量图像,并返回形状为 B × K × W × H 的标签预测和形状为 B × 1× W × H 的深度预测。一种可能的架构选择是 使用 DeepLabV3+ [4] 作为任务网络,具有两个不同的头,一个用于类别预测,一个用于深度预测。
  • 图像变换网络必须输入所有合成数据,即形状为 B × C × W × H 的图像、形状为 B × K × W × H 的标签和形状为 B × 1× W × H 的深度图,连接起来 它们,并在输出中生成形状为 B × C × W × H 的变换图像。
  • 鉴别器必须采用形状 B × (C 或 C + K + 1) × W × H 的输入,并产生形状 B × 1 的输出,表示样本为真实样本的概率。
class TaskNetwork(nn.Module):
    def __init__(
        self,
        input_channels: int,
        num_classes: int,
        pretrained_backbone: bool = False,
    ) -> None:
        ...

    def forward(self, image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        ...


class ImageTransformNetwork(nn.Module):
    def __init__(
        self,
        input_channels: int,
        output_channels: int,
    ) -> None:
        ...

    def forward(
        self,
        fake_images: torch.Tensor,
        labels: torch.Tensor,
        depths: torch.Tensor,
    ) -> torch.Tensor:
        ...


class Discriminator(nn.Module):
    def __init__(self, input_channels: int) -> None:
        ...

    def forward(self, inputs: torch.Tensor) -> torch.Tensor:
        ...

其余代码将在 LightningModule 内实现。 在这里,我们在 __init__方法中传递所有超参数,在该方法中我们实例化了 4 个神经网络,以及损失和指标。 卷积层的权重从正态分布初始化,任务网络的权重除外,其中权重可以预先训练,例如使用 ImageNet 数据集。

class GIOAda(pl.LightningModule):
    REAL_LABEL = 1
    FAKE_LABEL = 0

    def __init__(
        self,
        num_classes: int,
        pretrained_backbone: bool,
        init_lr: float,
        betas: tuple[float, float],
        num_epochs: int,
        num_steps_per_epoch: int,
        lam_input: float,
        lam_output: float,
        lam_depth: float,
    ) -> None:
        super().__init__()

        self.save_hyperparameters() # saved in the dictionary self.hparams
        # disabling automatic optimization, as it willl be made manually
        self.automatic_optimization = False

        self.task_network = TaskNetwork(
            input_channels=3,  # RGB Channels
            num_classes=num_classes,  # Classes
            pretrained_backbone=pretrained_backbone,
        )
        self.fake_transformation = ImageTransformNetwork(
            input_channels=num_classes + 4,  # RGB Channels + Classes + Depth
            output_channels=3,  # RGB Channels
        )
        self.input_discriminator = Discriminator(
            input_channels=3,  # RGB Channels
        )
        self.output_discriminator = Discriminator(
            input_channels=num_classes + 1,  # Classes + Depth
        )

        self.depths_loss = nn.L1Loss()
        self.labels_loss = nn.CrossEntropyLoss()
        self.discriminator_loss = nn.BCELoss()

        self.miou_index = MulticlassJaccardIndex(num_classes)

        self.weight_init(pretrained_backbone=pretrained_backbone)

    def weight_init(self, pretrained_backbone: bool = False):
        for name, module in self.named_modules():
            if "task" in name and pretrained_backbone:
                continue
            if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
                module.weight.data.normal_(0, 0.001)
                if module.bias is not None:
                    module.bias.data.zero_()

然后我们定义优化器和学习率调度器。 我们需要一个优化器来处理“生成器”的权重,即生成器和任务网络,以及另一个优化器来处理鉴别器的权重。 作为学习率调度程序,我们将使用 OneCycle 策略,该策略在训练的第一部分通过提高学习率和降低动量来“预热”网络,从而允许早期探索权重空间并找到更好的起点 观点。 然后,在最后部分,通过余弦退火策略降低学习率。

    def configure_optimizers(
        self,
    ) -> tuple[
        list[torch.optim.Adam], list[torch.optim.lr_scheduler.OneCycleLR]
    ]:
        params_g = itertools.chain(
            self.fake_transformation.parameters(),
            self.task_network.parameters(),
        )
        params_d = itertools.chain(
            self.input_discriminator.parameters(),
            self.output_discriminator.parameters(),
        )
        optimizer_g, lr_sched_g = self._optimizer_lr_scheduler(params_g)
        optimizer_d, lr_sched_d = self._optimizer_lr_scheduler(params_d)
        return [optimizer_g, optimizer_d], [lr_sched_g, lr_sched_d]

    def _optimizer_lr_scheduler(
        self,
        parameters: Iterator[torch.nn.Parameter],
    ) -> tuple[torch.optim.Adam, torch.optim.lr_scheduler.OneCycleLR]:
        optimizer = torch.optim.Adam(
            parameters,
            lr=self.hparams["init_lr"],
            betas=self.hparams["betas"],
        )
        lr_sched = torch.optim.lr_scheduler.OneCycleLR(
            optimizer,
            max_lr=self.hparams["init_lr"],
            epochs=self.hparams["num_epochs"],
            steps_per_epoch=self.hparams["num_steps_per_epoch"],
            base_momentum=self.hparams["betas"][0],
        )
        return optimizer, lr_sched

训练步骤接收输入:

  • 从真实数据集中采样的一批真实图像
  • 一批合成图像,以及从合成数据集中采样的相应标签和深度图。

然后它执行 2 个操作:

  • 鉴别器的优化步骤,需要所有输入
  • 生成器的优化步骤,仅需要合成数据

步骤的顺序对于确保模型的收敛至关重要。 由于生成器更容易崩溃,我们应该让判别器“引导”训练路径。 这样,在生成器步骤中,鉴别器的工作会更好一点,为生成器留下“更好”的梯度。

    def training_step(self, batch: tuple[torch.Tensor, ...]) -> None:
        optimizer_g, optimizer_d = self.optimizers()  
        real_images, fake_images, labels, depths = batch

        # Update D network: minimize log(D(x)) + log(1 - D(G(z)))
        self.toggle_optimizer(optimizer_d)
        optimizer_d.zero_grad()
        self._discriminator_step(real_images, fake_images, labels, depths)
        optimizer_d.step()
        self.untoggle_optimizer(optimizer_d)

        # Update G network: maximize log(D(G(z))) and minimize task loss
        self.toggle_optimizer(optimizer_g)
        optimizer_g.zero_grad()
        self._generator_step(fake_images, labels, depths)
        optimizer_g.step()
        self.untoggle_optimizer(optimizer_g)

鉴别器步骤只是最小化鉴别器输出的二元交叉熵损失。 首先在真实批次上完成此操作,其中预期标签全部为 1,然后在合成批次上完成,其中预期标签全部为零。

    def _discriminator_step(
        self,
        real_images: torch.Tensor,
        fake_images: torch.Tensor,
        labels: torch.Tensor,
        depths: torch.Tensor,
    ) -> None:
        disc_lab = _labels(real_images, self.REAL_LABEL)
        disc_input = self.input_discriminator(real_images)
        disc_output = self.output_discriminator(
            torch.concat(self.task_network(real_images), dim=1)
        )
        loss_input = (
            self.discriminator_loss(disc_input, disc_lab)
            * self.hparams["lam_input"]
        )
        loss_output = (
            self.discriminator_loss(disc_output, disc_lab)
            * self.hparams["lam_output"]
        )
        self.manual_backward(loss_input + loss_output)

        transformed = self.fake_transformation(fake_images, labels, depths)
        disc_lab = _labels(transformed, self.FAKE_LABEL)
        disc_input = self.input_discriminator(transformed)
        disc_output = self.output_discriminator(
            torch.concat(self.task_network(transformed), dim=1)
        )
        loss_input = (
            self.discriminator_loss(disc_input, disc_lab)
            * self.hparams["lam_input"]
        )
        loss_output = (
            self.discriminator_loss(disc_output, disc_lab)
            * self.hparams["lam_output"]
        )
        self.manual_backward(loss_input + loss_output)
        
        # Log losses and metrics
        # ...

相反,生成器步骤最小化标签的交叉熵损失和深度估计的 L1Loss,并且还最大化鉴别器的二元交叉熵损失。 这是通过使用与之前相反的标签计算损失来完成的,因此所有标签都用于合成输入。 没有必要计算实际输入的损失,因为生成器的权重对此输出没有影响。

    def _generator_step(
        self,
        fake_images: torch.Tensor,
        labels: torch.Tensor,
        depths: torch.Tensor,
    ) -> None:
        # Set disc_lab = REAL in order to maximize the loss for the 
        # discriminator when inputs are all fakes
        disc_lab = _labels(fake_images, self.REAL_LABEL)

        # Forward pass on all the networks to collect gradients for G
        transformed = self.fake_transformation(fake_images, labels, depths)
        fake_mask, fake_depth = self.task_network(transformed)
        disc_input = self.input_discriminator(transformed)
        disc_output = self.output_discriminator(
            torch.concat((fake_mask, fake_depth), dim=1)
        )

        # Calculate losses
        loss_input = (
            self.discriminator_loss(disc_input, disc_lab)
            * self.hparams["lam_input"]
        )
        loss_output = (
            self.discriminator_loss(disc_output, disc_lab)
            * self.hparams["lam_output"]
        )
        loss_depths = (
            self.depths_loss(fake_depth, depths) * self.hparams["lam_depth"]
        )
        loss_labels = self.labels_loss(fake_mask, labels)

        # Calculate Gradients
        self.manual_backward(
            loss_input + loss_output + loss_depths + loss_labels
        )

        # Log losses and metrics
        # ...

5、结束语

事实证明,这里解释的方法在各种数据集上都非常有效。 在下图中,我们可以看到,利用 sintetic 数据训练的模型优于仅在小型 KITTI 数据集上训练的模型。 从大量合成数据中获取的知识使模型能够从真实图像中提取更细粒度的细节。

在这里插入图片描述

图 1.3 — KITTI 数据集上的语义分割定性结果。 从左到右:左:输入图像,中:非自适应结果,右:GIO-Ada 方法的结果。

该算法也有一些缺点。 首先,对抗性训练可能非常不稳定,这可以从之前看到的不寻常的训练步骤中猜测出来。 因此,详尽的超参数搜索对于获得良好结果至关重要。 另一个主要问题是训练生成网络是一项内存非常密集的工作,尤其是对于高分辨率图像。

最新的研究集中在其他方法(例如自学习)上,利用变压器层中注意力机制的强泛化特性以及特定领域的数据增强技术。

尽管如此,生成方法(例如本文中讨论的生成方法)由于易于适应新领域以及生成学习研究的不断发展,继续在该领域占据一席之地。


原文链接:用合成数据进行语义分割 — BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/122775.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

ros1 模拟客户端生成小乌龟服务请求生成小乌龟

模拟客户端生成小乌龟服务请求生成小乌龟 一、话题模型二、创建功能包三 创建客户端Client代码四 配置CMakeLists.txt编译规则:五 测试启动ros 主服务启动小乌龟的服务启动模型客户端服务 一、话题模型 Sever端是海龟仿真器/turtlesim,Client端是待实现…

Android 13.0 Launcher3 app图标长按去掉应用信息按钮

1.前言 在13.0的rom定制化开发中,在Launcher3定制化开发中,对Launcher3的定制化功能中,在Launcher3的app列表页会在长按时,弹出微件和应用信息两个按钮,点击对应的按钮跳转到相关的功能页面, 现在由于产品需求要求禁用应用信息,不让进入到应用信息页面所以要去掉应用信息…

BigDecimal使用的时候需要注意什么?

BigDecimal只要涉及到浮点数运算都会用到BigDecimal,并且面试的时候经常会问到,那么BigDecimal使用的时候需要注意什么? 目录 1.为什么不能用浮点数表示金额?2.十进制转换二进制3.科学记数法4.IEEE 7545.在线浮点数转换二进制6.原…

限流式保护器在养老院火灾预防中的应用

安科瑞 华楠 【摘要】老年人是一个庞大特殊的社会群体。随着我国人口的老龄化,老年人口数量断上升。涉及老年人的火灾越来越多,本文从养老院火灾的案例、成因、预防措施等方面对此类火灾进行了深入的探讨。 【关键词】老年公寓;火灾预防&…

VINS-Mono-后端优化 (一:预积分残差计算-IMU预积分约束)

这里先回顾一下预积分是怎么来的 VINS-Mono-IMU预积分 (三:为什么要预积分预积分推导) 这里贴出预积分的公式 具体含义解释看对对应的文章 整个误差函数如下 预积分 α \alpha α β \beta β γ \gamma γ 是用 IMU 预积分获得的增量&a…

CSS关于默认宽度

所谓的默认宽度&#xff0c;就是不设置width属性时&#xff0c;元素所呈现出来的宽度 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title></title><style>* {margin: 0;padding: 0;}.box {/…

Hive 常用存储、压缩格式

1. Hive常用的存储格式 TEXTFI textfile为默认存储格式 存储方式&#xff1a;行存储 磁盘开销大 数据解析开销大 压缩的text文件 hive 无法进行合拆分 SEQUENCEFILE sequencefile二进制文件&#xff0c;以<key,value>的形式序列到文件中 存储方式&#xff1a;行存储 可…

STM32中断简介

中断系统 中断&#xff1a;在主程序运行过程中&#xff0c;出现了特定的中断触发条件&#xff08;中断源&#xff09;&#xff0c;使得CPU暂停当前正在运行的程序&#xff0c;转而去处理中断程序&#xff0c;处理完成后又返回原来被暂停的位置继续运行&#xff1b; 以上是中断的…

Docker - 镜像

Docker - 镜像 镜像是什么 镜像是一种轻量级&#xff0c;可执行的独立软件包&#xff0c;用来打包软件运行环境和基于运行环境开发的软件&#xff0c;它包含运行某个软件所需的所有内容&#xff0c;包括代码&#xff0c;运行时&#xff0c;库&#xff0c;环境变量和配置文件。…

电脑风扇控制软件 Macs Fan Control Pro mac中文版功能介绍

Macs Fan Control mac是一款专门为 Mac 用户设计的软件&#xff0c;它可以帮助用户控制和监控 Mac 设备的风扇速度和温度。这款软件允许用户手动调整风扇速度&#xff0c;以提高设备的散热效果&#xff0c;减少过热造成的风险。 Macs Fan Control 可以在菜单栏上显示当前系统温…

如何在Python爬虫中使用IP代理以避免反爬虫机制

目录 前言 一、IP代理的使用 1. 什么是IP代理&#xff1f; 2. 如何获取IP代理&#xff1f; 3. 如何使用IP代理&#xff1f; 4. 如何避免IP代理失效&#xff1f; 5. 代理IP的匿名性 二、代码示例 总结 前言 在进行爬虫时&#xff0c;我们很容易会遇到反爬虫机制。网站…

聊一聊被人嘲笑的if err!=nil和golang为什么要必须支持多返回值?

golang多返回值演示 我们知道&#xff0c;多返回值是golang的一个特性&#xff0c;比如下面这段代码,里面的参数名我起了几个比较好区分的 package mainfunc main() {Swap(10999, 10888) }func Swap(saaa, sbbb int) (int, int) {return sbbb, saaa }golang为什么要支持多返回…

IP代理识别API:预防欺诈和保护网络安全的必要工具

引言 随着互联网的快速发展&#xff0c;我们的生活变得越来越依赖于网络。然而&#xff0c;随着网络的发展&#xff0c;网络犯罪和网络欺诈也在不断增加。为了保护自己的网站和客户免受网络欺诈的侵害&#xff0c;许多企业和组织开始使用IP代理识别API作为一种必要工具。 什么…

jenkins结合k8s部署动态slave

1、完成k8s连接 在完成jenkins的部署后现安装kubernets的插件 如果jenkins 是部署在k8s集群中只需要填写一下 如果是非本集群的部署则需要填写证书等 cat ./config echo ‘certificate-authority-data-value’ | base64 -d > ./ca.crt echo ‘client-certificate-data’ |…

第二次pta认证P测试C++

#include <iostream> using namespace std; int f(int n){if (n0){return 1;}if (n1){return 3;}return 4*f(n-1)-f(n-2); } int n; int main() {cin>>n;cout<<f(n);return 0; }第二题 试题编号&#xff1a;2022-13-0302 试题名称&#xff1a;长正整数相加 …

springcloud小说阅读网站源码

开发工具&#xff1a; 大等于jdk1.8&#xff0c;大于mysql5.5&#xff0c;nodejs&#xff0c;idea&#xff08;eclipse&#xff09;&#xff0c;vscode&#xff08;webstorm&#xff09; 技术说明&#xff1a; springcloud springboot mybatis vue elementui 功能介绍&…

2023年9月少儿编程 中国电子学会图形化编程等级考试Scratch编程二级真题解析(判断题)

2023年9月scratch编程等级考试二级真题 判断题(共10题,每题2分,共20分) 26、下列两个程序运行效果一样 答案:对 考点分析:考查积木综合使用,重点考查重复执行和坐标积木 两个程序都是在x=0,y=100的时候停止,所以正确 27、甲、乙和丙,一位是山东人,一位是河南人,…

2023云栖大会,Salesforce终敲开中国CRM市场

2015年被视为中国CRM SaaS元年&#xff0c;众多CRM SaaS创业公司和厂商在Salesforce的榜样作用下涌入了CRM SaaS赛道。在全球市场&#xff0c;Salesforce是CRM SaaS领域的领导厂商&#xff0c;连续多年占据了全球CRM SaaS第一大厂商地位。然而&#xff0c;Salesforce作为业务类…

【送书福利-第二十六期】机械工业出版社《算法秘籍》~

&#x1f60e; 作者介绍&#xff1a;我是程序员洲洲&#xff0c;一个热爱写作的非著名程序员。CSDN全栈优质领域创作者、华为云博客社区云享专家、阿里云博客社区专家博主、前后端开发、人工智能研究生。公粽号&#xff1a;程序员洲洲。 &#x1f388; 本文专栏&#xff1a;本文…