即插即用模块之DO-Conv(深度过度参数化卷积层)详解

目录

一、摘要

二、核心创新点

三、代码详解

四、实验结果

4.1Image Classification

4.2Semantic Segmentation

4.3Object Detection 

五、总结


论文:DOConv论文

代码:DOConv代码

一、摘要

卷积层是卷积神经网络(cnn)的核心组成部分。在本文中,我们建议用额外的深度卷积来增强卷积层,其中每个输入通道与不同的二维核进行卷积。这两个卷积的组合构成了一个过度参数化,因为它增加了可学习的参数,而结果的线性操作可以用单个卷积层来表示。我们把这个深度过度参数化的卷积层称为DO-Conv。我们通过大量的实验表明,仅仅用DO-Conv层替换传统的卷积层就可以提高cnn在许多经典视觉任务上的性能,例如图像分类、检测和分割。此外,在推理阶段,深度卷积被折叠成常规卷积,将计算量减少到完全等同于卷积层的计算量,而没有过度参数化。由于DO-Conv在不增加推理计算复杂度的情况下引入了性能提升,我们主张将其作为传统卷积层的替代方案。

二、核心创新点

深度过参数化卷积层(DO-Conv)是一个具有可训练kernel深度卷积和一个具有可训练常规卷积的组合。给定一个输入, DO-Conv算子的输出与卷积层相同,是一个同维特征。DO-Conv算子是深度卷积算子和卷积算子的复合,如图所示,有两种数学上等价的方法来实现复合:特征复合(a)和核复合(b)。

三、代码详解

# 使用 utf-8 编码
# 导入必要的库
import math  # 导入数学库
import torch  # 导入 PyTorch 库
import numpy as np  # 导入 NumPy 库
from torch.nn import init  # 导入 PyTorch 中的初始化函数
from itertools import repeat  # 导入 itertools 库中的 repeat 函数
from torch.nn import functional as F  # 导入 PyTorch 中的函数式接口
from torch._jit_internal import Optional  # 导入 PyTorch 中的可选模块
from torch.nn.parameter import Parameter  # 导入 PyTorch 中的参数类
from torch.nn.modules.module import Module  # 导入 PyTorch 中的模块类
import collections  # 导入 collections 库

# 定义自定义模块 DOConv2d
class DOConv2d(Module):
    """
       DOConv2d 可以作为 torch.nn.Conv2d 的替代。
       接口与 Conv2d 类似,但有一个例外:
            1. D_mul:超参数的深度乘法器。
       请注意,groups 参数在 DO-Conv(groups=1)、DO-DConv(groups=in_channels)、DO-GConv(其他情况)之间切换。
    """
    # 常量声明
    __constants__ = ['stride', 'padding', 'dilation', 'groups',
                     'padding_mode', 'output_padding', 'in_channels',
                     'out_channels', 'kernel_size', 'D_mul']
    # 注解声明
    __annotations__ = {'bias': Optional[torch.Tensor]}

    # 初始化函数
    def __init__(self, in_channels, out_channels, kernel_size, D_mul=None, stride=1,
                 padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
        super(DOConv2d, self).__init__()

        # 将 kernel_size、stride、padding、dilation 转化为二元元组
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)

        # 检查 groups 是否合法
        if in_channels % groups != 0:
            raise ValueError('in_channels 必须能被 groups 整除')
        if out_channels % groups != 0:
            raise ValueError('out_channels 必须能被 groups 整除')
        # 检查 padding_mode 是否合法
        valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
        if padding_mode not in valid_padding_modes:
            raise ValueError("padding_mode 必须为 {} 中的一种,但得到 padding_mode='{}'".format(
                valid_padding_modes, padding_mode))
        
        # 初始化模块参数
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.padding_mode = padding_mode
        self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))

        #################################### 初始化 D & W ###################################
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        self.D_mul = M * N if D_mul is None or M * N <= 1 else D_mul
        self.W = Parameter(torch.Tensor(out_channels, in_channels // groups, self.D_mul))
        init.kaiming_uniform_(self.W, a=math.sqrt(5))

        if M * N > 1:
            self.D = Parameter(torch.Tensor(in_channels, M * N, self.D_mul))
            init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
            self.D.data = torch.from_numpy(init_zero)

            eye = torch.reshape(torch.eye(M * N, dtype=torch.float32), (1, M * N, M * N))
            d_diag = eye.repeat((in_channels, 1, self.D_mul // (M * N)))
            if self.D_mul % (M * N) != 0:  # 当 D_mul > M * N 时
                zeros = torch.zeros([in_channels, M * N, self.D_mul % (M * N)])
                self.d_diag = Parameter(torch.cat([d_diag, zeros], dim=2), requires_grad=False)
            else:  # 当 D_mul = M * N 时
                self.d_diag = Parameter(d_diag, requires_grad=False)
        ##################################################################################################

        # 初始化偏置参数
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.W)
            bound = 1 / math.sqrt(fan_in)
            init.uniform_(self.bias, -bound, bound)
        else:
            self.register_parameter('bias', None)

    # 返回模块配置的字符串表示形式
    def extra_repr(self):
        s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        if self.padding_mode != 'zeros':
            s += ', padding_mode={padding_mode}'
        return s.format(**self.__dict__)

    # 重新设置状态
    def __setstate__(self, state):
        super(DOConv2d, self).__setstate__(state)
        if not hasattr(self, 'padding_mode'):
            self.padding_mode = 'zeros'

    # 辅助函数,执行卷积操作
    def _conv_forward(self, input, weight):
        if self.padding_mode != 'zeros':
            return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
                            weight, self.bias, self.stride,
                            _pair(0), self.dilation, self.groups)
        return F.conv2d(input, weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)

    # 前向传播函数
    def forward(self, input):
        M = self.kernel_size[0]
        N = self.kernel_size[1]
        DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
        if M * N > 1:
            ######################### 计算 DoW #################
            # (input_channels, D_mul, M * N)
            D = self.D + self.d_diag
            W = torch.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))

            # einsum 输出 (out_channels // groups, in_channels, M * N),
            # 重塑为
            # (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(torch.einsum('ims,ois->oim', D, W), DoW_shape)
            #######################################################
        else:
            # 在这种情况下 D_mul == M * N
            # 从 (out_channels, in_channels // groups, D_mul) 重塑为 (out_channels, in_channels // groups, M, N)
            DoW = torch.reshape(self.W, DoW_shape)
        return self._conv_forward(input, DoW)

# 定义辅助函数
def _ntuple(n):
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return x
        return tuple(repeat(x, n))

    return parse

# 定义辅助函数,将输入转化为二元元组
_pair = _ntuple(2)

四、实验结果

4.1Image Classification

4.2Semantic Segmentation

4.3Object Detection 

五、总结

DO-Conv是一种深度过参数化卷积层,是一种新颖、简单、通用的提高cnn性能的方法。除了提高现有cnn的训练和最终精度的实际意义之外,在推理阶段不引入额外的计算,我们设想其优势的揭示也可以鼓励进一步探索过度参数化作为网络架构设计的一个新维度。

在未来,对这一相当简单的方法进行理论理解,以在一系列应用中实现令人惊讶的非凡性能改进,将是有趣的。此外,我们希望扩大这些过度参数化卷积层可能有效的应用范围,并了解哪些超参数可以从中受益更多。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/542553.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Java虚拟机】简单易懂的ZGC原理分析

简单易懂的ZGC原理分析 GC垃圾收集器ZGC的特点ZGC相关技术Region染色指针 & 转发表 & 读屏障染色指针转发表读屏障 内存多重映射 ZGC流程详解ZGC与其他垃圾搜集器比较与CMS比较与G1比较 GC垃圾收集器 GC垃圾收集器的作用就是帮我们清理堆内存里面的垃圾&#xff0c;无…

第1章、react基础知识;

一、react学习前期准备&#xff1b; 1、基本概念&#xff1b; 前期的知识准备&#xff1a; 1.javascript、html、css&#xff1b; 2.构建工具&#xff1a;Webpack&#xff1a;https://yunp.top/init/p/v/1 3.安装node&#xff1a;npm&#xff1a;https://yunp.top/init/p/v/1 …

cmake制作并链接动静态库

cmake制作并链接动静态库 制作静态库add_library(库名称 STATIC 源文件1 [源文件2] ...)LIBRARY_OUTPUT_PATH指定库的生成路径 制作动态库add_library(库名称 SHARED 源文件1 [源文件2] ...) 连接动静态库link_libraries连接静态库link_directories到哪个路径去找库target_link…

UnityShader学习计划

1.安装ShaderlabVS,vs的语法提示 2. 常规颜色是fixed 3.FrameDebugger调试查看draw的某一帧的全部信息&#xff0c;能看到变量参数的值

雅马哈电钢琴YDP145

数据线&#xff1a;MIDI 琴可以通过MIDI、线直接连接手机&#xff0c;播放声音 琴通过线连接电脑&#xff0c;不能直接播放声音 https://www.bilibili.com/video/BV1ws4y1M7yw 操作&#xff1a; https://usa.yamaha.com/support/updates/yamaha_steinberg_usb_driver_for_win…

王道汽车4S企业管理系统 SQL注入漏洞复现

0x01 产品简介 王道汽车4S企业管理系统(以下简称“王道4S系统”)是一套专门为汽车销售和维修服务企业开发的管理软件。该系统是博士德软件公司集10余年汽车行业管理软件研发经验之大成,精心打造的最新一代汽车4S企业管理解决方案。 0x02 漏洞概述 王道汽车4S企业管理系统…

完美照片由构图决定,摄影构图基础到进阶

一、资料描述 本套摄影构图资料&#xff0c;大小1.04G&#xff0c;共有51个文件。 二、资料目录 新手必备-摄影构图技巧.doc 无忌版的《摄影构图学》.pdf 完美照片的十大经典拍摄技法.pdf 数码摄影曝光手边书.pdf 数码摄影不求人.30天学会数码摄影构图.pdf 数码单反摄影…

sheng的学习笔记-AI-决策树(Decision Tree)

AI目录&#xff1a;sheng的学习笔记-AI目录-CSDN博客 目录 什么是决策树 划分选择 信息增益 增益率 基尼指数 剪枝处理 预剪枝 后剪枝 连续值处理 另一个例子 基本步骤 排序 计算候选划分点集合 评估分割点 每个分割点都进行评估&#xff0c;找到最大信息增益的…

靠谱的大型相亲交友婚恋平台有哪些?相亲app软件前十名

靠谱交友软件&#xff0c;个人感觉还是要选择大型的&#xff0c;口碑好的进行选择&#xff0c;以下是我用过的婚恋平台&#xff0c;分享给大家 1、丛丛 这是我用的最久的一款脱单小程序&#xff0c;我老公就是在这个小程序找到的&#xff01;&#xff01;&#xff01; 这是一款…

CSS边框

目录 内容区&#xff08;content&#xff09;&#xff1a; 边框&#xff08;border&#xff09;&#xff1a; 前言&#xff1a; 示例&#xff1a; 内容区&#xff08;content&#xff09;&#xff1a; 内容区就是盒子里面用来存放东西的区域&#xff0c;里面你可以随便放如:…

计算机三级数据库技术备考笔记(十三)

第十三章 大规模数据库架构 分布式数据库 分布式数据库系统概述 分布式数据库系统是物理上分散、逻辑上集中的数据库系统。系统中的数据分布在物理位置不同的计算机上&#xff08;通常称为场地、站点或结点&#xff0c;本章均用场地来描述&#xff09;&#xff0c;由通信网络将…

大语言模型总结整理(不定期更新)

《【快捷部署】016_Ollama&#xff08;CPU only版&#xff09;》 介绍了如何一键快捷部署Ollama&#xff0c;今天就来看一下受欢迎的模型。 模型简介gemmaGemma是由谷歌及其DeepMind团队开发的一个新的开放模型。参数&#xff1a;2B&#xff08;1.6GB&#xff09;、7B&#xff…

VMware安装Red Hat7.9

1、下载Red Hat Enterprise Linux7.9版本 【百度网盘下载】 链接&#xff1a;https://pan.baidu.com/s/1567NfZRF48PBXfUqxumvDA 提取码&#xff1a;bm7u 2、在虚拟机中创建Red Hat7.9 【点击创建虚拟机】 【自定义高级】 【选择光盘映像安装】 全名自定义即可 【虚拟机命…

Dual-AMN论文翻译

Boosting the Speed of Entity Alignment 10: Dual Attention Matching Network with Normalized Hard Sample Mining 将实体对齐速度提高 10 倍&#xff1a;具有归一化硬样本挖掘的双重注意力匹配网络 ABSTRACT 寻找多源知识图谱(KG)中的等效实体是知识图谱集成的关键步骤&…

分享 WebStorm 2024 激活的方案,支持JetBrains全家桶

大家好&#xff0c;欢迎来到金榜探云手&#xff01; WebStorm公司简介 JetBrains 是一家专注于开发工具的软件公司&#xff0c;总部位于捷克。他们以提供强大的集成开发环境&#xff08;IDE&#xff09;而闻名&#xff0c;如 IntelliJ IDEA、PyCharm、和 WebStorm等。这些工具…

Zynq学习笔记--AXI 总线概述

目录 1. AXI总线概述 1.1 主要特点 1.2 通道功能 1.3 信号概览 2. AXI Interconnect 2.1 信号说明 2.2 内部结构 3. PS-PL AXI Interface 3.1 AXI FPD/LFP/ACP 3.2 Address Editor 3.3 地址空间 3.4 AXI-DDR 4. 通过ILA观察AXI信号 4.1 AXI 读通道 1. AXI总线概述…

故障诊断 | Matlab实现基于小波包结合鹈鹕算法优化卷积神经网络DWT-POA-CNN实现电缆故障诊断算法

故障诊断 | Matlab实现基于小波包结合鹈鹕算法优化卷积神经网络DWT-POA-CNN实现电缆故障诊断算法 目录 故障诊断 | Matlab实现基于小波包结合鹈鹕算法优化卷积神经网络DWT-POA-CNN实现电缆故障诊断算法分类效果基本介绍程序设计参考资料 分类效果 基本介绍 1.Matlab实现基于小波…

复杂DP算法(动态规划)

复杂DP算法 一、线性DP例题1、鸣人的影分身题目信息思路题解 2、糖果题目信息思路题解 二、区间DP例题密码脱落题目信息思路题解 三、树状DP例题生命之树题目信息思路题解 一、线性DP 例题 1、鸣人的影分身 题目信息 思路 题解 #include <bits/stdc.h> #define endl …

funasr 麦克风实时流语音识别

参考: https://github.com/alibaba-damo-academy/FunASR chunk_size 是用于流式传输延迟的配置。[0,10,5] 表示实时显示的粒度为 1060=600 毫秒,并且预测的向前信息为 560=300 毫秒。每个推理输入为 600 毫秒(采样点为 16000*0.6=960),输出为相应的文本。对于最后一个语音…

Thinkphp6接入PayPal支付

沙盒环境示例 创建扩展封装类 <?php namespace lib;class PayPalApi {//clientIdprivate $clientId;//clientSecretprivate $clientSecret;//服务器地址private $host https://api-m.sandbox.paypal.com/;//主机头private $headers [];//api凭证private $token ;//报文…