RPN(Region Proposal Networks)候选区域网络算法解析(附PyTorch代码)

0. 前言

按照国际惯例,首先声明:本文只是我自己学习的理解,虽然参考了他人的宝贵见解及成果,但是内容可能存在不准确的地方。如果发现文中错误,希望批评指正,共同进步。

本文基于论文Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks对RPN候选区域网络进行解析说明,并基于PyTorch库对RPN网络进行编程。

创作本文的背景:

在我的专栏【PyTorch实例实战演练】中,已经有了4篇左右文章专门介绍图像分类相关的深度学习模型算法,对这类任务已经比较熟悉了。更进一步地,我想再学习掌握目标检测任务。

目标检测任务相比图像分类任务更加复杂,因为前者不仅要像后者一样找到特征并分类,而且还要定位不同特征的位置。由于图像分类算法已经有比较成熟的研究基础,因此目标检测任务的难点不在于分类,而在于定位!

1. 区域候选算法(RPA,Region Proposal Algorithm)

为了定位图像中的目标分类物体位置,我们可以在整个图像或者放大后的部分图像上移动一个固定大小的矩形窗口,对于每一个窗口位置,都会运行分类器来判断窗口内是否包含目标分类的对象。这种方法称为滑动窗口法,滑动窗口法是一种早期广泛应用于目标检测的传统计算机视觉技术。

滑动窗口法的缺点十分明显——过于简单粗暴!会在大量的没有意义的位置上浪费检测目标的时间。如果不采用一些“聪明的”方法先预选一些比较有可能出现检测目标的位置,或者说如果不让分类器“注意力”集中在可能出现检测目标的位置,实时的目标检测方法就不可能,我们现在诸如依赖机器视觉的自动驾驶等技术也就也可能。

我们把此类能“预选可能出现检测目标位置”的算法称为RPA(Region Proposal Algorithm)区域候选算法,常见的RPA有:

Region Proposal Algorithms (RPA) 是计算机视觉领域中的重要组件,它们主要用于目标检测任务中生成候选区域(region proposals),这些候选区域可能是潜在的对象位置。以下是几个常见的Region Proposal方法:

  1. Selective Search, SS:通过图像金字塔和图像分块策略,进行一系列的过分割、合并操作,生成高质量的候选框。
  2. Edge Boxes, EB:该算法根据边界框内的边缘密度来生成候选框,倾向于选择包含显著物体边界的框。
  3. Region Proposal Network, RPN:首次在Faster R-CNN中引入,RPN是一个全卷积网络,它可以并行地生成多个候选框及其相应的置信度评分,极大地提高了目标检测的速度和性能。

本文要介绍的算法就是RPN。

2. 区域候选网络(RPN,Region Proposal Algorithm)

在正式详细说明RPN之前,我想先总结下Faster R-CNN: Towards Real-Time Object Detection with Region Proposal Networks这篇文章的几点精华,提炼要点:

  • RPN是一个全卷积网络,能同时预测候选区域的边界置信度
  • RPN的提出目的是为了提升目标检测的效率(速度),使用RPN后与检测任务相比,候选区域的选择几乎是不耗费计算资源(nearly cost-free);
  • RPN能提升目标检测效率的原因是RPN与分类器共用要检测的原图的卷积特征图
  • 首次提出锚框(anchor)的概念,在锚框提出之前候选框的选择只能通过各种尺寸的图像(pyramids of images)或者各种尺寸的算子(pyramids of filters),这也是为了提升目标检测的效率。
2.1 RPN背景说明

在当时(2016年),先进的目标检测网络如SPPnet、Fast R-CNN提出后,减少了目标识别任务的计算时间,使得区域候选成为了计算瓶颈。当时已有的区域候选算法:Selective Search, SS仅能做到2s完成一张图像的区域候选;EdgeBoxes虽然更快,能做到0.2s完成,但是这是以牺牲计算质量为代价的。

通过发现:给识别器(detector)做图像识别的卷积特征图也可以共用于生成候选区域,这样做可以减少计算生成候选区域的时间(少到每张图像的候选区域生成仅需要10ms左右),于是基于此提出了RPN网络。使用RPN区域候选+深度学习图像分类组成的目标检测识别算法能做到每秒完成5个目标检测任务,大大提升了检测速率。

2.2 RPN架构

RPN在目标检测算法Faster R-CNN中的架构如下图:

Faster R-CNN = Fast R-CNN + RPN

RPN是整个 Faster R-CNN 模型的重要组成部分,原始图像首先通过一系列卷积层进行特征提取和处理得到特征图,RPN基于该特征图生成可能包含感兴趣对象的候选区域,以便后续的分类器能把“注意力”集中在这些后续区域内,提升整个目标检测任务的效率。

然后我们再来看下RPN本身的架构:

RPN是一个全卷积网络,模型计算过程如下:

  1. 使用一个n×n滑动窗口(sliding window)在共享的卷积特征图上滑动并提取特征。每个滑动窗口的中间位置都对应有k个锚框(anchors),在本文中使用3种缩放比例×3种长宽比,共k=9种锚框
  2. 对于每一个滑动窗口,通过卷积计算将其转换为一个固定维度向量(例如基础CNN为ZFnet时,该向量长度为256;基础CNN为VGG时,该向量长度为512),通常称为"中间层"或"隐藏层";
  3. 然后使用这个中间层作为输入,分别经过两个孪生全连接层分支(two sibling fully connected layers):分类层和回归层;
  4. 分类层输出的是每个格子属于各个类别(如人、车等)的概率分布,需要注意的是这里输出的是有物体和无物体,即正样本(前景)和负样本(背景)两个概率,因此输出长度为2k;而回归层则输出的是每个格子中心点相对于真实边界框偏移的距离值,中心坐标偏移加上长宽偏移共4个值,因此输出长度为4k

当然,分类层也可以只输出有物体的概率,输出长度为k。原文有一行注释:For simplicity we implement the cls layer as a two-class softmax layer. Alternatively, one may use logistic regression to produce k scores.

RPN的一个重要特性就是平移不变性(translation invariant),带来的好处就是减小模型,更小的参数量能进一步提高计算速度。

2.2 RPN的损失函数

RPN的损失函数为:

L = \frac{1}{N_{cls}}\sum L_{cls}(p_i, p^*_i) + \lambda \frac{1}{N_{reg}}\sum p^*_i L_{reg}(t_i, t^*_i)

  • N_{cls}, N_{reg}为归一化参数,λ为平衡分类损失和回归损失的系数;
  • L_{cls}为分类损失函数,L_{reg}为回归损失函数;
  • 角标i代表第i个锚框;
  • p_i为模型预测的第i个锚框中有物体的概率;
  • p^*_i为第i个锚框中是否有物体的ground truth,当①锚框与真实物体的IoU=0.7 或②锚框与真实物体的IoU达到最大为正样本,p^*_i = 1。当③锚框与真实物体的IoU<0.3(且不为最大)为负样本,p^*_i = 0

如果p^*_i不满足①②③中任意条件,比如p^*_i为0.4,且不为最大值怎么办呢?

答:这种样本不会被选来训练。在训练RPN的时候,每张图像会随机选取256个锚框进行训练,其中正样本和负样本为1:1,选取的样本肯定满足①②③条件之一。

  • t_i是一个四维向量,代表预测的bounding box相对于第i个锚框x, y, w, h的偏置;
  • t^*_i是一个四维向量,代表ground truth相对于第i个锚框x, y, w, h的偏置;

t_i = [t_x, t_y, t_w, t_h]其中t_x = (x - x_a)/w_at_y = (y - y_a)/h_at_w = log(w/w_a)t_h = log(h/h_a)

t^*_i = [t^*_x, t^*_y, t^*_w, t^*_h]其中t^*_x = (x^* - x_a)/w_at^*_y = (y^* - y_a)/h_at^*_w = log(w^*/w_a)t^*_h = log(h^*/h_a)

x, y 为框的中心点坐标,w, h为框的宽和高。无角标、角标为a、角标为*分别代表预测的bounding box、锚框、ground truth的框的几何尺寸。

2.3 训练RPN

在Faster R-CNN的目标检测框架中,交替训练(Alternating training)是一种分阶段优化RPN(Region Proposal Network)和Fast R-CNN检测器的策略。以下是详细的交替训练过程:

第一步:初始化与训练RPN
首先,仅使用卷积特征层训练RPN网络。RPN是一个全卷积网络,它直接从基础CNN网络(如VGG或ResNet)提取的特征图上生成候选区域(region proposals)。RPN通过学习调整一组预先定义好的锚框来预测每个锚框是否包含对象以及调整其边界框回归参数。

第二步:使用RPN生成提议并训练Fast R-CNN
使用第一步训练得到的RPN生成大量的候选区域提议,并从中采样固定数量的高质量提议用于训练Fast R-CNN分类器和边框回归器。Fast R-CNN是一个端到端可训练的模型,它对这些候选区域进行分类和精调边界框位置。

第三步:使用Fast R-CNN初始化RPN并重新训练
将经过Fast R-CNN训练后更新了卷积特征层的整个网络(包括基础CNN部分)作为初始化权重,再次启动RPN的训练。在这个迭代过程中,Fast R-CNN的部分被冻结(即保持参数不变),只训练RPN部分,利用更高质量的特征来进一步优化RPN的提议生成能力。

第四步:重复迭代
这个过程可以多次迭代,每次都是将更新后的模型反过来影响RPN的训练,然后再用改进后的RPN提议去提升Fast R-CNN的表现,直到两个子网络收敛达到较好的联合性能。

通过这种方式,RPN和Fast R-CNN可以互相促进,共同优化目标检测的整体性能。在实际训练过程中,可能需要几个交替迭代周期才能使得RPN产生的提议质量和最终检测结果达到最优状态。

3. 基于PyTorch框架的RPN

注意:由于RPN与其他网络模型关系十分密切,难易独立分割开来。以下代码只是一个非常基础的示意,并未包含诸如锚框生成、前向传播中的空间尺寸变换、损失函数定义以及后处理步骤(如非极大值抑制NMS)等内容。在实际项目中,请参考Faster R-CNN论文或其他开源实现来完善整个RPN模块的功能。

import torch
import torch.nn as nn
from torchvision.models import vgg16
import torch.nn.functional as F

# RPN网络定义
class RPN(nn.Module):  #softmax???   长宽计算方法??
    def __init__(self, in_channels=512):
        super(RPN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, 512, kernel_size=3, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.cls_score = nn.Conv2d(512, 2 * 9, kernel_size=1)  #k=9
        self.bbox_pred = nn.Conv2d(512, 4 * 9, kernel_size=1)

    def forward(self, x):
        x = self.relu(self.conv1(x))
        rpn_cls_score = self.cls_score(x)  #生成预测正负样本(有无物体的概率)[pos_label, neg_label]
        rpn_cls_score_softmax = F.softmax(rpn_cls_score)
        rpn_bbox_pred = self.bbox_pred(x)  #生成bounding box的[tx, ty, tw, th]
        return rpn_cls_score_softmax, rpn_bbox_pred

# 使用预训练的VGG16作为特征提取器
backbone = vgg16(pretrained=True).features[:-1]  # 模型微调,去掉最后一个池化层
rpn = RPN(in_channels=512)  # VGG16最后一层输出维度为512

# 定义损失函数和优化器
cls_loss_func = nn.CrossEntropyLoss()
bbox_loss_func = nn.SmoothL1Loss()
optimizer = torch.optim.SGD(rpn.parameters(), lr=0.001, momentum=0.9)

# 假设我们有一个训练数据加载器,格式为images, (gt_boxes, gt_labels)
data_loader = ...

# 训练过程,这里仅说明RPN的训练过程,即交替训练的第一步!
num_epochs = 10
for epoch in range(num_epochs):
    for images, (gt_boxes, gt_labels) in enumerate(data_loader):
        # 前向传播
        features = backbone(images)
        rpn_cls_scores_softmax, rpn_bbox_preds = rpn(features)

        # 数据预处理,将预测结果调整到与ground truth相匹配的格式
        # 这部分会根据你的具体实现有所不同,这里仅作示例
        rpn_cls_scores_view = rpn_cls_scores_softmax.permute(0, 2, 3, 1).contiguous().view(-1, 2)
        rpn_bbox_preds_view = rpn_bbox_preds.permute(0, 2, 3, 1).contiguous().view(-1, 4)
        gt_labels_view = gt_labels.view(-1)

        # 计算损失
        rpn_cls_loss = cls_loss_func(rpn_cls_scores_view, gt_labels_view)
        rpn_bbox_loss = bbox_loss_func(rpn_bbox_preds_view, gt_boxes)

        # 总损失,这里暂时忽略Ncls,Nreg,λ
        loss = rpn_cls_loss + rpn_bbox_loss

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 打印训练信息
        if (images + 1) % 100 == 0:
            print(f'Epoch [{epoch+1}/{num_epochs}], Step [{images+1}/{len(data_loader)}], Loss: {loss.item()}')

# 训练结束
print('Training finished.')

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/318195.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

大一统的监控探针采集器 cprobe

需求背景 监控数据采集领域,比如 Prometheus 生态有非常多的 Exporter,虽然生态繁荣,但是无法达到开箱即用的大一统体验,Exporter 体系的核心问题有: 良莠不齐:有的 Exporter 写的非常棒,有的…

uniapp怎么开发插件并发布

今天耳机坏了,暂时内卷不了,所以想开发几个插件玩玩,也好久没写博客了,就拿这个来写了 首先,发布插件时需要你有项目 这里先拿uniapp创建一个项目, 如下,创建好的项目长这样 然后根据uniapp官网上说的,我们发布插件时,需要在uni_modules里面编写和发布 ps:还需要使用uniapp…

计算机组成原理期末复习

文章目录 第一章:计算机系统漫游编译系统进程线程之间的关系存储器层次结构虚拟地址 第二章:信息的表示和处理大端与小端整数运算浮点数运算 第三章:程序的机器级表示栈的压入和弹出算数与逻辑运算操作指令条件判断与循环 第六章:…

U盘抜太快打不开恢复方法

U盘是一种常用的存储设备,由于其便携性和大容量等特点,被广泛应用于数据存储和传输。然而,有时候我们会遇到U盘拔出后无法再次使用的问题,这通常是由于U盘拔出太快导致的。本文将深入探讨U盘拔太快打不开的原因,并提供…

【elfboard linux 开发板】10. 设备树与烧录

1. 设备树介绍 设备树由一系列被命名的node和property组成 可以描述如下信息: CPU的数量和类别内存基地址和大小总线和桥外设连接中断控制器和中断使用情况GPIO控制器和GPIO使用情况Clock 控制器和 Clock 使用情况 由dts文件以文本方式对系统设备树进行描述&…

GPT-4V的图片识别和分析能力原创

GPT-4V是OpenAI开发的大型语言模型,是GPT-4的升级版本。GPT-4V在以下几个方面进行了改进: 模型规模更大:GPT-4V的参数量达到了1.37T,是GPT-4的10倍。 训练数据更丰富:GPT-4V的训练数据包括了1.56T的文本和代码数据。 …

Ubuntu 22.04 安装prometheus

服务器监控和报警软件有很多,为什么我们会选择Prometheus而不是其他软件呢? 因为它有以下优点: 自带简易web监控页面,用户可以很方便地查看监控数据和使用仪表盘。能实时收集数据并根据自定义警报规则推送告警;具有丰…

力扣双周赛第三题----2857. 统计距离为 k 的点对

这题我们的暴力做法就是o(n^2),但是根据数据量这样会超时,所以我们不能用暴力解法去解决 那么想一想双指针可以吗,不可以。为什么呢?因为他没有一个特性可以让他双指针跳过前面或者后面一个点。比如他们数组有顺序的情况下,还有一…

【JaveWeb教程】(23) MySQL数据库开发之事务与索引 详细代码示例讲解(最全面)

目录 2. 事务2.1 介绍2.2 操作2.3 四大特性 3. 索引3.1 介绍3.2 结构3.3 语法 2. 事务 场景:学工部整个部门解散了,该部门及部门下的员工都需要删除了。 操作: -- 删除学工部 delete from dept where id 1; -- 删除成功-- 删除学工部的员工…

某厂校招一道关于C的笔试题

一、笔试原题 题目:在Linux x86 _ 54 gcc环境下,下面的程序会出现什么问题?运行结果是什么?为什么? 程序如下图: 通过在gcc的环境下编译运行,发现运行结果为不断死循环打印0-17的数字 我们…

遥感影像-语义分割数据集:高分卫星-云数据集详细介绍及训练样本处理流程

原始数据集详情 简介:该云数据集包括RGB三通道的高分辨率图像,包含高分一、高分二及宽幅数据集。 KeyValue卫星类型高分系列覆盖区域未知场景未知分辨率1m、2m、8m数量12000单张尺寸1024*1024原始影像位深8位标签图片位深8位原始影像通道数三通道标签图…

云卷云舒:AI for DB、DB for AI

云卷云舒:算力网络云原生(下):云数据库发展的新篇章-CSDN博客https://blog.csdn.net/bishenghua/article/details/135050556 随着数据库和AI技术的分支同向演进,AI 和数据库间的关联越发紧密了。 大模型的演进发展&a…

为何资深程序员都离不开 requirements.txt?你还在为环境配置发愁吗?

requirements.txt 文件是一个用于记录 Python 包依赖的文件,它能够帮助我们快速配置开发环境。在迁移到新的开发环境时,通常需要逐个使用 pip install 命令安装各种包,这个过程既耗时又可能出现错误。 而 requirements.txt 文件可以让我们一…

通过IP地址识别风险用户

随着互联网的迅猛发展,网络安全成为企业和个人关注的焦点之一。识别和防范潜在的风险用户是维护网络安全的关键环节之一。IP数据云将探讨通过IP地址识别风险用户的方法和意义。 IP地址的基本概念:IP地址是互联网上设备的独特标识符,它分为IP…

靶场实战(14):OSCP备考之VulnHub SUNSET NOONTIDE

打靶思路 资产发现 主机发现服务发现漏洞发现(获取权限) irc服务提升权限 server用户 sudosuidcron内核提权信息收集 1、资产发现 1.1、主机发现 本次靶场SUNSET: NOONTIDE[1]指定IP,不涉及主机发现过程。 1.2、服务发现 使用命令sudo -u roo…

Kubernetes 集群管理—日志架构

日志架构 应用日志可以让你了解应用内部的运行状况。日志对调试问题和监控集群活动非常有用。 大部分现代化应用都有某种日志记录机制。同样地,容器引擎也被设计成支持日志记录。 针对容器化应用,最简单且最广泛采用的日志记录方式就是写入标准输出和标…

调用openai实现聊天功能

📑前言 本文主要是【聊天机器人】——调用openai实现聊天功能的文章,如果有什么需要改进的地方还请大佬指出⛺️ 🎬作者简介:大家好,我是听风与他🥇 ☁️博客首页:CSDN主页听风与他 &#x1f3…

编译和链接(2)

3. 预处理详解 3.2#define 3.2.5带副作用的宏参数 当宏参数在宏的定义中出现超过一次的时候,如果参数带有副作用,那么你在使用这个宏的时候就可能 出现危险,导致不可预测的后果。副作用就是表达式求值的时候出现的永久性效果。 例如&…

day16 二叉树的最大深度 n叉树的最大深度 二叉树的最小深度 完全二叉树的节点数

题目1:104 二叉树的最大深度 题目链接:104 二叉树的最大深度 题意 二叉树的根节点是root,返回其最大深度(从根节点到最远叶子节点的最长路径上的节点数) 递归 根节点的的高度就是二叉树的最大深度 所以使用后序遍…

【Minio】常见问题解决思路

检查存储服务器对应的端口与应用服务器是否能够互通,通过ping|telnet命令检查、查看防火墙端口是否开放,检查防火墙端口linux系统和windows系统各有不同。检查电脑上的杀毒软件是否限制了网络端口和文件权限问题。检查minio配置信息是否正确,…