YOLO即插即用---PKIBlock

Poly Kernel Inception Network for Remote Sensing Detection

论文地址

1. 解决的问题

2. 解决方案

3. 解决问题的具体方法

4. 模块的应用

5. 在目标检测任务中的添加位置

6.即插即用代码


论文地址

2403.06258icon-default.png?t=O83Ahttps://arxiv.org/pdf/2403.06258

1. 解决的问题

遥感图像目标检测面临着两大挑战:

  • 目标尺度变化大: 遥感图像通常包含各种尺度差异巨大的目标,例如足球场、建筑物、车辆等。这种尺度变化对目标检测算法提出了挑战,因为传统的目标检测方法往往难以有效地处理不同尺度的目标。

  • 场景背景复杂多样: 遥感图像通常包含复杂的背景信息,例如山脉、河流、植被等。这些背景信息会对目标检测造成干扰,使得算法难以准确地识别和定位目标。

2. 解决方案

PKINet 通过以下两种机制来解决上述挑战:

  • 多尺度卷积核 (PKI 模块):

    • PKI 模块使用并行排列的不同尺寸深度可分离卷积核,提取不同感受野的多尺度纹理特征。这种设计可以有效地捕捉到不同尺度目标的特征,从而提高模型对不同尺度目标的检测能力。

    • 深度可分离卷积可以有效地降低模型的计算量和参数量,使得模型更加轻量化。

  • 上下文锚点注意力机制 (CAA 模块):

    • CAA 模块利用全局平均池化和 1D 线条卷积来捕捉远距离像素之间的关系,并增强中心区域的特征。这种设计可以帮助模型更好地理解目标的上下文信息,从而提高模型的检测精度。

    • 1D 线条卷积可以有效地捕捉长距离像素关系,而不会引入过多的噪声和计算量。

3. 解决问题的具体方法

  • PKI 模块:

    • 局部特征提取: 使用 3x3 卷积提取局部特征,捕获目标的细节信息。

    • 多尺度上下文特征提取: 使用不同尺寸 (3x3, 5x5, 7x7, 9x9, 11x11) 的深度可分离卷积提取不同感受野的上下文特征,捕捉目标周围的环境信息。

    • 特征融合: 使用 1x1 卷积融合局部特征和上下文特征,得到最终的特征表示。

  • CAA 模块:

    • 局部区域特征提取: 使用全局平均池化和 1x1 卷积提取局部区域特征,作为后续计算的基础。

    • 远距离像素关系捕捉: 使用 1D 线条卷积 (水平方向和垂直方向) 代替大核卷积,扩大感受野并捕捉远距离像素关系,从而更好地理解目标的上下文信息。

    • 注意力权重生成: 使用 sigmoid 函数生成注意力权重,并根据注意力权重增强 PKI 模块的输出特征,突出目标区域的特征,抑制背景区域的特征。

4. 模块的应用

PKI 模块和 CAA 模块共同构成了 PKINet 的基本单元,用于提取图像特征。PKINet 可以作为特征提取骨干网络,应用于各种遥感图像目标检测任务,例如检测飞机、舰船、车辆等目标。通过 PKINet 提取的特征可以更好地反映目标的尺度、形状、纹理和上下文信息,从而提高目标检测算法的性能。

5. 在目标检测任务中的添加位置

PKINet 位于目标检测网络的底层,负责提取图像特征。这些特征随后被送入检测头,例如 RPN、RoI Pooling、分类和回归头等,用于识别和定位目标。PKINet 提取的特征可以帮助检测头更好地理解目标的特征,从而提高目标检测算法的精度和鲁棒性

6.即插即用代码

from typing import Optional, Sequence
import torch.nn as nn
import torch
def autopad(kernel_size: int, padding: Optional[int] = None, dilation: int = 1) -> int:
    """Calculate the padding size based on kernel size and dilation."""
    if padding is None:
        padding = (kernel_size - 1) * dilation // 2
    return padding


def make_divisible(value: int, divisor: int = 8) -> int:
    """Make a value divisible by a certain divisor."""
    return int((value + divisor // 2) // divisor * divisor)


class ConvModule(nn.Module):
    def __init__(
            self,
            in_channels: int,
            out_channels: int,
            kernel_size: int,
            stride: int = 1,
            padding: int = 0,
            dilation: int = 1,
            groups: int = 1,
            norm_cfg: Optional[dict] = None,
            act_cfg: Optional[dict] = None):
        super().__init__()
        layers = []
        layers.append(nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=(norm_cfg is None)))
        if norm_cfg:
            norm_layer = self._get_norm_layer(out_channels, norm_cfg)
            layers.append(norm_layer)
        if act_cfg:
            act_layer = self._get_act_layer(act_cfg)
            layers.append(act_layer)
        self.block = nn.Sequential(*layers)

    def forward(self, x):
        return self.block(x)

    def _get_norm_layer(self, num_features, norm_cfg):
        if norm_cfg['type'] == 'BN':
            return nn.BatchNorm2d(num_features, momentum=norm_cfg.get('momentum', 0.1), eps=norm_cfg.get('eps', 1e-5))
        # Add more normalization types if needed
        raise NotImplementedError(f"Normalization layer '{norm_cfg['type']}' is not implemented.")

    def _get_act_layer(self, act_cfg):
        if act_cfg['type'] == 'ReLU':
            return nn.ReLU(inplace=True)
        if act_cfg['type'] == 'SiLU':
            return nn.SiLU(inplace=True)
        # Add more activation types if needed
        raise NotImplementedError(f"Activation layer '{act_cfg['type']}' is not implemented.")

# Update InceptionBottleneck's constructor call to avoid conflicts
class InceptionBottleneck(nn.Module):
    """Bottleneck with Inception module"""
    def __init__(
            self,
            in_channels: int,
            out_channels: Optional[int] = None,
            kernel_sizes: Sequence[int] = (3, 5, 7, 9, 11),
            dilations: Sequence[int] = (1, 1, 1, 1, 1),
            expansion: float = 1.0,
            add_identity: bool = True,
            with_caa: bool = True,
            caa_kernel_size: int = 11,
            norm_cfg: Optional[dict] = dict(type='BN', momentum=0.03, eps=0.001),
            act_cfg: Optional[dict] = dict(type='SiLU')):
        super().__init__()
        out_channels = out_channels or in_channels
        hidden_channels = make_divisible(int(out_channels * expansion), 8)

        self.pre_conv = ConvModule(in_channels, hidden_channels, 1, 1, 0,
                                   norm_cfg=norm_cfg, act_cfg=act_cfg)

        self.dw_conv = ConvModule(hidden_channels, hidden_channels, kernel_sizes[0], 1,
                                  autopad(kernel_sizes[0], None, dilations[0]),
                                  dilation=dilations[0], groups=hidden_channels,
                                  norm_cfg=None, act_cfg=None)
        self.dw_conv1 = ConvModule(hidden_channels, hidden_channels, kernel_sizes[1], 1,
                                   autopad(kernel_sizes[1], None, dilations[1]),
                                   dilation=dilations[1], groups=hidden_channels,
                                   norm_cfg=None, act_cfg=None)
        self.dw_conv2 = ConvModule(hidden_channels, hidden_channels, kernel_sizes[2], 1,
                                   autopad(kernel_sizes[2], None, dilations[2]),
                                   dilation=dilations[2], groups=hidden_channels,
                                   norm_cfg=None, act_cfg=None)
        self.dw_conv3 = ConvModule(hidden_channels, hidden_channels, kernel_sizes[3], 1,
                                   autopad(kernel_sizes[3], None, dilations[3]),
                                   dilation=dilations[3], groups=hidden_channels,
                                   norm_cfg=None, act_cfg=None)
        self.dw_conv4 = ConvModule(hidden_channels, hidden_channels, kernel_sizes[4], 1,
                                   autopad(kernel_sizes[4], None, dilations[4]),
                                   dilation=dilations[4], groups=hidden_channels,
                                   norm_cfg=None, act_cfg=None)
        self.pw_conv = ConvModule(hidden_channels, hidden_channels, 1, 1, 0,
                                  norm_cfg=norm_cfg, act_cfg=act_cfg)

        if with_caa:
            self.caa_factor = CAA(hidden_channels, caa_kernel_size, caa_kernel_size, None, None)
        else:
            self.caa_factor = None

        self.add_identity = add_identity and in_channels == out_channels

        self.post_conv = ConvModule(hidden_channels, out_channels, 1, 1, 0,
                                    norm_cfg=norm_cfg, act_cfg=act_cfg)

    def forward(self, x):
        x = self.pre_conv(x)

        y = x
        x = self.dw_conv(x)
        x = x + self.dw_conv1(x) + self.dw_conv2(x) + self.dw_conv3(x) + self.dw_conv4(x)
        x = self.pw_conv(x)
        if self.caa_factor is not None:
            y = self.caa_factor(y)
        if self.add_identity:
            y = x * y
            x = x + y
        else:
            x = x * y

        x = self.post_conv(x)
        return x

class CAA(nn.Module):
    """Context Anchor Attention"""
    def __init__(
            self,
            channels: int,
            h_kernel_size: int = 11,
            v_kernel_size: int = 11,
            norm_cfg: Optional[dict] = dict(type='BN', momentum=0.03, eps=0.001),
            act_cfg: Optional[dict] = dict(type='SiLU')):
        super().__init__()
        self.avg_pool = nn.AvgPool2d(7, 1, 3)
        self.conv1 = ConvModule(channels, channels, 1, 1, 0,
                                norm_cfg=norm_cfg, act_cfg=act_cfg)
        self.h_conv = ConvModule(channels, channels, (1, h_kernel_size), 1,
                                 (0, h_kernel_size // 2), groups=channels,
                                 norm_cfg=None, act_cfg=None)
        self.v_conv = ConvModule(channels, channels, (v_kernel_size, 1), 1,
                                 (v_kernel_size // 2, 0), groups=channels,
                                 norm_cfg=None, act_cfg=None)
        self.conv2 = ConvModule(channels, channels, 1, 1, 0,
                                norm_cfg=norm_cfg, act_cfg=act_cfg)
        self.act = nn.Sigmoid()

    def forward(self, x):
        attn_factor = self.act(self.conv2(self.v_conv(self.h_conv(self.conv1(self.avg_pool(x))))))
        return attn_factor

# Testing the InceptionBottleneck
if __name__ == "__main__":

    input = torch.randn(1, 64, 128, 128) #输入B C H W
    block = InceptionBottleneck(in_channels=64, out_channels=128)
    output = block(input)
    print(input.size())
    print(output.size())

对模型改进感兴趣的可以进入交流群,群中有答疑(QQ:828370883

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/907047.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

丝杆支撑座的更换与细节注意事项

丝杆支撑座是支撑连接丝杆和电机的轴承支撑座,分固定侧和支撑侧,它们都有用预压调整的JIS5级的交界处球轴承。在自动化设备中是常用的传动装置,作为核心部件,对设备精度、稳定性和生产效率产生直接影响。在长时间运行中&#xff0…

3D Gaussian Splatting代码详解(一):模型训练、数据加载

1 模型训练 这段代码实现了一个 3D 高斯模型的训练循环,旨在通过逐步优化模型参数,使其能够精确地渲染特定场景。以下是代码的详细解析: def training(dataset, opt, pipe, testing_iterations, saving_iterations, checkpoint_iterations,…

Docker-微服务项目部署

环境准备 1.微服务项目 参考:通过网盘分享的文件:wolf2w_cloud.zip 链接: https://pan.baidu.com/s/1Lr4k6LPIJ59gVNA_DgKM_Q?pwdkjxt 提取码: kjxt 前端项目:trip-mgrsite-ui,trip-website-ui,trip-wenda-ui 服务项…

设计模式讲解01-建造者模式(Builder)

1. 概述 建造者模式也称为:生成器模式 定义:建造者模式是一种创建型设计模式,它允许你将创建复杂对象的步骤与表示方式相分离。 解释:建造者模式就是将复杂对象的创建过程拆分成多个简单对象的创建过程,并将这些简单…

HTML 基础标签——文本内容标签 <ul>、<ol>、<blockquote> 、<code> 等标签的用法详解

文章目录 1. 标题标签2. 段落标签3. 文本格式化标签4. 列表标签4.1 无序列表 `<ul>`4.2 有序列表 `<ol>`5. 引用标签5.1 块引用 `<blockquote>`5.2 行内引用 `<q>`5.3 作品引用 `<cite>`6. 代码和预格式文本标签6.1 代码标签 `<code>`6.2 …

(51)MATLAB迫零均衡器系统建模与性能仿真

文章目录 前言一、迫零均衡器性能仿真说明二、迫零均衡器系统建模与性能仿真代码1.仿真代码2.代码说明3.迫零均衡器zf_equalizer的MATLAB源码 三、仿真结果1.信道的冲击响应2.频率响应3.迫零均衡器的输入和输出 前言 使用MATLAB对迫零均衡器系统进行建模仿真&#xff0c;完整的…

前端请求后端接口报错(blocked:mixed-content),以及解决办法

报错原因&#xff1a;被浏览器拦截了&#xff0c;因为接口地址不是https的。 什么是混合内容&#xff08;Mixed Content&#xff09; 混合内容是指在同一页面中同时包含安全&#xff08;HTTPS&#xff09;和非安全&#xff08;HTTP&#xff09;资源的情况。当浏览器试图加载非…

python 包和模块

一、模块 一个.py 文件就是一个模块&#xff0c;模块是含有一系列数据&#xff0c;函数&#xff0c;类等的程序。 1、模块导入 1.1、impotrt 模块名称 [ as 别名] import nunpy as np 1.2、form 模块名 import 模块内属性名 [ as 别名] from datetime import datetime as d…

Git下载-连接码云-保姆级教学(连接Gitee失败的解决)

Git介绍 码云连接 一、Git介绍 二、Git的工作机制 下载链接&#xff1a;Git - 下载软件包 三、使用步骤 创建一个wss的文件夹&#xff0c;作为‘工作空间’ 四、连接码云账号 五、连接Gitee失败的解决方法 一、Git介绍 Git是一个免费的、开源的分布式版本控制…

https和http的区别,及HTTPS的工作流程

HTTP&#xff08;HyperText Transfer Protocol&#xff09;和HTTPS&#xff08;HyperText Transfer Protocol Secure&#xff09;都是超文本传输协议&#xff0c;但它们之间的关键区别在于安全性。 安全性&#xff1a; HTTP&#xff1a;数据以明文传输&#xff0c;没有加密&…

【Python · Pytorch】人工神经网络 ANN(上)

【Python Pytorch】人工神经网络 ANN&#xff08;上&#xff09; 0. 生物神经网络1. 人工神经网络定义2. 人工神经网络结构2.1 感知机2.2 多层感知机2.3 全连接神经网络2.4 深度神经网络 2. 训练流程※ 数据预处理 (Data Preprocessing) 3. 常见激活函数3.1 Sigmoid / Logisti…

基本查询【MySQL】

文章目录 基本查询插入时是否更新替换查询指定列查询查询字段为表达式为查询结果指定别名结果去重where条件NULL 的查询 结果排序筛选分页结果UpdateDelete截断表聚合函数分组(group by)having && where 基本查询 建表 mysql> create table Student (-> id int…

pandas——数据结构

一、series &#xff08;一&#xff09;创建series import pandas as pd#1.使用列表或数组创建Series # 使用列表创建Series&#xff0c;索引默认从0开始 s1 pd.Series([1, 2, 3]) print(s1) # 使用列表和自定义索引创建Series s2 pd.Series([1, 2, 3], index[a, b, c]) pr…

算法妙妙屋-------1.递归的深邃回响:C++ 算法世界的优雅之旅

前言&#xff1a; 递归是一种在算法中广泛应用的思想&#xff0c;其主体思想是通过将复杂的问题分解为更简单的子问题来求解。具体而言&#xff0c;递归通常包括以下几个要素&#xff1a; 基本情况&#xff08;Base Case&#xff09;&#xff1a;每个递归算法必须有一个或多个…

禾川HCQ1控制器程序编译报错如何解决

1、第一次打开用户程序 2、提示库未安装 3、安装库文件 4、脉冲轴库未安装 5、没有错误 去禾川自动化官网,把可以安装的包和库都安装下,程序编译就没有错误了。 6、下载相关包文件

HarmonyOS:@Watch装饰器:状态变量更改通知

Watch应用于对状态变量的监听。如果开发者需要关注某个状态变量的值是否改变&#xff0c;可以使用Watch为状态变量设置回调函数。 说明 从API version 9开始&#xff0c;该装饰器支持在ArkTS卡片中使用。 从API version 11开始&#xff0c;该装饰器支持在元服务中使用。 一、概…

Windows如何查看自己网卡的MAC地址?

本章教程&#xff0c;主要介绍如何在Windows查看自己的网卡mac地址。 一、查询MAC地址方法 打开使用PowerShell&#xff0c;运行以下命令即可查询到自己的网卡MAC地址。 Get-NetAdapter | Select-Object Name, MacAddress二、MAC地址是什么 MAC地址&#xff08;Media Access Co…

Unknown at rule @tailwindscss(unknownAtRules)

一、前言 整合 tailwindcss 后&#xff0c;发现指令提示警告 Unknown at rule tailwindscss(unknownAtRules)&#xff0c;其实是 vscode 无法识别 tailwindscss 指令&#xff0c;不影响使用&#xff0c;但是对于我这种有编程洁癖的人来说&#xff0c;有点膈应。 二、解决方案…

Python 实现深度学习模型预测控制--预测模型构建

链接&#xff1a;深度学习模型预测控制 &#xff08;如果认为有用&#xff0c;动动小手为我点亮github小星星哦&#xff09;&#xff0c;持续更新中…… 链接&#xff1a;WangXiaoMingo/TensorDL-MPC: DL-MPC(deep learning model predictive control) is a software toolkit…

安宝特案例 | AR技术在院外心脏骤停急救中的革命性应用

00 案例背景 在院外心脏骤停 (OHCA) 的突发救援中&#xff0c;时间与效率直接决定着患者的生命。传统急救模式下&#xff0c;急救人员常通过视频或电话与医院医生进行沟通&#xff0c;以描述患者状况并依照指令行动。然而&#xff0c;这种信息传递方式往往因信息不完整或传递延…