Densenet+SE

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊# 前言

前言

这周开始学习关于经典模型的改进如加注意力机制,这周学习Densenet加通道注意力即SE注意力机制。

##SE注意力机制简介
SE(Squeeze-and-Excitation)注意力机制是一种用于增强卷积神经网络(CNN)性能的注意力机制,特别适用于图像分类任务。该机制由Jie Hu等人于2018年提出,旨在通过动态调整特征图中每个通道的权重,从而增强模型对重要特征的感知能力。

SE注意力机制的基本思想是利用全局信息来动态调整特征图中各个通道的重要性。具体来说,该机制包括两个关键步骤:Squeeze(压缩)和Excitation(激励)。

Squeeze(压缩):
在Squeeze阶段,通过全局池化操作(通常是全局平均池化)来压缩特征图在空间维度上的信息,将每个通道的特征图转换为单个数字。
假设输入特征图的尺寸为 𝐻×𝑊×𝐶其中 𝐻和 𝑊 分别是特征图的高度和宽度,𝐶是通道数。通过全局平均池化操作,将每个通道的特征图转换为长度为 𝐶的向量,表示每个通道的全局重要性。
Excitation(激励):
在Excitation阶段,利用一个小型的全连接(或者称为多层感知机)网络来学习每个通道的激励权重。将Squeeze阶段得到的长度为 𝐶 的向量输入到一个两层的全连接网络中。通过激活函数(如ReLU)和sigmoid函数,输出每个通道的激励权重(或者称为通道注意力权重)。这些权重用于对原始特征图进行加权,从而增强对重要特征的感知能力。

在这里插入图片描述

神经网络插入注意力机制

以resnet为例
在这里插入图片描述

SE插入Densenet

from collections import OrderedDict
import torch.utils.checkpoint as cp
import torch
import torch.nn as nn
import torch.nn.functional as F

def _bn_function_factory(norm, relu, conv):
    def bn_function(*inputs):
        concated_features = torch.cat(inputs, 1)
        bottleneck_output = conv(relu(norm(concated_features)))
        return bottleneck_output

    return bn_function


class _DenseLayer(nn.Module):
    def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, efficient=False):
        super(_DenseLayer, self).__init__()
        self.add_module('norm1', nn.BatchNorm2d(num_input_features)),
        self.add_module('relu1', nn.ReLU(inplace=True)),
        self.add_module('conv1', nn.Conv2d(num_input_features, bn_size * growth_rate,
                                           kernel_size=1, stride=1, bias=False)),
        self.add_module('norm2', nn.BatchNorm2d(bn_size * growth_rate)),
        self.add_module('relu2', nn.ReLU(inplace=True)),
        self.add_module('conv2', nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False)),

        self.add_module('SE_Block', SE_Block(growth_rate, reduction=16))
        self.drop_rate = drop_rate
        self.efficient = efficient

    def forward(self, *prev_features):
        bn_function = _bn_function_factory(self.norm1, self.relu1, self.conv1)
        if self.efficient and any(prev_feature.requires_grad for prev_feature in prev_features):
            bottleneck_output = cp.checkpoint(bn_function, *prev_features)
        else:
            bottleneck_output = bn_function(*prev_features)
        new_features = self.SE_Block(self.conv2(self.relu2(self.norm2(bottleneck_output))))
        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
        return new_features


class _Transition(nn.Sequential):
    def __init__(self, num_input_features, num_output_features):
        super(_Transition, self).__init__()
        self.add_module('norm', nn.BatchNorm2d(num_input_features))
        self.add_module('relu', nn.ReLU(inplace=True))
        self.add_module('conv', nn.Conv2d(num_input_features, num_output_features,
                                          kernel_size=1, stride=1, bias=False))
        self.add_module('pool', nn.AvgPool2d(kernel_size=2, stride=2))


class _DenseBlock(nn.Module):
    def __init__(self, num_layers, num_input_features, bn_size, growth_rate, drop_rate, efficient=False):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(
                num_input_features + i * growth_rate,
                growth_rate=growth_rate,
                bn_size=bn_size,
                drop_rate=drop_rate,
                efficient=efficient,
            )
            self.add_module('denselayer%d' % (i + 1), layer)

    def forward(self, init_features):
        features = [init_features]
        for name, layer in self.named_children():
            new_features = layer(*features)
            features.append(new_features)
        return torch.cat(features, 1)


class SE_Block(nn.Module):
    def __init__(self, ch_in, reduction=16):
        super(SE_Block, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)  # 全局自适应池化
        self.fc = nn.Sequential(
            nn.Linear(ch_in, ch_in // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(ch_in // reduction, ch_in, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)  # squeeze操作
        y = self.fc(y).view(b, c, 1, 1)  # FC获取通道注意力权重,是具有全局信息的
        return x * y.expand_as(x)  # 注意力作用每一个通道上


class DenseNet(nn.Module):
    def __init__(self, growth_rate, block_config, num_init_features=24, compression=0.5, bn_size=4, drop_rate=0,
                 num_classes=10, small_inputs=True, efficient=False):

        super(DenseNet, self).__init__()
        assert 0 < compression <= 1, 'compression of densenet should be between 0 and 1'

        # First convolution
        if small_inputs:
            self.features = nn.Sequential(OrderedDict([
                ('conv0', nn.Conv2d(3, num_init_features, kernel_size=3, stride=1, padding=1, bias=False)),
            ]))
        else:
            self.features = nn.Sequential(OrderedDict([
                ('conv0', nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ]))
            self.features.add_module('norm0', nn.BatchNorm2d(num_init_features))
            self.features.add_module('relu0', nn.ReLU(inplace=True))
            self.features.add_module('pool0', nn.MaxPool2d(kernel_size=3, stride=2, padding=1,
                                                           ceil_mode=False))

        # Each denseblock
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(
                num_layers=num_layers,
                num_input_features=num_features,
                bn_size=bn_size,
                growth_rate=growth_rate,
                drop_rate=drop_rate,
                efficient=efficient,
            )
            self.features.add_module('denseblock%d' % (i + 1), block)
            num_features = num_features + num_layers * growth_rate
            if i != len(block_config) - 1:
                trans = _Transition(num_input_features=num_features,
                                    num_output_features=int(num_features * compression))
                self.features.add_module('transition%d' % (i + 1), trans)
                num_features = int(num_features * compression)
            # self.features.add_module('SE_Block%d' % (i + 1),SE_Block(num_features, reduction=16))

        # Final batch norm
        self.features.add_module('norm_final', nn.BatchNorm2d(num_features))

        # Linear layer
        self.classifier = nn.Linear(num_features, num_classes)

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.adaptive_avg_pool2d(out, (1, 1))
        out = torch.flatten(out, 1)
        out = self.classifier(out)
        return out

验证效果

以猴痘病识别为例
在这里插入图片描述
在这里插入图片描述

总结

SE注意力机制的优点在于它能够有效地捕捉不同特征通道之间的关系,并且只引入了少量的额外参数和计算成本。通过引入SE模块,可以显著提升CNN模型在图像分类等任务上的性能,使其更加关注重要的特征,从而提高模型的泛化能力和准确性。SE后面还有许多注意力机制如将通道和空间注意机制结合的CBAM注意力机制,SE的改进ECA注意力机制等,后续将继续深入学习。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/605385.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

自定义shell

1、首先我们的程序要打印出命令行 命令行》用户名【主机名】当前路劲$:命令字符串 用户名、主机名、当前路径可以通过系统调用函数getenv()得到&#xff1a; 2、获取命令字符串 把输入的命令字符串放到一个指针数组中 但是我们发现用scanf函数输入的话&#xff0c;遇到空…

【数据结构】-- 链表专题

链表的分类 前面我们实现了单链表&#xff0c;单链表只是链表的一种。可以根据以下几个标准来判断链表的类型&#xff1a; 1.单向或者双向 如图所示&#xff0c;单向链表中一个节点的指针域只储存了下一个节点的指针&#xff0c;能通过前一个节点访问后一个节点&#xff0c;无…

Vue 3.3 编译宏 vue3.3新增了一些语法糖和宏,包括泛型组件、defineSlots、defineEmits、defineOptions

Vue 3.3新增了一些语法糖和宏&#xff0c;包括泛型组件、defineSlots、defineEmits、defineOptions defineProps 父组件传参 <template><Child name"my"></Child> </template> <script setup lang"ts"> import Child fro…

使用Docker安装Yapi接口管理工具

简介&#xff1a; YAPI 是由去哪儿网移动架构组开发的一款可视化接口管理工具。它具有可视化管理、高效易用、功能强大等特点。它提供了便捷的接口创建、发布和维护方式&#xff0c;开发人员可以通过简单的操作实现接口管理。 YAPI 还支持类似 postman 的接口调试&#xff0c;对…

06-数组

1. 为什么需要数组 一个养鸡场有6只鸡&#xff0c;它们的体重分别是3kg&#xff0c;4kg&#xff0c;1kg&#xff0c;2kg&#xff0c;6kg&#xff0c;3kg。 没有数组&#xff0c;就需要定义六个变量&#xff0c;一个变量代表一只鸡的体重。 使用数组&#xff0c;就可以定义一…

TypeScript学习日志-第二十三天(装饰器Decorator)

装饰器Decorator 一、类装饰器 ClassDecorator 其中返回的 target 是 Http 的构造函数&#xff0c;有了构造函数就不会去破坏其自身原有的结构&#xff0c;当我们 Http 里面有多个属性或者方法的&#xff0c;当是我们不想看或者改变它&#xff0c;这时候可以在构造函数中增加即…

【Mybatis操作数据库】入门(一)

个人主页&#xff1a;兜里有颗棉花糖 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 兜里有颗棉花糖 原创 收录于专栏【MyBatis框架】 本专栏旨在分享MyBatis框架的学习笔记&#xff0c;如有错误定当洗耳恭听&#xff0c;欢迎大家在评论区交流讨论&#x1f…

59岁前TVB男拳王内地登台疑黑面 被批耍大牌

现年59岁的郭政鸿在2015年离巢TVB后转往内地发展&#xff0c;密密拍剧、登台及直播带货&#xff0c;短短几年就已经储够钱&#xff0c;斥资过千万买楼&#xff0c;成功上车做业主&#xff0c;可见收入丰厚。 早前郭政鸿现身顺德&#xff0c;在酒吧登台唱歌&#xff0c;有网民上…

《铁路出行更便捷:火车票预定审批系统的设计与应用》

在现代化的铁路交通管理中&#xff0c;火车票预定审批系统扮演着至关重要的角色。它不仅能够有效管理员工出差、培训等需要乘坐火车的行程&#xff0c;还能够提高审批效率&#xff0c;减少人力成本&#xff0c;确保出行安全。本文将探讨火车票预定审批系统的设计原则和应用场景…

计算机毕业设计Python+Spark知识图谱医生推荐系统 医生门诊预测系统 医生数据分析 医生可视化 医疗数据分析 医生爬虫 大数据毕业设计 机器学习

摘 要 随着我国社会经济发展水平的不断提高&#xff0c;人们的物质生活水平也有了很大的改善&#xff0c;越来越多的人不满足于当前的医疗服务质量&#xff0c;由于地域和空间的限制&#xff0c;医疗资源不平衡&#xff0c;无法实现全民共享。针对当今社会中存在的求医难的问题…

智能家居2 -- 实现网络控制模块

这一模块的思路和前面的语言控制模块很相似&#xff0c;差别只是调用TCP 去控制 废话少说&#xff0c;放码过来 增添/修改代码 socket_interface.c #include <pthread.h>#include "socket_interface.h" #include "control.h" #include "socke…

【教程】超简单!如何将“在VSCode中打开”添加到右键菜单中

按照以下步骤进行操作&#xff1a; 打开注册表编辑器&#xff1a; 按下 Win R 组合键打开运行对话框。输入 regedit 并按下 Enter 键打开注册表编辑器。 导航到适当的注册表项&#xff1a; 转到以下注册表项&#xff1a;HKEY_CLASSES_ROOT\Directory\Background\shell 创建…

26版SPSS操作教程(高级教程第十九章)

目录 前言 粉丝及官方意见说明 第十九章一些学习笔记 第十九章一些操作方法 树模型、随机森林与最近邻元素法 树模型 数据准备 具体操作 结果解释 对案例的进一步分析 结果解释 考虑应用模型时的成本与收益 保存新数据 在选项中看错误分类成本和利润 结果解释…

【管理篇】如何管理情绪?

目录标题 为什么要特别关注激动和愤怒两种情绪呢&#xff1f;管理自己的情绪大致的步骤三层脑结构爬行脑情绪脑视觉脑 大家说的情绪管理&#xff0c;基本上都是对于情绪激动、生气甚至是愤怒的管理&#xff1b;日常所说的情绪化&#xff0c;一般也是指某个人特别容易情绪激动&a…

Gitlab自动化测试的配置

1. 代码分支命名规范检测 Setting → Repository → Push rules → Branch name&#xff0c;添加分支命名规范对应的正则表达式。如&#xff1a; ^(Release|Tag|Develop|Feature)_._.|Main$ 表示分支名只能以以下关键字之一开头&#xff1a;Release、Tag、Develop和Feature。 …

基于模糊控制的AMT自动变速汽车换档智能控制系统simulink建模与仿真

目录 1.课题概述 2.系统仿真结果 3.核心程序与模型 4.系统原理简介 5.完整工程文件 1.课题概述 基于模糊控制的AMT自动变速汽车换档智能控制系统simulink建模与仿真。 2.系统仿真结果 输入的V&#xff0c;Ac&#xff0c;a 输出的档位&#xff1a; 3.核心程序与模型 版…

【BST】Behavior Sequence Transformer for E-commerceRecommendation in Alibaba

一、提出背景 传统的Embedding&MLP模型结构将原始特征嵌入到低维向量中&#xff0c;然后将其concat后输入MLP进行最终推荐。DIN提出使用注意力机制来捕获候选项与用户先前点击的项之间的相似性。 然而&#xff0c;大多数这些工作只是连接不同的特征&#xff0c;而没有捕获用…

通过 Java 操作 redis -- hash 哈希表基本命令

目录 使用命令 hset&#xff0c;hget 使用命令 hexists 使用命令 hdel 使用命令 hkeys&#xff0c;hvals 使用命令 hmget&#xff0c;hmset 关于 redis hash 哈希表类型的相关命令推荐看Redis - hash 哈希表 要想通过 Java 操作 redis&#xff0c;首先要连接上 redis 服务…

AVL Cruise与Simulink联合仿真(通过MATLAB DLL方式)

最近毕业设计需要用到AVL Cruise与Simulink进行联合仿真&#xff0c;分析汽车模型的经济性。下面介绍一下我所知的AVL Cruise与Simulink联合仿真的几种方式&#xff0c;它们各自的优缺点&#xff0c;以及DLL方式联合仿真的具体配置过程。我这里用的MATLAB软件版本是2021a&#…

运行Spring Boot项目失败?显示java: 无法访问org.springframework.boot.SpringApplication,让我来看看~

idea项目运行报错截图&#xff1a; &#xff08;1&#xff09;查看错误提示“类文件具有错误的版本 61.0, 应为 52.0”&#xff0c;61.0对应的是jdk17&#xff0c;52.0对应1.8。 通过这个网址可以查询版本&#xff1a; https://stackoverflow.com/questions/9170832/list-of-ja…