【Block总结】CPCA,通道优先卷积注意力|即插即用

论文信息

标题: Channel Prior Convolutional Attention for Medical Image Segmentation

论文链接: arxiv.org

代码链接: GitHub
在这里插入图片描述

创新点

本文提出了一种新的通道优先卷积注意力(CPCA)机制,旨在解决医学图像分割中存在的低对比度和显著器官形状变化等问题。CPCA通过在通道和空间维度上动态分配注意力权重,增强了模型对信息丰富通道和重要区域的关注能力。
在这里插入图片描述

方法

CPCA方法的核心在于以下几个方面:

  • 动态注意力分配: 在通道和空间维度上支持动态分配注意力权重,使得模型能够自适应地关注不同的特征。

  • 多尺度深度可分离卷积模块: 该模块有效提取空间关系,同时保持通道优先,增强了特征表示能力。

  • CPCANet网络结构: 基于CPCA的医学图像分割网络CPCANet被设计用于实现更高效的分割性能。

CPCA模块详解

通道优先卷积注意力(Channel Prior Convolutional Attention,CPCA)是一种新型的注意力机制,旨在提升深度学习模型在医学图像分割等任务中的性能。CPCA通过动态分配通道和空间维度的注意力权重,增强了模型对重要特征的关注能力。

CPCA结合了通道注意力和空间注意力的机制,具体实现步骤如下:

  1. 通道注意力(Channel Attention):

    • 特征聚合: 通过全局平均池化和全局最大池化操作,生成两个不同的特征表示。
    • 特征处理: 将这两个特征表示分别通过卷积层和激活函数处理,以提取通道之间的关系。
    • 权重生成: 通过Sigmoid函数生成通道注意力权重,用于动态调整每个通道的重要性。
  2. 空间注意力(Spatial Attention):

    • 空间关系捕捉: 使用多尺度深度可分离卷积模块来提取特征图中不同位置之间的关系。
    • 多尺度卷积: 采用不同大小的卷积核来捕获多尺度信息,从而更好地理解特征图的空间结构。
  3. 整体机制:

    • CPCA通过结合通道和空间注意力,动态分配注意力权重,并保持通道优先。这种结合使得网络能够更好地捕捉重要特征,提高特征的表征能力。

效果

在多个公开数据集上进行的实验表明,CPCANet相较于其他最先进的算法在分割性能上表现更佳,同时所需计算资源更少。具体来说,CPCA在处理低对比度和复杂形状变化的医学图像时,显著提高了分割精度。

实验结果

实验结果显示,CPCANet在两个主要数据集(如ACDC和ISIC2016)上均取得了优于现有方法的分割效果。通过对比分析,CPCA方法在多个指标上均表现出色,尤其是在处理复杂背景和形状变化时,展现了其优越性。

总结

本文通过提出通道优先卷积注意力(CPCA)机制,显著提升了医学图像分割的性能。CPCA不仅增强了模型对重要特征的关注能力,还通过动态分配注意力权重,优化了计算资源的使用。未来的研究可以进一步探索CPCA在其他领域的应用潜力,以及与其他深度学习技术的结合。

代码

import torch
import torch.nn.functional
import torch.nn.functional as F
from torch import nn

class Mlp(nn.Module):
    """ Multilayer perceptron."""

    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x


def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
    result = nn.Sequential()
    result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
                                        kernel_size=kernel_size, stride=stride, padding=padding, groups=groups,
                                        bias=False))
    result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
    return result


def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups=1):
    result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride,
                     padding=padding, groups=groups)
    result.add_module('relu', nn.ReLU())
    return result


def fuse_bn(conv_or_fc, bn):
    std = (bn.running_var + bn.eps).sqrt()
    t = bn.weight / std
    t = t.reshape(-1, 1, 1, 1)

    if len(t) == conv_or_fc.weight.size(0):
        return conv_or_fc.weight * t, bn.bias - bn.running_mean * bn.weight / std
    else:
        repeat_times = conv_or_fc.weight.size(0) // len(t)
        repeated = t.repeat_interleave(repeat_times, 0)
        return conv_or_fc.weight * repeated, (bn.bias - bn.running_mean * bn.weight / std).repeat_interleave(
            repeat_times, 0)


class ChannelAttention(nn.Module):

    def __init__(self, input_channels, internal_neurons):
        super(ChannelAttention, self).__init__()
        self.fc1 = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1,
                             bias=True)
        self.fc2 = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1,
                             bias=True)
        self.input_channels = input_channels

    def forward(self, inputs):
        x1 = F.adaptive_avg_pool2d(inputs, output_size=(1, 1))
        # print('x:', x.shape)
        x1 = self.fc1(x1)
        x1 = F.relu(x1, inplace=True)
        x1 = self.fc2(x1)
        x1 = torch.sigmoid(x1)
        x2 = F.adaptive_max_pool2d(inputs, output_size=(1, 1))
        # print('x:', x.shape)
        x2 = self.fc1(x2)
        x2 = F.relu(x2, inplace=True)
        x2 = self.fc2(x2)
        x2 = torch.sigmoid(x2)
        x = x1 + x2
        x = x.view(-1, self.input_channels, 1, 1)
        return x


class RepBlock(nn.Module):

    def __init__(self, in_channels, out_channels,
                 channelAttention_reduce=4):
        super().__init__()

        self.C = in_channels
        self.O = out_channels

        assert in_channels == out_channels
        self.ca = ChannelAttention(input_channels=in_channels, internal_neurons=in_channels // channelAttention_reduce)
        self.dconv5_5 = nn.Conv2d(in_channels, in_channels, kernel_size=5, padding=2, groups=in_channels)
        self.dconv1_7 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 7), padding=(0, 3), groups=in_channels)
        self.dconv7_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(7, 1), padding=(3, 0), groups=in_channels)
        self.dconv1_11 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 11), padding=(0, 5), groups=in_channels)
        self.dconv11_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(11, 1), padding=(5, 0), groups=in_channels)
        self.dconv1_21 = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 21), padding=(0, 10), groups=in_channels)
        self.dconv21_1 = nn.Conv2d(in_channels, in_channels, kernel_size=(21, 1), padding=(10, 0), groups=in_channels)
        self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=(1, 1), padding=0)
        self.act = nn.GELU()

    def forward(self, inputs):
        #   Global Perceptron
        inputs = self.conv(inputs)
        inputs = self.act(inputs)

        channel_att_vec = self.ca(inputs)
        inputs = channel_att_vec * inputs

        x_init = self.dconv5_5(inputs)
        x_1 = self.dconv1_7(x_init)
        x_1 = self.dconv7_1(x_1)
        x_2 = self.dconv1_11(x_init)
        x_2 = self.dconv11_1(x_2)
        x_3 = self.dconv1_21(x_init)
        x_3 = self.dconv21_1(x_3)
        x = x_1 + x_2 + x_3 + x_init
        spatial_att = self.conv(x)
        out = spatial_att * inputs
        out = self.conv(out)
        return out


if __name__ == "__main__":
    dim=64
    # 如果GPU可用,将模块移动到 GPU
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # 输入张量 (batch_size, height, width,channels)
    x = torch.randn(2,dim,40,40).to(device)
    # 初始化 CPCA模块
    block = RepBlock(dim,dim)
    print(block)
    block = block.to(device)
    # 前向传播
    output = block(x)
    print("输入:", x.shape)
    print("输出:", output.shape)

输出结果:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/962812.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

grpc 和 http 的区别---二进制vsJSON编码

gRPC 和 HTTP 是两种广泛使用的通信协议,各自适用于不同的场景。以下是它们的详细对比与优势分析: 一、核心特性对比 特性gRPCHTTP协议基础基于 HTTP/2基于 HTTP/1.1 或 HTTP/2数据格式默认使用 Protobuf(二进制)通常使用 JSON/…

Qt常用控件 输入类控件

文章目录 1.QLineEdit1.1 常用属性1.2 常用信号1.3 例子1,录入用户信息1.4 例子2,正则验证手机号1.5 例子3,验证输入的密码1.6 例子4,显示密码 2. QTextEdit2.1 常用属性2.2 常用信号2.3 例子1,获取输入框的内容2.4 例…

[b01lers2020]Life on Mars1

打开题目页面如下 看了旁边的链接,也没有什么注入点,是正常的科普 利用burp suite抓包,发现传参 访问一下 http://5edaec92-dd87-4fec-b0e3-501ff24d3650.node5.buuoj.cn:81/query?searchtharsis_rise 接下来进行sql注入 方法一&#xf…

前端自动化测试(一):揭秘自动化测试秘诀

目录 [TOC](目录)前言自动化测试 VS 手动测试测试分类何为单元测试单元测试的优缺点优点缺点 测试案例测试代码 测试函数的封装实现 expect 方法实现 test 函数结语 正文开始 , 如果觉得文章对您有帮助,请帮我三连订阅,谢谢💖&…

FFmpeg工具使用基础

一、FFmpeg工具介绍 FFmpeg命令行工具主要包括以下几个部分: ‌ffmpeg‌:编解码工具‌ffprobe‌:多媒体分析器‌ffplay‌:简单的音视频播放器这些工具共同构成了FFmpeg的核心功能,支持各种音视频格式的处理和转换‌ 二、在Ubuntu18.04上安装FFmpeg工具 1、sudo apt-upda…

upload labs靶场

upload labs靶场 注意:本人关卡后面似乎相比正常的关卡少了一关,所以每次关卡名字都是1才可以和正常关卡在同一关 一.个人信息 个人名称:张嘉玮 二.解题情况 三.解题过程 题目:up load labs靶场 pass 1前后端 思路及解题:…

解锁豆瓣高清海报(二) 使用 OpenCV 拼接和压缩

解锁豆瓣高清海报(二): 使用 OpenCV 拼接和压缩 脚本地址: 项目地址: Gazer PixelWeaver.py pixel_squeezer_cv2.py 前瞻 继上一篇“解锁豆瓣高清海报(一) 深度爬虫与requests进阶之路”成功爬取豆瓣电影海报之后,本文将介绍如何使用 OpenCV 对这些海报进行智…

C++:虚函数与多态性习题2

题目内容: 编写程序,声明抽象基类Shape,由它派生出3个派生类:Circle、Rectangle、Triangle,用虚函数分别计算图形面积,并求它们的和。要求用基类指针数组,使它每一个元素指向一个派生类对象。 …

JVM运行时数据区域-附面试题

Java虚拟机在执行Java程序的过程中会把它所管理的内存划分为若干个不同的数据区域。这些区域 有各自的用途,以及创建和销毁的时间,有的区域随着虚拟机进程的启动而一直存在,有些区域则是 依赖用户线程的启动和结束而建立和销毁。 1. 程序计…

电脑优化大师-解决电脑卡顿问题

我们常常会遇到电脑运行缓慢、网速卡顿的情况,但又不知道是哪个程序在占用过多资源。这时候,一款能够实时监控网络和系统状态的工具就显得尤为重要了。今天,就来给大家介绍一款小巧实用的监控工具「TrafficMonitor」。 「TrafficMonitor 」是…

跨组织环境下 MQTT 桥接架构的评估

论文标题 中文标题: 跨组织环境下 MQTT 桥接架构的评估 英文标题: Evaluation of MQTT Bridge Architectures in a Cross-Organizational Context 作者信息 Keila Lima, Tosin Daniel Oyetoyan, Rogardt Heldal, Wilhelm Hasselbring Western Norway …

Baklib揭示内容中台实施最佳实践的策略与实战经验

内容概要 在当前数字化转型的浪潮中,内容中台的概念日益受到关注。它不再仅仅是一个内容管理系统,而是企业提升运营效率与灵活应对市场变化的重要支撑平台。内容中台的实施离不开最佳实践的指导,这些实践为企业在建设高效内容中台时提供了宝…

牛客周赛round78 B,C

B.一起做很甜的梦 题意&#xff1a;就是输出n个数&#xff08;1-n&#xff09;&#xff0c;使输出的序列中任意选连续的小序列&#xff08;小序列长度>2&&<n-1&#xff09;不符合排列&#xff08;例如如果所选长度为2&#xff0c;在所有长度为2 的小序列里不能出…

微机原理与接口技术期末大作业——4位抢答器仿真

在微机原理与接口技术的学习旅程中&#xff0c;期末大作业成为了检验知识掌握程度与实践能力的关键环节。本次我选择设计并仿真一个 4 位抢答器系统&#xff0c;通过这个项目&#xff0c;深入探索 8086CPU 及其接口技术的实际应用。附完整压缩包下载。 一、系统设计思路 &…

Android记事本App设计开发项目实战教程2025最新版Android Studio

平时上课录了个视频&#xff0c;从新建工程到打包Apk&#xff0c;从头做到尾&#xff0c;没有遗漏任何实现细节&#xff0c;欢迎学过Android基础的同学参加&#xff0c;如果你做过其他终端软件开发&#xff0c;也可以学习&#xff0c;快速上手Android基础开发。 Android记事本课…

设计模式Python版 组合模式

文章目录 前言一、组合模式二、组合模式实现方式三、组合模式示例四、组合模式在Django中的应用 前言 GOF设计模式分三大类&#xff1a; 创建型模式&#xff1a;关注对象的创建过程&#xff0c;包括单例模式、简单工厂模式、工厂方法模式、抽象工厂模式、原型模式和建造者模式…

4-图像梯度计算

文章目录 4.图像梯度计算(1)Sobel算子(2)梯度计算方法(3)Scharr与Laplacian算子4.图像梯度计算 (1)Sobel算子 图像梯度-Sobel算子 Sobel算子是一种经典的图像边缘检测算子,广泛应用于图像处理和计算机视觉领域。以下是关于Sobel算子的详细介绍: 基本原理 Sobel算子…

MATLAB实现多种群遗传算法

多种群遗传算法&#xff08;MPGA, Multi-Population Genetic Algorithm&#xff09;是一种改进的遗传算法&#xff0c;它通过将种群分成多个子种群并在不同的子种群之间进行交叉和交换&#xff0c;旨在提高全局搜索能力并避免早期收敛。下面是多种群遗传算法的主要步骤和流程&a…

WebRtc06: 音视频数据采集

音视频采集API 通过getUserMedia这个API去获取视频音频&#xff0c; 通过constraints这个对象去配置偏好&#xff0c;比如视频宽高、音频降噪等 测试代码 index.html <html><head><title>WebRtc capture video and audio</title></head><…

浅析CDN安全策略防范

CDN&#xff08;内容分发网络&#xff09;信息安全策略是保障内容分发网络在提供高效服务的同时&#xff0c;确保数据传输安全、防止恶意攻击和保护用户隐私的重要手段。以下从多个方面详细介绍CDN的信息安全策略&#xff1a; 1. 数据加密 数据加密是CDN信息安全策略的核心之…