【深度学习笔记】稠密连接网络(DenseNet)

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

5.12 稠密连接网络(DenseNet)

ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个:稠密连接网络(DenseNet) [1]。 它与ResNet的主要区别如图5.10所示。

在这里插入图片描述

图5.10 ResNet(左)与DenseNet(右)在跨层连接上的主要区别:使用相加和使用连结

图5.10中将部分前后相邻的运算抽象为模块 A A A和模块 B B B。与ResNet的主要区别在于,DenseNet里模块 B B B的输出不是像ResNet那样和模块 A A A的输出相加,而是在通道维上连结。这样模块 A A A的输出可以直接传入模块 B B B后面的层。在这个设计里,模块 A A A直接跟模块 B B B后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。

DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。前者定义了输入和输出是如何连结的,后者则用来控制通道数,使之不过大。

5.12.1 稠密块

DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构,我们首先在conv_block函数里实现这个结构。

import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def conv_block(in_channels, out_channels):
    blk = nn.Sequential(nn.BatchNorm2d(in_channels), 
                        nn.ReLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
    return blk

稠密块由多个conv_block组成,每块使用相同的输出通道数。但在前向计算时,我们将每块的输入和输出在通道维上连结。

class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super(DenseBlock, self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结
        return X

在下面的例子中,我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时,我们会得到通道数为 3 + 2 × 10 = 23 3+2\times 10=23 3+2×10=23的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。

blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])

5.12.2 过渡层

由于每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过 1 × 1 1\times1 1×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。

def transition_block(in_channels, out_channels):
    blk = nn.Sequential(
            nn.BatchNorm2d(in_channels), 
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2))
    return blk

对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。

blk = transition_block(23, 10)
blk(Y).shape # torch.Size([4, 10, 4, 4])

5.12.3 DenseNet模型

我们来构造DenseNet模型。DenseNet首先使用同ResNet一样的单卷积层和最大池化层。

net = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64), 
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,我们可以设置每个稠密块使用多少个卷积层。这里我们设成4,从而与上一节的ResNet-18保持一致。稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。

ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。这里我们则使用过渡层来减半高和宽,并减半通道数。

num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]

for i, num_convs in enumerate(num_convs_in_dense_blocks):
    DB = DenseBlock(num_convs, num_channels, growth_rate)
    net.add_module("DenseBlosk_%d" % i, DB)
    # 上一个稠密块的输出通道数
    num_channels = DB.out_channels
    # 在稠密块之间加入通道数减半的过渡层
    if i != len(num_convs_in_dense_blocks) - 1:
        net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2

同ResNet一样,最后接上全局池化层和全连接层来输出。

net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) 

我们尝试打印每个子模块的输出维度确保网络无误:

X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)

输出:

0  output shape:	 torch.Size([1, 64, 48, 48])
1  output shape:	 torch.Size([1, 64, 48, 48])
2  output shape:	 torch.Size([1, 64, 48, 48])
3  output shape:	 torch.Size([1, 64, 24, 24])
DenseBlosk_0  output shape:	 torch.Size([1, 192, 24, 24])
transition_block_0  output shape:	 torch.Size([1, 96, 12, 12])
DenseBlosk_1  output shape:	 torch.Size([1, 224, 12, 12])
transition_block_1  output shape:	 torch.Size([1, 112, 6, 6])
DenseBlosk_2  output shape:	 torch.Size([1, 240, 6, 6])
transition_block_2  output shape:	 torch.Size([1, 120, 3, 3])
DenseBlosk_3  output shape:	 torch.Size([1, 248, 3, 3])
BN  output shape:	 torch.Size([1, 248, 3, 3])
relu  output shape:	 torch.Size([1, 248, 3, 3])
global_avg_pool  output shape:	 torch.Size([1, 248, 1, 1])
fc  output shape:	 torch.Size([1, 10])

5.12.4 获取数据并训练模型

由于这里使用了比较深的网络,本节里我们将输入高和宽从224降到96来简化计算。

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0020, train acc 0.834, test acc 0.749, time 27.7 sec
epoch 2, loss 0.0011, train acc 0.900, test acc 0.824, time 25.5 sec
epoch 3, loss 0.0009, train acc 0.913, test acc 0.839, time 23.8 sec
epoch 4, loss 0.0008, train acc 0.921, test acc 0.889, time 24.9 sec
epoch 5, loss 0.0008, train acc 0.929, test acc 0.884, time 24.3 sec

小结

  • 在跨层连接上,不同于ResNet中将输入与输出相加,DenseNet在通道维上连结输入与输出。
  • DenseNet的主要构建模块是稠密块和过渡层。

参考文献

[1] Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (Vol. 1, No. 2).


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/439905.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

LInux-多线程基础概念

文章目录 前言预备页表详解缺页中断页表的映射 一、多线程是什么?轻量级进程 二、Pthread库pthread_create 前言 从本章的多线程开始,我们开始进入Linux系统的尾声,所以,在学习多线程的过程中,我们也会逐步对之前的内…

第十五课 PCB保姆级规则设置(四)

1.常用的规则设置 2.设置完规则后,重新设置规则检查 将不设置的规则,不勾选 3.布线规则 (1)先走短线(模块之内的线) 模块之内的线需要连接,长线的话提前打孔(占位置)…

微信小程序开发系列(十五)·bind 绑定的事件(冒泡事件)、catch 绑定的事件(非冒泡事件)

目录 1. 事件分类 1.1 bind 绑定的事件(冒泡事件) 1.2 catch 绑定的事件(非冒泡事件) 1.3 代码 1. 事件分类 事件分为 冒泡事件 和 非冒泡事件: 冒泡事件:当一个组件的事件被触发后,…

浅谈字典攻击

一、前言 字典攻击是一种常见的密码破解方法,它使用预先编制的字典文件作为攻击字典,通过尝试猜测密码的方式来破解密码。下面是一个关于字典攻击的博客,希望能够为您了解字典攻击提供帮助。 二、字典攻击概述 字典攻击是一种密码破解方法&…

微信小程序 弹窗禁止底部页面滚动

解决方法: 在弹窗上给添加catchtouchmove catchtouchmovereturn

个人商城系统开源(展示宣传)

原文地址:个人商城系统开源(展示宣传) - Pleasure的博客 下面是正文内容: 前言 由于近期实在没有什么话题可写和一些有趣的项目教程可以分享。所以我只能决定将我自己亲手编写的一个迷你商城系统进行开源。 也就是放在我博客右边…

VR虚拟生态探索介绍|元宇宙文旅|VR设备价格

虚拟现实(VR)虚拟生态探索是一种通过VR技术让用户体验并探索虚拟环境的活动。在这样的环境中,用户可以通过头戴式显示器和其他感知设备沉浸在一个虚拟的世界中,仿佛身临其境。 在VR虚拟生态探索中,用户可以探索各种不同…

数据库备份.....

一.环境准备 数据库备份,数据库为school,素材如下 >create database school; >use school1.创建student和score表CREATE TABLE student ( id INT(10) NOT NULL UNIQUE PRIMARY KEY , name VARCHAR(20) NOT NULL , sex VARCHAR(4) , bi…

【PCL】(二十五)基于Min-Cut的点云分割

Min-Cut Based Segmentation 提出这个方法的论文&#xff1a;Min-Cut Based Segmentation of Point Clouds 最大流最小割问题 min_cut_segmentation.cpp #include <iostream> #include <vector> #include <pcl/io/pcd_io.h> #include <pcl/point_types…

基于ZEM系列台式扫描电镜和EDS能谱仪的岩石元素成分分析研究

上一期里我们对几种不同种类的岩石在特定作用下的破坏特征进行了表征&#xff0c;感兴趣的同学可以阅读《从宏观到微观——泽攸科技ZEM系列台式扫描电镜在岩石分析中的应用》。岩石是地球表层和地壳的基本组成物质&#xff0c;其成因和成岩过程复杂多样&#xff0c;岩石的元素组…

PythonStudio 控件使用常用方式(十)TLabel和TLinkLabel

PythonStudio是一个极强的开发Python的IDE工具&#xff0c;它使用的是Delphi的控件&#xff0c;常用的内容是与Delphi一致的。但是相关文档并一定完整。现在我试试能否逐步把它的控件常用用法写一点点&#xff0c;也作为PythonStudio的参考。 Tlabel是最常用的控件&#xff0c…

【论文阅读】Segment Anything论文梳理

Abstract 我们介绍了Segment Anything&#xff08;SA&#xff09;项目&#xff1a;新的图像分割任务、模型和数据集。高效的数据循环采集&#xff0c;使我们建立了迄今为止最大的分割数据集&#xff0c;在1100万张图像中&#xff0c;共超过10亿个掩码。 该模型被设计和训练为可…

【最详细的python教程】Python基本数据类型及运算符

&#x1f383;Python基本数据类型 &#x1f38b;整数** 整数是表述整数的数值&#xff0c;没有小数部分。在Python中&#xff0c;整数包括正整数、负整数和0&#xff0c;在Python中数值超过计算机自身的计算功能时会自动转用高精度计算。整数类型包括二进制、八进制、十进制和…

摘花生c++

题目 Hello Kitty想摘点花生送给她喜欢的米老鼠。 她来到一片有网格状道路的矩形花生地(如下图)&#xff0c;从西北角进去&#xff0c;东南角出来。 地里每个道路的交叉点上都有种着一株花生苗&#xff0c;上面有若干颗花生&#xff0c;经过一株花生苗就能摘走该它上面所有的…

指针易错点(超详细)

&#x1f4cc; 博客主页 爆打维c 本文将介绍指针基本知识点及易错点&#xff0c;刚入门学习c语言的小伙伴们可以收藏起来&#xff0c;方便找到。 目录 一、指针是什么&#xff1f; 1.const修饰指针 总结: 2.野指针 野指针成因: 3.指针数组与数组指针的区别 3.1指针数…

数据分析师必备:五款数据可视化工具对比与推荐

在数字化时代&#xff0c;数据可视化产品成为了企业和个人进行数据分析、信息呈现的重要工具。市面上涌现了众多数据可视化产品&#xff0c;它们各具特色&#xff0c;功能各异。本文为大家简要介绍五款市面上热门的数据可视化产品。 一、Tableau Tableau是一款功能强大的数据…

精读《React Conf 2019 - Day2》

1 引言 这是继 精读《React Conf 2019 - Day1》 之后的第二篇&#xff0c;补充了 React Conf 2019 第二天的内容。 2 概述 & 精读 第二天的内容更为精彩&#xff0c;笔者会重点介绍比较干货的部分。 Fast refresh Fast refresh 是更好的 react-hot-loader 替代方案&am…

mysql日常优化的总结

文章目录 一、数据表结构相关优化建字段类型注意事项1. int类型的选择2.varchar、char、text类型3.date、datetime、timestamp类型 表规划1. 垂直分表2. 水平分表 二、查询语句优化1.对于字段多的表&#xff0c;避免使用SELECT *2.避免使用!操作符3.避免使用null做条件4.like查…

echarts中toolbox 中文乱码问题

问题描述 本地引用的echarts源文件&#xff0c;页面其他部分编码显示正常&#xff0c;唯独toolbox鼠标悬停在上面时提示信息显示乱码。 如图所示&#xff1a; 尝试过的方法 使用sublime text 3&#xff0c;notepad&#xff0c;记事本更改文件编码为utf-8引入时&#xff0c;在sc…

爱普生宣布开发出独特的宽幅度LVDS输出 —可灵活选择与LSI

爱普生宣布开发出独特的宽幅度LVDS输出 —可灵活选择与LSI -相匹配的低噪声输出 精工爱普生公司(TSE: 6724&#xff0c;“爱普生”)开发了一种新的晶体振荡器差分输出方案。新方案&#xff0c;宽幅低压差分信号(WA-LVDS)&#xff0c;可以灵活选择最适合LSI所需的幅值水平的输出…