ICLR 2022)ODConv:即插即用的动态卷积 (附代码)

论文地址:Omni-Dimensional Dynamic Convolution | OpenReview

代码地址:https://github.com/OSVAI/ODConv/blob/main/modules/odconv.py

1.是什么?

ODConv是一种动态卷积算法,它的原理是在卷积过程中,根据输入数据的特征动态地调整卷积核的形状和大小,以适应不同的输入数据。具体来说,ODConv通过引入一个可学习的形变模块,根据输入数据的特征动态地调整卷积核的形状和大小,从而提高了卷积神经网络的性能。与CondConv和DyConv不同,ODConv不仅考虑了空间维度、输入通道维度和输出通道维度,还考虑了卷积核的形状和大小,因此可以更好地适应不同的输入数据。

2.为什么?

常规卷积只有一个静态卷积核且与输入样本无关。对于动态卷积来说,它对多个卷积核进行线性加权,而加权值则与输入有关,这就使得动态卷积具有输入依赖性。它可以描述如下:

尽管动态卷积的定义很简单,但CondConv与DyConv的实现是不相同的,主要体现在计算a_{wi}的结构\pi_{wi}(x)训练策略以及实施动态卷积的层,这些实现上的差异导致了不同的模型精度、模型大小以及推理效率。

  • 两者均为\pi_{wi}(x)采用了类SE架构,但CondConv采用的是Sigmoid,而DyConv采用的是Softmax;
  • DyConv采用的退化策略进行训练以抑制Softmax的one-hot输出;
  • 对于他们嵌入的CNN架构,CondConv替换了最后几个模块的卷积与全连接层,而DyConv则对除第一个卷积外的其他卷积均进行了替换。

根据动态卷积的公式来看,动态卷积有两个基本元素:

  • 卷积核{W_{1},..,W_{n}}
  • 用于计算注意力{a_{w1,...,a_{wn}}}的注意力函数\phi _{wi}(x)

给定n个卷积核,其对应的核空间有以下四个维度:

  • 空间核尺寸k×k;
  • 输入通道数c_{in}
  • 输出通道数c_{out}
  • 卷积核数量n

然而,对于CondConv与DyConv来说,\phi _{wi}(x)均采用单个注意力标量a_{wi},这就意味着它的的输出滤波器W_{i}^{m}R^{k*k*c_{in}}对于输入具有相同的注意力值。换句话说,卷积核 W_{i}的空间维度、输入通道维度以及输出通道维度均被CondConv与DyConv所忽视了。这就导致了关于核空间的粗糙探索。这可能就是为什么CondConv与DyConv对于大网络的性能增益较低的原因。

此外,相比常规卷积,动态卷积的卷积核参数往往是其n倍。比如CondConv中的n=8,DyConv中的n=4。当动态卷积使用过多时无疑会极大程度提升模型大小。我们发现:当 移除掉CondConv/DyConv中的注意力机制(即a_{wi}=1)后,其性能提升接近于零。比如,对于ResNet18,其性能增益从1.78%/2.51%下降到了0.08%/0.14。

上述发现意味着:动态卷积中的注意力机制起关键性作用,更有效的设计也许可以在模型精度与大小之间得到更好的平衡。

一定程度上讲,ODConv可以视作CondConv的延续,将CondConv中一个维度上的动态特性进行了扩展,同时了考虑了空域、输入通道、输出通道等维度上的动态性,故称之为全维度动态卷积。ODConv通过并行策略采用多维注意力机制沿核空间的四个维度学习互补性注意力。作为一种“即插即用”的操作,它可以轻易的嵌入到现有CNN网络中。ImageNet分类与COCO检测任务上的实验验证了所提ODConv的优异性:即可提升大模型的性能,又可提升轻量型模型的性能,实乃万金油是也!值得一提的是,受益于其改进的特征提取能力,ODConv搭配一个卷积核时仍可取得与现有多核动态卷积相当甚至更优的性能

3 怎么样?

3.1 网络结构

基于前述讨论,ODConv通过并行策略引入一种多维注意力机制以对卷积核空间的四个维度学习更灵活的注意力。上图给出CondConv、DyConv以及ODConv的差异图。

延续动态卷积的定义,ODConv可以描述成如下形式:

其中,a_{wi}表示卷积核W_{i}的注意力标量,a_{si}\epsilon R^{k*k},a_{ci}\epsilon R^{c_{in}},a_{fi}\epsilon R^{c_{out}}表示新引入的三个注意力,分别沿空域维度、输入通道维度以及输出通道维度。这四个注意力采用多头注意力模块\pi_{i}(x)计算得到。

在ODConv中,对于卷积核W_{i}a_{si}对k*k空域位置上的卷积参数赋予不用的注意力值,见上图a;a_{ci}对不同输入通道的卷积滤波器赋予不同的注意力值,见上图b;a_{fi}对不同输出通道的卷积滤波器赋予不同的注意力值,见上图c;而a_{wi}则对n个整体卷积核赋予不同的值,见上图d。

原则上来讲,这四种类型的注意力是互补的,通过渐进式对卷积W_{i}沿位置、通道、滤波器以及核等维度乘以不同的注意力将使得卷积操作对于输入存在各个维度的差异性,提供更好的性能以捕获丰富上下文信息。因此,ODCOnv可以大幅提升卷积的特征提取能力;更重要的是,采用更少卷积核的ODConv可以取得与CondConv、DyConv相当甚至更优的性能。

对比前面两种动态卷积的公式可以发现:ODConv是一种更广义的动态卷积。此外,当设置n=1,a_{s1}=a_{c1}=a_{w1}=1时,ODConv则退化为仅具有滤波器层面的注意力,基于输入对卷积滤波器进行调制后再进行卷积,类似于SE。故SE是ODConv的一个特例。

那么如何实现ODConv的四种类型的注意力值呢?延续CondConv与DyConv,我们同样采用SE风格的注意力模块,但使其具有多个头以计算多种类型注意力,整体结构见上图。具体来说,对于输入先通过GAP收缩为长度为c_{in}的特征向量,然后采用FC与四个头生成不同类型的注意力值。对于四个头,其维度分别为k*k,c_{in}×1,c_{out}×1,n×1。

在训练方面,我们采用了DyConv中的退化策略以加速训练。在具体架构嵌入方面,我们参考DyConv对除第一个卷积外的其他所有卷积进行替换。

3.2 代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd


class Attention(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, groups=1, reduction=0.0625, kernel_num=4, min_channel=16):
        super(Attention, self).__init__()
        attention_channel = max(int(in_planes * reduction), min_channel)
        self.kernel_size = kernel_size
        self.kernel_num = kernel_num
        self.temperature = 1.0

        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Conv2d(in_planes, attention_channel, 1, bias=False)
        self.bn = nn.BatchNorm2d(attention_channel)
        self.relu = nn.ReLU(inplace=True)

        self.channel_fc = nn.Conv2d(attention_channel, in_planes, 1, bias=True)
        self.func_channel = self.get_channel_attention

        if in_planes == groups and in_planes == out_planes:  # depth-wise convolution
            self.func_filter = self.skip
        else:
            self.filter_fc = nn.Conv2d(attention_channel, out_planes, 1, bias=True)
            self.func_filter = self.get_filter_attention

        if kernel_size == 1:  # point-wise convolution
            self.func_spatial = self.skip
        else:
            self.spatial_fc = nn.Conv2d(attention_channel, kernel_size * kernel_size, 1, bias=True)
            self.func_spatial = self.get_spatial_attention

        if kernel_num == 1:
            self.func_kernel = self.skip
        else:
            self.kernel_fc = nn.Conv2d(attention_channel, kernel_num, 1, bias=True)
            self.func_kernel = self.get_kernel_attention

        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            if isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def update_temperature(self, temperature):
        self.temperature = temperature

    @staticmethod
    def skip(_):
        return 1.0

    def get_channel_attention(self, x):
        channel_attention = torch.sigmoid(self.channel_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
        return channel_attention

    def get_filter_attention(self, x):
        filter_attention = torch.sigmoid(self.filter_fc(x).view(x.size(0), -1, 1, 1) / self.temperature)
        return filter_attention

    def get_spatial_attention(self, x):
        spatial_attention = self.spatial_fc(x).view(x.size(0), 1, 1, 1, self.kernel_size, self.kernel_size)
        spatial_attention = torch.sigmoid(spatial_attention / self.temperature)
        return spatial_attention

    def get_kernel_attention(self, x):
        kernel_attention = self.kernel_fc(x).view(x.size(0), -1, 1, 1, 1, 1)
        kernel_attention = F.softmax(kernel_attention / self.temperature, dim=1)
        return kernel_attention

    def forward(self, x):
        x = self.avgpool(x)
        x = self.fc(x)
        x = self.bn(x)
        x = self.relu(x)
        return self.func_channel(x), self.func_filter(x), self.func_spatial(x), self.func_kernel(x)


class ODConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                 reduction=0.0625, kernel_num=4):
        super(ODConv2d, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.kernel_num = kernel_num
        self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,
                                   reduction=reduction, kernel_num=kernel_num)
        self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),
                                   requires_grad=True)
        self._initialize_weights()

        if self.kernel_size == 1 and self.kernel_num == 1:
            self._forward_impl = self._forward_impl_pw1x
        else:
            self._forward_impl = self._forward_impl_common

    def _initialize_weights(self):
        for i in range(self.kernel_num):
            nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')

    def update_temperature(self, temperature):
        self.attention.update_temperature(temperature)

    def _forward_impl_common(self, x):
        # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,
        # while we observe that when using the latter method the models will run faster with less gpu memory cost.
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        batch_size, in_planes, height, width = x.size()
        x = x * channel_attention
        x = x.reshape(1, -1, height, width)
        aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)
        aggregate_weight = torch.sum(aggregate_weight, dim=1).view(
            [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])
        output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups * batch_size)
        output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
        output = output * filter_attention
        return output

    def _forward_impl_pw1x(self, x):
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        x = x * channel_attention
        output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups)
        output = output * filter_attention
        return output

    def forward(self, x):
        return self._forward_impl(x)

 参考:

ODConv详解

ICLR 2022 | 涨点神器!Intel提出ODConv:即插即用的动态卷积

致敬CondConv!Intel提出即插即用的“万金油”动态卷积ODConv

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/112807.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Azure机器学习 - 使用与Azure集成的Visual Studio Code实战教程

本文介绍如何启动远程连接到 Azure 机器学习计算实例的 Visual Studio Code。 借助 Azure 机器学习资源的强大功能,使用 VS Code 作为集成开发环境 (IDE)。 在VS Code中将计算实例设置为远程 Jupyter Notebook 服务器。 关注TechLead,分享AI全维度知识。…

【密评】商用密码应用安全性评估从业人员考核题库(十七)

商用密码应用安全性评估从业人员考核题库(十七) 国密局给的参考题库5000道只是基础题,后续更新完5000还会继续更其他高质量题库,持续学习,共同进步。 4001 多项选择题 网络和通信安全层面的通信主体一般包括哪些&…

Unity屏幕中涂鸦

LineRenderer LineRenderer是Unity中的一个组件,用于在场景中绘制简单的线段。 LineRenderer组件允许你通过设置一系列顶点来定义线段的形状和外观。它会根据这些顶点自动在场景中绘制出线段。 下面是LineRenderer的一些重要属性和方法: positionCou…

C++使用栈实现简易计算器(支持括号)

使用C实现&#xff0c;使用系统自带stac 支持括号处理支持小数计算支持表达式有效性检查支持多轮输入。 运行结果示例&#xff1a; 代码&#xff1a; #include <iostream> #include <stack> #include <string> using namespace std;//判断是否是数字字符 …

基于MFC的串口通信(Mscomm)

1、串口通信的概述&#xff1a; 串口是一种重要的通信资源&#xff0c;例如鼠标口、USB接口都是串口。串行端口是CPU和串行设备间的编码转换器。当数据从CPU经过端口发送出去的时候&#xff0c;字节数据会被转为串行的位&#xff0c;在接收数据时&#xff0c;串行的位被转换为…

k8s调度约束

List-Watch Kubernetes 是通过 List-Watch的机制进行每个组件的协作&#xff0c;保持数据同步的&#xff0c;每个组件之间的设计实现了解耦。 List-Watch机制 工作机制&#xff1a;用户通过 kubectl请求给 APIServer 来建立一个 Pod。APIServer会将Pod相关元信息存入 etcd 中…

【ROS系列】坐标系转换介绍和对齐

一、坐标系简介 本篇文章介绍&#xff1a;ECEF、ENU、UTM、WGS-84坐标系&#xff08;LLA) 1.1、ECEF坐标系 ECEF坐标系也叫地心地固直角坐标系。 原点&#xff1a;地球的质心&#xff0c; x轴&#xff1a;原点延伸通过本初子午线&#xff08;0度经度&#xff09;和赤道&am…

京东数据分析:2023年9月京东洗地机行业品牌销售排行榜

鲸参谋监测的京东平台9月份洗地机市场销售数据已出炉&#xff01; 9月份&#xff0c;洗地机市场的销售额增长。根据鲸参谋电商数据分析平台的相关数据显示&#xff0c;9月京东平台上洗地机的销量为9.2万&#xff0c;销售额将近2.2亿&#xff0c;同比增长约9%。从价格上看&#…

爬虫 | 【实践】百度搜索链接爬取,生成标题词云 | 以“AI换脸”为例

目录 &#x1f4da;链接爬取 &#x1f407;流程梳理 &#x1f407;代码实现 &#x1f407;结果 &#x1f4da;词云生成 &#x1f407;代码实现 &#x1f407;结果 &#x1f4da;链接爬取 &#x1f407;流程梳理 总体流程是&#xff1a;构建搜索链接 -> 发送HTTP请求…

大坝水库安全监测终端MCU,智能化管理的新篇章!

我国目前拥有超过9.8万座水库大坝&#xff0c;其中超过95%为土石坝&#xff0c;这些大坝主要是在上世纪80年代以前建造的。这些水库大坝在保障防洪、发电、供水、灌溉等方面发挥了巨大的作用&#xff0c;但是同时也存在一定的安全风险&#xff0c;比如坝体结构破损、坝基渗漏、…

当你在浏览器地址栏输入一个URL后,将会发生的事情?个人笔记

客户端 在浏览器输入 URL 回车之后发生了什么&#xff08;超详细版&#xff09; - 知乎 (zhihu.com) 大致流程是&#xff1a; URL 解析DNS 查询TCP 连接处理请求接受响应渲染页面 1.URL解析 地址解析&#xff1a; 首先判断你输入是否是一个合法的URL还是一个待搜索的关键…

【C++】set和multiset

文章目录 关联式容器键值对一、set介绍二、set的使用multiset 关联式容器 STL中的部分容器&#xff0c;比如&#xff1a;vector、list、deque、forward_list(C11)等&#xff0c;这些容器统称为序列式容器&#xff0c;因为其底层为线性序列的数据结构&#xff0c;里面存储的是元…

从内存管理的角度来看,Python语言的变量和参数传递情况解析

从内存管理的角度来看&#xff0c;Python语言的变量和参数传递情况解析 概述 从内存管理的角度来看&#xff0c;Python中的变量和参数传递有一些特点&#xff1a; ☆ 变量是对象的引用&#xff1a;在Python中&#xff0c;变量实际上是对象的引用&#xff0c;而不是对象本身。…

vi vim 末尾编辑按GA 在最后一行下方新增一行编辑按Go

vim 快速跳到文件末尾 在最后一行下方新增一行 移到末尾,并且进入文本录入模式 GA (大写G大写A) 在一般模式(刚进入的模式,esc模式) GA 或 Shift ga 先 G 或 shiftg 到最后一行 然后 A 或 shifta 到本行末尾 并且进入文本录入模式 在最后一行下方新增一行 (光标换行,文字不…

【Linux】Nginx安装使用负载均衡及动静分离(前后端项目部署),前端项目打包

一、Nginx导言 1、引言 Nginx 是一款高性能的 Web 服务器和反向代理服务器&#xff0c;也可以充当负载均衡器、HTTP 缓存和安全防护设备。它的特点是内存占用小、稳定性高、并发性强、易于扩展&#xff0c;因此在互联网领域得到了广泛的使用。 总结出以下三点: 负载均衡&#x…

lvsDR模式

LVS-DR模式 是最常用的LVS负载方式 直接路由模式 Lvs调度器 只负责请求和转发 转发到真实服务器 但是响应结果 由后端服务器直接转发给客户端 不需要经过调度器处理 可以减轻Lvs调度器的负担 提高系统性能和稳定性 工作原理&#xff1a; 客户端发送请求到vipLVS调度器接受请求…

miniconda快速安装

目录 一、Linux下miniconda安装 1.1、安装 1.2、miniconda初始化 二、Windows下miniconda安装 三、maOS下miniconda安装 3.1、安装 3.2、miniconda初始化 四、参考&#xff1a; 本文给出windows、macos、linux下快速安装miniconda方法。 对比conda&#xff0c;minicond…

光学仿真|优化汽车内部照明体验

当我们谈论优化人类感知的内部照明时&#xff0c;我们实际上指的是两个重点领域&#xff1a;安全性和驾驶员体验。如果内部照明可以提供尽可能最佳的体验&#xff0c;驾驶员则能够更好地应对颇具挑战性或意外的驾驶状况&#xff0c;并且减轻疲劳感。除了功能优势外&#xff0c;…

【Servlet】 一

本文主要介绍了如何在tomcat部署一个webapp 以及 如何借助maven用servlet编写一个hello world . 一.Tomcat Tomcat是一个Java里广泛使用的http服务器 HTTP服务器有很多实现&#xff0c;其中最知名的是Nginx&#xff1b;而在Java里&#xff0c;最知名的是Tomcat 一个Tomcat服务…

【K8S】二进制安装

常见的K8S安装部署方式 ●Minikube Minikube是一个工具&#xff0c;可以在本地快速运行一个单节点微型K8S&#xff0c;仅用于学习、预览K8S的一些特性使用。 部署地址&#xff1a;https://kubernetes.io/docs/setup/minikube ●Kubeadm☆ Kubeadm也是一个工具&#xff0c;提…