【Block总结】ODConv动态卷积,适用于CV任务|即插即用

一、论文信息

  • 论文标题:Omni-Dimensional Dynamic Convolution
  • 作者:Chao Li, Aojun Zhou, Anbang Yao
  • 发表会议:ICLR 2022
  • 论文链接:https://arxiv.org/pdf/2209.07947
  • GitHub链接:https://github.com/OSVAI/ODConv
    在这里插入图片描述

二、创新点

Omni-Dimensional Dynamic Convolution(ODConv)提出了一种更为通用且优雅的动态卷积设计,主要创新点包括:

  • 多维动态注意力机制:ODConv通过并行策略在卷积核的四个维度(空间大小、输入通道数、输出通道数和卷积核数量)上学习互补的注意力。这种设计使得卷积核能够根据输入特征动态调整,从而提升特征提取能力。

  • 即插即用的特性:ODConv可以作为常规卷积的替代品,轻松集成到现有的CNN架构中,增强模型的灵活性和适应性。

三、方法

ODConv的实现方法包括以下几个步骤:

  1. 注意力计算

    • ODConv计算四种类型的注意力:空间注意力、输入通道注意力、输出通道注意力和卷积核注意力。这些注意力值用于调节卷积核的输出。
  2. 并行策略

    • 在每个卷积层中,ODConv并行计算上述四种注意力,确保每个卷积核在不同维度上都能获得适当的加权。
  3. 卷积操作

    • 将计算得到的注意力应用于卷积核,进而影响最终的特征图输出。

在这里插入图片描述

ODConv的多维动态注意力机制实现

Omni-Dimensional Dynamic Convolution(ODConv)引入了一种创新的多维动态注意力机制,旨在提升卷积神经网络(CNN)的特征提取能力。该机制通过并行策略在卷积核的四个维度上学习互补的注意力,从而实现更灵活的卷积操作。以下是ODConv多维动态注意力机制的具体实现细节:

1、四个维度的注意力机制

ODConv的多维动态注意力机制主要涉及以下四个维度的注意力学习:

  1. 空间维度注意力(Spatial Attention)

    • 该注意力机制为每个卷积核的不同空间位置分配不同的权重。通过对空间特征的加权,ODConv能够更好地捕捉图像中的局部特征。
  2. 输入通道注意力(Input Channel Attention)

    • 该机制为每个卷积核的输入通道分配不同的权重,允许模型根据输入特征的重要性动态调整卷积操作。这种方式增强了模型对不同输入特征的响应能力。
  3. 输出通道注意力(Output Channel Attention)

    • 该注意力机制为每个卷积核的输出通道分配不同的权重,使得模型能够根据输出特征的重要性进行动态调整,从而优化特征表示。
  4. 卷积核数量注意力(Kernel Attention)

    • 该机制为每个卷积核分配不同的权重,允许模型在多个卷积核之间进行选择,增强了模型的灵活性和适应性。

2、并行策略

ODConv采用并行策略来计算上述四种类型的注意力。具体实现步骤如下:

  • 注意力计算

    • 在每个卷积层中,ODConv并行计算四种注意力,分别对应于卷积核的四个维度。这些注意力值通过多头注意力模块进行计算,以确保每个维度的特征都能得到充分的关注。
  • 注意力加权

    • 计算得到的注意力值被应用于卷积核的输出,进而影响最终的特征图。这种加权机制使得卷积操作能够根据输入特征的不同动态调整,从而提升特征提取的效果。

3、优势与效果

ODConv的多维动态注意力机制带来了显著的性能提升:

  • 增强特征学习能力:通过在多个维度上进行动态调整,ODConv能够更有效地捕捉图像中的重要特征。

  • 减少参数量:即使在使用单个卷积核的情况下,ODConv也能与现有的多核动态卷积方法竞争或超越,显著减少了额外的参数。

  • 广泛适用性:ODConv可以作为常规卷积的替代品,轻松集成到现有的CNN架构中,提升模型的灵活性和适应性。

四、效果

ODConv在多个标准数据集上进行了实验,结果显示其在准确性和效率上均有显著提升:

  • ImageNet:在MobileNetV2和ResNet系列模型上,ODConv分别提升了3.77%至5.71%和1.86%至3.72%的Top-1准确率。

  • MS-COCO:在目标检测任务中,ODConv同样展现了优越的性能,提升了模型对小目标和被遮挡目标的检测能力。

五、实验结果

ODConv的实验结果表明,其在多个主流CNN架构上的表现均优于传统卷积方法。具体实验结果包括:

  • MobileNetV2

    • 原始模型Top-1准确率为71.65%,使用ODConv后提升至74.74%(1×核)和75.29%(4×核)。
  • ResNet系列

    • ResNet50的Top-1准确率从76.23%提升至77.87%(1×核)和78.50%(4×核)。

这些结果表明,ODConv不仅提高了模型的准确性,还在参数量上保持了较低的增长。

六、总结

Omni-Dimensional Dynamic Convolution(ODConv)通过引入多维动态注意力机制,显著提升了卷积神经网络的特征提取能力。其创新的设计使得ODConv能够在多个维度上学习卷积核的动态特性,进而提高模型的性能。实验结果证明,ODConv在多个标准数据集上均表现出色,成为现代深度学习模型中一种有效的卷积替代方案。

代码

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd


class Attention(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):
        super(Attention, self).__init__()
        attention_channel = max(int(in_planes * reduction), min_channel)
        self.kernel_size = kernel_size
        self.kernel_num = kernel_num
        self.temperature = 1.0

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False)
        self.bn = nn.BatchNorm2d(attention_channel)
        self.relu = nn.ReLU(inplace=True)

        self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)
        self.func_channel = self.get_channel_attention

        if in_planes == groups and in_planes == out_planes:  # depth-wise convolution
            self.func_filter = self.skip
        else:
            self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)
            self.func_filter = self.get_filter_attention

        if kernel_size == 1:  # point-wise convolution
            self.func_spatial = self.skip
        else:
            self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)
            self.func_spatial = self.get_spatial_attention

        if kernel_num == 1:
            self.func_kernel = self.skip
        else:
            self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)
            self.func_kernel = self.get_kernel_attention

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def update_temperature(self, temperature):
        self.temperature = temperature

    @staticmethod
    def skip(_):
        return 1.0

    def get_channel_attention(self, x):
        channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
        return channel_attention

    def get_filter_attention(self, x):
        filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
        return filter_attention

    def get_spatial_attention(self, x):
        spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)
        spatial_attention = torch.sigmoid(spatial_attention / self.temperature)
        return spatial_attention

    def get_kernel_attention(self, x):
        kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)
        kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)
        return kernel_attention

    def forward(self, x):
        x = self.avgpool(x)
        x = self.fc(x)
        x = self.bn(x)
        x = self.relu(x)
        return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)


class ODConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                 reduction=0.0625, kernel_num=4):
        super(ODConv2d, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.kernel_num = kernel_num
        self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,
                                   reduction=reduction, kernel_num=kernel_num)
        self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),
                                   requires_grad=True)
        self._initialize_weights()

        if self.kernel_size == 1 and self.kernel_num == 1:
            self._forward_impl = self._forward_impl_pw1x
        else:
            self._forward_impl = self._forward_impl_common

    def _initialize_weights(self):
        for i in range(self.kernel_num):
            nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')

    def update_temperature(self, temperature):
        self.attention.update_temperature(temperature)

    def _forward_impl_common(self, x):
        # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,
        # while we observe that when using the latter method the models will run faster with less gpu memory cost.
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        batch_size, in_planes, height, width = x.size()
        x = x * channel_attention
        x = x.reshape(1, -1, height, width)
        aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)
        aggregate_weight = torch.sum(aggregate_weight, dim=1).view(
            [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])
        output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups * batch_size)
        output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
        output = output * filter_attention
        return output

    def _forward_impl_pw1x(self, x):
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        x = x * channel_attention
        output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups)
        output = output * filter_attention
        return output

    def forward(self, x):
        return self._forward_impl(x)





if __name__ == "__main__":
    dim=256
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, height, width,channels)
    x = torch.randn(2,dim,40,40).to(device)
    # 初始化 HWD 模块

    block = ODConv2d(dim,dim,7,padding=3)
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)

输出结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962547.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

英语语法 第一天

I’m a student. 我是个学生 我是个新东方的学生 I’m a student of New Oriental School 我爱你 I love you 我在心中爱你 I love you in my heart. 这是一朵花 This is a flower 这是一朵在公园里的花 This is a flower in the park.(修饰部分在修饰词后面) 主干…

doris:高并发导入优化(Group Commit)

在高频小批量写入场景下,传统的导入方式存在以下问题: 每个导入都会创建一个独立的事务,都需要经过 FE 解析 SQL 和生成执行计划,影响整体性能每个导入都会生成一个新的版本,导致版本数快速增长,增加了后台…

智联出行公司布局中国市场,引领绿色出行新潮流

近日,智联共享科技有限公司(智联出行ZSTL)正式入驻中国香港市场,开启中国地区“合伙人”战略部署,其线上服务平台也于同日上线。 作为共享单车领域的先行者,智联出行公司此举标志着其全球化布局的重要进展&…

PythonFlask框架

文章目录 处理 Get 请求处理 POST 请求应用 app.route(/tpost, methods[POST]) def testp():json_data request.get_json()if json_data:username json_data.get(username)age json_data.get(age)return jsonify({username: username测试,age: age})从 flask 中导入了 Flask…

开源的瓷砖式图像板系统Pinry

简介 什么是 Pinry ? Pinry 是一个开源的瓷砖式图像板系统,旨在帮助用户轻松保存、标记和分享图像、视频和网页。它提供了一种便于快速浏览的格式,适合喜欢整理和分享多种媒体内容的人。 主要特点 图像抓取和在线预览:支持从网页…

前端进阶:深度剖析预解析机制

一、预解析是什么? 在前端开发中,我们常常会遇到一些看似不符合常规逻辑的代码执行现象,比如为什么在变量声明之前访问它,得到的结果是undefined,而不是报错?为什么函数在声明之前就可以被调用&#xff1f…

stm32教程:EXTI外部中断应用

早上好啊大佬们,上一期我们讲了EXTI外部中断的原理以及基础代码的书写,这一期就来尝试一下用它来写一些有实际效能的工程吧。 这一期里,我用两个案例代码来让大家感受一下外部中断的作用和使用价值。 旋转编码器计数 整体思路讲解 这里&…

数据分析系列--⑦RapidMiner模型评价(基于泰坦尼克号案例含数据集)

一、前提 二、模型评估 1.改造⑥ 2.Cross Validation算子说明 2.1Cross Validation 的作用 2.1.1 模型评估 2.1.2 减少过拟合 2.1.3 数据利用 2.2 Cross Validation 的工作原理 2.2.1 数据分割 2.2.2 迭代训练与测试 ​​​​​​​ 2.2.3 结果汇总 ​​​​​​​ …

WPS mathtype间距太大、显示不全、公式一键改格式/大小

1、间距太大 用mathtype后行距变大的原因 mathtype行距变大到底怎么解决-MathType中文网 段落设置固定值 2、显示不全 设置格式: 打开MathType编辑器点击菜单栏中的"格式(Format)"选择"间距(Spacing)"在弹出的对话框中调整"分数间距(F…

【Postman接口测试】Postman的安装和使用

在软件测试领域,接口测试是保障软件质量的关键环节之一,而Postman作为一款功能强大且广受欢迎的接口测试工具,能够帮助测试人员高效地进行接口测试工作。本文将详细介绍Postman的安装和使用方法,让你快速上手这款工具。 一、Pos…

因果推断与机器学习—用机器学习解决因果推断问题

Judea Pearl 将当前备受瞩目的机器学习研究戏谑地称为“仅限于曲线拟合”,然而,曲线拟合的实现绝非易事。机器学习模型在图像识别、语音识别、自然语言处理、蛋白质分子结构预测以及搜索推荐等多个领域均展现出显著的应用效果。 在因果推断任务中,在完成因果效应识别之后,需…

python算法和数据结构刷题[2]:链表、队列、栈

链表 链表的节点定义: class Node():def __init__(self,item,nextNone):self.itemitemself.nextNone 删除节点: 删除节点前的节点的next指针指向删除节点的后一个节点 添加节点: 单链表 class Node():"""单链表的结点&quo…

AJAX案例——图片上传个人信息操作

黑马程序员视频地址&#xff1a; AJAX-Day02-11.图片上传https://www.bilibili.com/video/BV1MN411y7pw?vd_source0a2d366696f87e241adc64419bf12cab&spm_id_from333.788.videopod.episodes&p26 图片上传 <!-- 文件选择元素 --><input type"file"…

deepseek大模型本机部署

2024年1月20日晚&#xff0c;中国DeepSeek发布了最新推理模型DeepSeek-R1&#xff0c;引发广泛关注。这款模型不仅在性能上与OpenAI的GPT-4相媲美&#xff0c;更以开源和创新训练方法&#xff0c;为AI发展带来了新的可能性。 本文讲解如何在本地部署deepseek r1模型。deepseek官…

使用 Ollama 和 Kibana 在本地为 RAG 测试 DeepSeek R1

作者&#xff1a;来自 Elastic Dave Erickson 及 Jakob Reiter 每个人都在谈论 DeepSeek R1&#xff0c;这是中国对冲基金 High-Flyer 的新大型语言模型。现在他们推出了一款功能强大、具有开放权重的思想链推理 LLM&#xff0c;这则新闻充满了对行业意味着什么的猜测。对于那些…

Greenplum临时表未清除导致库龄过高处理

1.问题 Greenplum集群segment后台日志报错 2.回收库龄 master上执行 vacuumdb -F -d cxy vacuumdb -F -d template1 vacuumdb -F -d rptdb 3.回收完成后检查 仍然发现segment还是有库龄报警警告信息发出 4.检查 4.1 在master上检查库年龄 SELECT datname, datfrozen…

栈和队列特别篇:栈和队列的经典算法问题

图均为手绘,代码基于vs2022实现 系列文章目录 数据结构初探: 顺序表 数据结构初探:链表之单链表篇 数据结构初探:链表之双向链表篇 链表特别篇:链表经典算法问题 数据结构:栈篇 数据结构:队列篇 文章目录 系列文章目录前言一.有效的括号(leetcode 20)二.用队列实现栈(leetcode…

记录一次,PyQT的报错,多线程Udp失效,使用工具如netstat来检查端口使用情况。

1.问题 报错Exception in thread Thread-1: Traceback (most recent call last): File "threading.py", line 932, in _bootstrap_inner File "threading.py", line 870, in run File "main.py", line 456, in udp_recv IndexError: list…

论文阅读(十):用可分解图模型模拟连锁不平衡

1.论文链接&#xff1a;Modeling Linkage Disequilibrium with Decomposable Graphical Models 摘要&#xff1a; 本章介绍了使用可分解的图形模型&#xff08;DGMs&#xff09;表示遗传数据&#xff0c;或连锁不平衡&#xff08;LD&#xff09;&#xff0c;各种下游应用程序之…

本地部署DeepSeek开源多模态大模型Janus-Pro-7B实操

本地部署DeepSeek开源多模态大模型Janus-Pro-7B实操 Janus-Pro-7B介绍 Janus-Pro-7B 是由 DeepSeek 开发的多模态 AI 模型&#xff0c;它在理解和生成方面取得了显著的进步。这意味着它不仅可以处理文本&#xff0c;还可以处理图像等其他模态的信息。 模型主要特点:Permalink…