DenseNet分类网络改进(添加SPP)--亲测有效

最近,我在做多分类问题。在针对基模型的选择中,我使用了DenseNet作为基本模型。我在阅读论文时,遇到了一种改进方式:
在这里插入图片描述

如上图所示,在全连接层之前引入SPP模块:
在这里插入图片描述
代码如下:

SPP模块代码:

class SpatialPyramidPooling(nn.Module):
    def __init__(self, pool_sizes: List[int], in_channels: int):
        super(SpatialPyramidPooling, self).__init__()
        self.pool_sizes = pool_sizes
        self.in_channels = in_channels
        self.pool_layers = nn.ModuleList([
            nn.AdaptiveMaxPool2d(output_size=(size, size)) for size in pool_sizes
        ])

    def forward(self, x: Tensor) -> Tensor:
        pools = [pool_layer(x) for pool_layer in self.pool_layers]

        # Resize the output of each pool to have the same number of channels
        pools_resized = [F.adaptive_max_pool2d(pool, (1, 1)) for pool in pools]

        spp_out = torch.cat(pools_resized, dim=1)  # Concatenate the resized pools
        return spp_out

加入SPP代码后的DenseNet网络完整如下:

import re
from typing import List, Tuple, Any
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from torch import Tensor

class _DenseLayer(nn.Module):
    def __init__(self, input_c: int, growth_rate: int, bn_size: int, drop_rate: float, memory_efficient: bool = False):
        super(_DenseLayer, self).__init__()

        self.add_module("norm1", nn.BatchNorm2d(input_c))
        self.add_module("relu1", nn.ReLU(inplace=True))
        self.add_module("conv1", nn.Conv2d(in_channels=input_c, out_channels=bn_size * growth_rate,
                                           kernel_size=1, stride=1, bias=False))
        self.add_module("norm2", nn.BatchNorm2d(bn_size * growth_rate))
        self.add_module("relu2", nn.ReLU(inplace=True))
        self.add_module("conv2", nn.Conv2d(bn_size * growth_rate, growth_rate,
                                           kernel_size=3, stride=1, padding=1, bias=False))
        self.drop_rate = drop_rate
        self.memory_efficient = memory_efficient

    def bn_function(self, inputs: List[Tensor]) -> Tensor:
        concat_features = torch.cat(inputs, 1)
        bottleneck_output = self.conv1(self.relu1(self.norm1(concat_features)))
        return bottleneck_output

    @staticmethod
    def any_requires_grad(inputs: List[Tensor]) -> bool:
        for tensor in inputs:
            if tensor.requires_grad:
                return True
        return False

    @torch.jit.unused
    def call_checkpoint_bottleneck(self, inputs: List[Tensor]) -> Tensor:
        def closure(*inp):
            return self.bn_function(inp)

        return cp.checkpoint(closure, *inputs)

    def forward(self, inputs: Tensor) -> Tensor:
        if isinstance(inputs, Tensor):
            prev_features = [inputs]
        else:
            prev_features = inputs

        if self.memory_efficient and self.any_requires_grad(prev_features):
            if torch.jit.is_scripting():
                raise Exception("memory efficient not supported in JIT")

            bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
        else:
            bottleneck_output = self.bn_function(prev_features)

        new_features = self.conv2(self.relu2(self.norm2(bottleneck_output_with_cbam)))

        if self.drop_rate > 0:
            new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)

        return new_features

class _DenseBlock(nn.ModuleDict):
    _version = 2

    def __init__(self, num_layers: int, input_c: int, bn_size: int, growth_rate: int, drop_rate: float,
                 memory_efficient: bool = False):
        super(_DenseBlock, self).__init__()
        for i in range(num_layers):
            layer = _DenseLayer(input_c + i * growth_rate,
                                growth_rate=growth_rate,
                                bn_size=bn_size,
                                drop_rate=drop_rate,
                                memory_efficient=memory_efficient)
            self.add_module("denselayer%d" % (i + 1), layer)

    def forward(self, init_features: Tensor) -> Tensor:
        features = [init_features]
        for _, layer in self.items():
            new_features = layer(features)
            features.append(new_features)
        return torch.cat(features, 1)



class _Transition(nn.Sequential):
    def __init__(self, input_c: int, output_c: int):
        super(_Transition, self).__init__()
        self.add_module("norm", nn.BatchNorm2d(input_c))
        self.add_module("relu", nn.ReLU(inplace=True))
        self.add_module("conv", nn.Conv2d(input_c, output_c, kernel_size=1, stride=1, bias=False))
        self.add_module("pool", nn.AvgPool2d(kernel_size=2, stride=2))

class SpatialPyramidPooling(nn.Module):
    def __init__(self, pool_sizes: List[int], in_channels: int):
        super(SpatialPyramidPooling, self).__init__()
        self.pool_sizes = pool_sizes
        self.in_channels = in_channels
        self.pool_layers = nn.ModuleList([
            nn.AdaptiveMaxPool2d(output_size=(size, size)) for size in pool_sizes
        ])

    def forward(self, x: Tensor) -> Tensor:
        pools = [pool_layer(x) for pool_layer in self.pool_layers]

        # Resize the output of each pool to have the same number of channels
        pools_resized = [F.adaptive_max_pool2d(pool, (1, 1)) for pool in pools]

        spp_out = torch.cat(pools_resized, dim=1)  # Concatenate the resized pools
        return spp_out

class DenseNet(nn.Module):
    def __init__(self, growth_rate: int = 32, block_config: Tuple[int, int, int, int] = (6, 12, 24, 16),
                 num_init_features: int = 64, bn_size: int = 4, drop_rate: float = 0, num_classes: int = 1000,
                 memory_efficient: bool = False):
        super(DenseNet, self).__init__()

        # First conv+bn+relu+pool
        self.features = nn.Sequential(OrderedDict([
            ("conv0", nn.Conv2d(3, num_init_features, kernel_size=7, stride=2, padding=3, bias=False)),
            ("norm0", nn.BatchNorm2d(num_init_features)),
            ("relu0", nn.ReLU(inplace=True)),
            ("pool0", nn.MaxPool2d(kernel_size=3, stride=2, padding=1)),
        ]))

        # Each dense block
        num_features = num_init_features
        for i, num_layers in enumerate(block_config):
            block = _DenseBlock(num_layers=num_layers,
                                input_c=num_features,
                                bn_size=bn_size,
                                growth_rate=growth_rate,
                                drop_rate=drop_rate,
                                memory_efficient=memory_efficient)
            self.features.add_module("denseblock%d" % (i + 1), block)
            num_features = num_features + num_layers * growth_rate

            if i != len(block_config) - 1:
                trans = _Transition(input_c=num_features,
                                    output_c=num_features // 2)
                self.features.add_module("transition%d" % (i + 1), trans)
                num_features = num_features // 2

        # Final batch norm
        self.features.add_module("norm5", nn.BatchNorm2d(num_features))

        # Spatial Pyramid Pooling (SPP) layer
        spp_pool_sizes = [1, 4, 16]  # You can adjust pool sizes as needed
        self.spp = SpatialPyramidPooling(spp_pool_sizes, in_channels=num_features)

        # FC layer
        self.classifier = nn.Linear(num_features + len(spp_pool_sizes) * num_features, num_classes)

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.constant_(m.bias, 0)
    def forward(self, x: Tensor) -> Tensor:
        features = self.features(x)
        out = F.relu(features, inplace=True)
        # Apply Spatial Pyramid Pooling
        spp_out = self.spp(out)
        # Adjust the number of channels in out to match spp_out
        out = F.adaptive_avg_pool2d(out, (1, 1))
        # Concatenate the original feature map with the SPP output along the channel dimension
        out = torch.cat([spp_out, out], dim=1)
        # Flatten the spatial dimensions of out
        out = torch.flatten(out, 1)
        # FC layer
        out = self.classifier(out)
        return out


def densenet121(**kwargs: Any) -> DenseNet:
    # Top-1 error: 25.35%
    # 'densenet121': 'https://download.pytorch.org/models/densenet121-a639ec97.pth'
    return DenseNet(growth_rate=32,
                    block_config=(6, 12, 24, 16),
                    num_init_features=64,
                    **kwargs)
def load_state_dict(model: nn.Module, weights_path: str) -> None:
    # '.'s are no longer allowed in module names, but previous _DenseLayer
    # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'.
    # They are also in the checkpoints in model_urls. This pattern is used
    # to find such keys.
    pattern = re.compile(
        r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$')

    state_dict = torch.load(weights_path)

    num_classes = model.classifier.out_features
    load_fc = num_classes == 1000

    for key in list(state_dict.keys()):
        if load_fc is False:
            if "classifier" in key:
                del state_dict[key]

        res = pattern.match(key)
        if res:
            new_key = res.group(1) + res.group(2)
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
    model.load_state_dict(state_dict, strict=load_fc)
    print("successfully load pretrain-weights.")

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/222666.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

6 大 Android 照片恢复软件深度评测

当您发现令人难忘的照片在 Android 上被错误删除或丢失时,您可能会感到非常沮丧。然而,当您尝试安装一些恢复应用程序并希望将它们重新恢复时,其中许多应用程序无法再次检测到丢失的照片。 为了节省您逐一尝试此类应用程序的时间&#xff0c…

yum命令详解和使用

目录 一、yum介绍 常见的软件封包类型 RPM的含义 二、 yum命令格式 三、常用命令 3.1 yum清空缓存列表 3.2 yum显示信息 3.2 yum安装 3.3 yum删除 3.4 yum包的更新和升级 四、常用的国内 yum 源 一、yum介绍 Yum(全称为 Yellow dog Updater, Modified),是…

西工大计算机学院计算机系统基础实验一(环境配置)

首先,不要焦虑,稳住心态慢慢来,一点一点做,跟着作者把基础打好,比什么都重要。作者曾经经历过这份痛苦,知道它有多么不好受。当初的作者高中之前甚至都没有自己的一台笔记本,上了大学以后学C语言…

unity旋转选中效果

代码和预制体 函数PlayAnim()中的角度要根据按钮数量手动填好 using System; using DG.Tweening; using DG.Tweening.Core; using DG.Tweening.Plugins.Options; using UnityEngine;// Token: 0x0200001B RID: 27 public class BtnParentScript : Base…

数字化智慧工地管理云平台源码(人工智能、物联网)

​智慧工地优势:"智慧工地”将施工企业现场视频管理、建筑起重机械安全监控、现场从业人员管理、物料管理、进度管理、扬尘噪声监测等现场设备有机、高效、科学、规范的结合起来真正实现工程项目业务流与现场各类监控源数据流的有效结合与深度配合,…

html通过CDN引入使用Vue和ElementUI

html通过CDN引入使用Vue和ElementUI 近期遇到个需求,就是需要在.net MVC的项目中,对已有的项目的首页进行优化,也就是写原生html和js。但是咱是一个写前端的,写html还可以,.net的话,开发也不方便&#xff0…

MybatisPlus中的使用Wrapper自定义SQL

一、条件构造器 条件构造器提供了一种更加简洁和直观的方式来构建复杂的查询条件。它提供了一组静态方法,用于构建各种类型的查询条件,包括等于、不等于、大于、小于、包含等。使用条件构造器可以避免手动拼接SQL语句的麻烦,提高代码的可读性…

Pandas使用过程中的神器加持 你不用不要怪我

Pandas是我们日常处理表格数据最常用的包,但是对于数据分析来说,Pandas的DataFrame还不够直观,所以今天我们将介绍4个和Pandas相关的Python包,可以将Pandas的DataFrame转换交互式表格,让我们可以直接在上面进行数据分析…

mybatis-plus的用法。如果条件不为空则查询。返回结果类型转换

1.如果条件不为空则查询 偶尔看到这个代码。感觉大可不必 点击进入源码就会发现。 内部已经封装好了条件判断,可以说非常的人性化 而且改条件走的非常深 2.返回结果类型转换 page查询的时候,返回的结果只能是表对应的实体类。如果自定义扩展vo,dto需…

“城中村上建高楼”,开启一场数智化时代新修炼

“数字化也好,数智化也罢,你明明白白地告诉我,他们与信息化到底有什么区别?” “我在信息化方面已投入那么多,为什么又要投那么多钱搞数智化?” 中国软件网、海比研究院在《2024中国企业数智服务趋势洞察研…

Kubernetes(K8s)安全认证-10

安全认证 访问控制概述 Kubernetes作为一个分布式集群的管理工具,保证集群的安全性是其一个重要的任务。所谓的安全性其实就是保证对Kubernetes的各种客户端进行认证和鉴权操作。 客户端 在Kubernetes集群中,客户端通常有两类: User Acco…

【AI-Fix】解决地图展示包leafmap在Jupyter NoteBook中地图不显示的问题

1. 问题描述 新创建的环境想使用leafmap在Jupyter中进行地图展示,结果发现运行完成之后不显示,不论怎么重启都没有办法显示出来,以经验来看,多半是缺了包了。 于是去leafmap的官方文档查找原因,一开始并没有发现什么问…

用OpenCV与MFC写一个图像格式转换及简单处理程序

打开不同格式的图形文件,彩色装灰度图像、锐化、高斯滤波、边界检测及将其存储为需求格式是图像处理的最基本的操作。如果单纯用MFC编程,是一个令人头痛的事情,有不少的代码量。可用OpenCV与MFC编程就变得相对简单。下面来详细演示这一编程操…

Altair推出 Altair RapidMiner 2023 平台,提供生成式 AI 功能

Altair推出 Altair RapidMiner 2023 平台,提供生成式 AI 功能 更新包括自动聚类、扩展 SAS、Python 和 R 编程功能等 近日,Altair(纳斯达克股票代码:ALTR)近日宣布其数据分析和 AI 平台 Altair RapidMiner 取得了一系…

Python文件打包成exe可执行文件

我们平常用python写些脚本可以方便我们的学习办公,但限制就是需要有python环境才能运行。 那能不能直接在没有python环境的电脑上运行我们的脚本呢? 当然可以,那就是直接把python脚本打包成exe可执行程序(注针对win系统&#xf…

Linux之进程(一)

目录 一、概念 1、基本概念 2、描述进程的PCB 3、task_struct 二、查看进程 三、获取进程的PID和PPID 通过系统调用获取进程的PID和PPID 四、通过系统调用创建进程 1、fork函数创建子进程 2、用if进行分流 五、进程状态 1、操作系统进程状态 1、新建 2、运行 3、…

windows磁盘扩容

在我们为电脑加装一个新的硬盘之后,需要相应的操作才能将硬盘空间加到磁盘空间中。 1.打开磁盘管理器 windows系统可以选择在搜索中直接搜索计算机管理并将其打开 windows server的打开方式为服务器管理器→工具→计算机管理 打开计算机管理后就可以看到磁盘管理&…

【链表Linked List】力扣-234回文链表

目录 问题描述 解题过程 labuladong题解 问题描述 给你一个单链表的头节点 head ,请你判断该链表是否为回文链表。如果是,返回 true ;否则,返回 false 。 示例 1: 输入:head [1,2,2,1] 输出&#xff…

20、pytest中的参数化

官方实例 # content of test_expectation.pyimport pytestpytest.mark.parametrize("test_input, expected",[("35",8),("24",6),("6*9",42)]) def test_eval(test_input, expected):assert eval(test_input) expected# content of …

unsafe类和varhandle类讲解

Java的Unsafe类是一个非常特殊的类,它提供了一组原始、底层的操作,可以跳过Java的限制,直接操作内存和对象。这些操作可能会破坏Java的安全机制,所以Unsafe类被标记为不安全的。 Unsafe类提供了下列方法: allocateIns…