【深度学习笔记】5_12稠密连接网络(DenseNet)

注:本文为《动手学深度学习》开源内容,部分标注了个人理解,仅为个人学习记录,无抄袭搬运意图

5.12 稠密连接网络(DenseNet)

ResNet中的跨层连接设计引申出了数个后续工作。本节我们介绍其中的一个:稠密连接网络(DenseNet) [1]。 它与ResNet的主要区别如图5.10所示。

在这里插入图片描述

图5.10 ResNet(左)与DenseNet(右)在跨层连接上的主要区别:使用相加和使用连结

图5.10中将部分前后相邻的运算抽象为模块 A A A和模块 B B B。与ResNet的主要区别在于,DenseNet里模块 B B B的输出不是像ResNet那样和模块 A A A的输出相加,而是在通道维上连结。这样模块 A A A的输出可以直接传入模块 B B B后面的层。在这个设计里,模块 A A A直接跟模块 B B B后面的所有层连接在了一起。这也是它被称为“稠密连接”的原因。

DenseNet的主要构建模块是稠密块(dense block)和过渡层(transition layer)。前者定义了输入和输出是如何连结的,后者则用来控制通道数,使之不过大。

5.12.1 稠密块

DenseNet使用了ResNet改良版的“批量归一化、激活和卷积”结构,我们首先在conv_block函数里实现这个结构。

import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..") 
import d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def conv_block(in_channels, out_channels):
    blk = nn.Sequential(nn.BatchNorm2d(in_channels), 
                        nn.ReLU(),
                        nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1))
    return blk

稠密块由多个conv_block组成,每块使用相同的输出通道数。但在前向计算时,我们将每块的输入和输出在通道维上连结。

class DenseBlock(nn.Module):
    def __init__(self, num_convs, in_channels, out_channels):
        super(DenseBlock, self).__init__()
        net = []
        for i in range(num_convs):
            in_c = in_channels + i * out_channels
            net.append(conv_block(in_c, out_channels))
        self.net = nn.ModuleList(net)
        self.out_channels = in_channels + num_convs * out_channels # 计算输出通道数

    def forward(self, X):
        for blk in self.net:
            Y = blk(X)
            X = torch.cat((X, Y), dim=1)  # 在通道维上将输入和输出连结
        return X

在下面的例子中,我们定义一个有2个输出通道数为10的卷积块。使用通道数为3的输入时,我们会得到通道数为 3 + 2 × 10 = 23 3+2\times 10=23 3+2×10=23的输出。卷积块的通道数控制了输出通道数相对于输入通道数的增长,因此也被称为增长率(growth rate)。

blk = DenseBlock(2, 3, 10)
X = torch.rand(4, 3, 8, 8)
Y = blk(X)
Y.shape # torch.Size([4, 23, 8, 8])

5.12.2 过渡层

由于每个稠密块都会带来通道数的增加,使用过多则会带来过于复杂的模型。过渡层用来控制模型复杂度。它通过 1 × 1 1\times1 1×1卷积层来减小通道数,并使用步幅为2的平均池化层减半高和宽,从而进一步降低模型复杂度。

def transition_block(in_channels, out_channels):
    blk = nn.Sequential(
            nn.BatchNorm2d(in_channels), 
            nn.ReLU(),
            nn.Conv2d(in_channels, out_channels, kernel_size=1),
            nn.AvgPool2d(kernel_size=2, stride=2))
    return blk

对上一个例子中稠密块的输出使用通道数为10的过渡层。此时输出的通道数减为10,高和宽均减半。

blk = transition_block(23, 10)
blk(Y).shape # torch.Size([4, 10, 4, 4])

5.12.3 DenseNet模型

我们来构造DenseNet模型。DenseNet首先使用同ResNet一样的单卷积层和最大池化层。

net = nn.Sequential(
        nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3),
        nn.BatchNorm2d(64), 
        nn.ReLU(),
        nn.MaxPool2d(kernel_size=3, stride=2, padding=1))

类似于ResNet接下来使用的4个残差块,DenseNet使用的是4个稠密块。同ResNet一样,我们可以设置每个稠密块使用多少个卷积层。这里我们设成4,从而与上一节的ResNet-18保持一致。稠密块里的卷积层通道数(即增长率)设为32,所以每个稠密块将增加128个通道。

ResNet里通过步幅为2的残差块在每个模块之间减小高和宽。这里我们则使用过渡层来减半高和宽,并减半通道数。

num_channels, growth_rate = 64, 32  # num_channels为当前的通道数
num_convs_in_dense_blocks = [4, 4, 4, 4]

for i, num_convs in enumerate(num_convs_in_dense_blocks):
    DB = DenseBlock(num_convs, num_channels, growth_rate)
    net.add_module("DenseBlosk_%d" % i, DB)
    # 上一个稠密块的输出通道数
    num_channels = DB.out_channels
    # 在稠密块之间加入通道数减半的过渡层
    if i != len(num_convs_in_dense_blocks) - 1:
        net.add_module("transition_block_%d" % i, transition_block(num_channels, num_channels // 2))
        num_channels = num_channels // 2

同ResNet一样,最后接上全局池化层和全连接层来输出。

net.add_module("BN", nn.BatchNorm2d(num_channels))
net.add_module("relu", nn.ReLU())
net.add_module("global_avg_pool", d2l.GlobalAvgPool2d()) # GlobalAvgPool2d的输出: (Batch, num_channels, 1, 1)
net.add_module("fc", nn.Sequential(d2l.FlattenLayer(), nn.Linear(num_channels, 10))) 

我们尝试打印每个子模块的输出维度确保网络无误:

X = torch.rand((1, 1, 96, 96))
for name, layer in net.named_children():
    X = layer(X)
    print(name, ' output shape:\t', X.shape)

输出:

0  output shape:	 torch.Size([1, 64, 48, 48])
1  output shape:	 torch.Size([1, 64, 48, 48])
2  output shape:	 torch.Size([1, 64, 48, 48])
3  output shape:	 torch.Size([1, 64, 24, 24])
DenseBlosk_0  output shape:	 torch.Size([1, 192, 24, 24])
transition_block_0  output shape:	 torch.Size([1, 96, 12, 12])
DenseBlosk_1  output shape:	 torch.Size([1, 224, 12, 12])
transition_block_1  output shape:	 torch.Size([1, 112, 6, 6])
DenseBlosk_2  output shape:	 torch.Size([1, 240, 6, 6])
transition_block_2  output shape:	 torch.Size([1, 120, 3, 3])
DenseBlosk_3  output shape:	 torch.Size([1, 248, 3, 3])
BN  output shape:	 torch.Size([1, 248, 3, 3])
relu  output shape:	 torch.Size([1, 248, 3, 3])
global_avg_pool  output shape:	 torch.Size([1, 248, 1, 1])
fc  output shape:	 torch.Size([1, 10])

5.12.4 获取数据并训练模型

由于这里使用了比较深的网络,本节里我们将输入高和宽从224降到96来简化计算。

batch_size = 256
# 如出现“out of memory”的报错信息,可减小batch_size或resize
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96)

lr, num_epochs = 0.001, 5
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
d2l.train_ch5(net, train_iter, test_iter, batch_size, optimizer, device, num_epochs)

输出:

training on  cuda
epoch 1, loss 0.0020, train acc 0.834, test acc 0.749, time 27.7 sec
epoch 2, loss 0.0011, train acc 0.900, test acc 0.824, time 25.5 sec
epoch 3, loss 0.0009, train acc 0.913, test acc 0.839, time 23.8 sec
epoch 4, loss 0.0008, train acc 0.921, test acc 0.889, time 24.9 sec
epoch 5, loss 0.0008, train acc 0.929, test acc 0.884, time 24.3 sec

小结

  • 在跨层连接上,不同于ResNet中将输入与输出相加,DenseNet在通道维上连结输入与输出。
  • DenseNet的主要构建模块是稠密块和过渡层。

参考文献

[1] Huang, G., Liu, Z., Weinberger, K. Q., & van der Maaten, L. (2017). Densely connected convolutional networks. In Proceedings of the IEEE conference on computer vision and pattern recognition (Vol. 1, No. 2).


注:除代码外本节与原书此节基本相同,原书传送门

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/451705.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Python学习:基础语法

版本查看 python --version编码 默认情况下,Python 3 源码文件以 UTF-8 编码,所有字符串都是 unicode 字符串。 特殊情况下,也可以为源码文件指定不同的编码: # -*- coding: cp-1252 -*-标识符 第一个字符必须是字母表中字母或…

Java学习笔记------常用API

Math类 常用方法: 1. publicb static int abs(int a) 获取参数绝对值 2. publicb static double ceil(double a) 向上取整 3. publicb static floor(double a) 向下取整 4.public static int round(float a) 四舍五入 5. publicb static int max…

Vue3全家桶 - VueRouter - 【2】重定向路由

重定向路由 在路由规则数组中,可采用 redirect 来重定向到另一个地址: 通常是将 / 重定向到 某个页面; 示例展示: router/index.js:import { createRouter, createWebHashHistory, createWebHistory } from vue-route…

云桥通SDWAN企业组网的15大应用场景

云桥通SD-WAN企业组网技术在企业网络中有多样化的应用场景,在技术不断迭代升级中,已经越来越匹配现在的互联网环境,其中在这15中常见的应用场景中,使用云桥通SDWAN企业组网可以很好的帮到企业: 分支机构连接优化&#…

蓝桥杯之【01背包模版】牛客例题展示

牛客链接 #include <bits/stdc.h> using namespace std; int n,V; const int N1010; int v[N],w[N]; int dp[N][N]; int main() {cin>>n>>V;for(int i1;i<n;i){cin>>v[i]>>w[i];}for(int i1;i<n;i){for(int j1;j<V;j){dp[i][j]dp[i-1][…

力扣刷题Days16(js)-67二进制求和

目录 1,题目 2&#xff0c;代码 2.1转换进制数 2.2模拟加法 3&#xff0c;学习与总结 Math.floor() 模拟加法思路回顾 重点复习巩固 模拟加法的思路和学习位运算&#xff1b; 今天没精力了&#xff0c;先休息 1,题目 给你两个二进制字符串 a 和 b &#xff0c;以二进制…

CSS 用 flex 布局绘制骰子

<!DOCTYPE html> <html><head><meta charset"utf-8"><title></title><style>.box {height: 100px;width: 100px;border: 2px solid grey;border-radius: 10px;display: flex;justify-content: center; // 水平居中/* alig…

桥接模式以及在JDBC源码剖析

介绍&#xff1a; 1、桥接模式是指&#xff1a;将实现和抽象放在两个不同类层次中&#xff0c;使两个层次可以独立改变 2、是一种结构型设计模式 3、Bridge模式基于类的最小设计原则&#xff0c;通过使用封装、聚合以及继承等行为让不同的类承担不同的职责。 4、特点&#xff1…

【DAY11 软考中级备考笔记】数据结构 查找和排序

数据结构 查找和排序 3月12日 – 天气&#xff1a;晴 1. 顺序查找 顺序查找就是简单的从头一个一个的进行比较&#xff0c;注意它的平均查找长度 2. 折半查找 折半查找和二叉排序树一致&#xff1a; 优点&#xff1a;查找效率很高 缺点&#xff1a;要求必须是循序存储并且表中…

LoadRunner学习:RuntimeSetting、参数化、关联、(unfinished

LoadRunner RuntimeSetting 运行时设置 在Vuser中设置Run-time Settings RunLogic&#xff1a;运行逻辑&#xff0c;决定了脚本真正执行逻辑&#xff0c; Init和End部分代码只能执行一次。决定脚本真正执行逻辑的意思是&#xff0c;在Run中的代码和Number of Iteration决定了…

马斯克放出豪言,他旗下的xAI要把Grok开源了

每周跟踪AI热点新闻动向和震撼发展 想要探索生成式人工智能的前沿进展吗&#xff1f;订阅我们的简报&#xff0c;深入解析最新的技术突破、实际应用案例和未来的趋势。与全球数同行一同&#xff0c;从行业内部的深度分析和实用指南中受益。不要错过这个机会&#xff0c;成为AI领…

WeiPHP5.0远程代码执行漏洞

文章目录 前言声明一、漏洞描述二、影响版本三、漏洞复现四、修复建议 前言 weiphp 是一个开源&#xff0c;高效&#xff0c;简洁的微信开发平台 声明 请勿利用文章内的相关技术从事非法测试&#xff0c;由于传播、利用此文所提供的信息或者工具而造成的任何直接或者间接的后…

Learn OpenGL 08 颜色+基础光照+材质+光照贴图

我们在现实生活中看到某一物体的颜色并不是这个物体真正拥有的颜色&#xff0c;而是它所反射的(Reflected)颜色。物体的颜色为物体从一个光源反射各个颜色分量的大小。 创建光照场景 首先需要创建一个光源&#xff0c;因为我们以及有一个立方体数据&#xff0c;我们只需要进行…

04.JavaScript中的封装和函数

函数 理解函数的封装特性&#xff0c;掌握函数的语法规则 声明和调用 函数可以把具有相同或相似逻辑的代码“包裹”起来&#xff0c;通过函数调用执行这些被“包裹”的代码逻辑&#xff0c;这么做的优势是有利于精简代码方便复用。 声明&#xff08;定义&#xff09; 声明&a…

白嫖的学习资源,程序员绝对相见恨晚!

麦瑟尔夫说&#xff1a;厌学是心灵的癌症。 不学而原地踏步则会让你变成大笨蛋&#xff0c;甚至脑子瓦特......&#xff08;好吧&#xff0c;可以喷了&#xff0c;其实我在危言耸听&#xff09; 但是&#xff01;排除鸡汤的洗脑&#xff0c;学习的重要性是不可否认的&#xff…

初学者必会的Python3文件操作

文件操作的步骤&#xff1a; 打开文件 -> 操作文件 -> 关闭文件 切记&#xff1a;最后要关闭文件。 打开文件 文件句柄 open(文件路径, 模式) 指定文件编码 文件句柄 open(文件路径,模式,encodingutf-8) 为了防止忘记关闭文件&#xff0c;可以使用上下文管理器来…

大模型的整体性

大模型在人工智能领域&#xff0c;体现出一种高度的整体性特征。大模型的整体性表现在其能够跨越多种数据模态&#xff0c;统一表示&#xff0c;应用广泛的知识&#xff0c;以统一的方式处理复杂信息&#xff0c;并在多种场景下保持一致和有效的性能。这种整体性可以分为外部和…

PNG图片合成,带手机外观设置,可自定义金额等

PNG图片合成&#xff0c;带手机外观设置&#xff0c;可自定义金额等 软件界面成品显示免责声明 软件界面 成品显示 免责声明 若因使用代码与官方造成不必要的纠纷&#xff0c;本人盖不负责&#xff0c;存粹技术爱好&#xff0c;若侵犯贵公司的权益&#xff0c;请告知&#xff…

数据结构·二叉树(一)

1. 树概念及结构 1.1 树的概念 树是一种非线性的数据结构&#xff0c;它是由n(n>0)个有限节点组成一个具有层次关系的集合。把它叫做树是因为它看起来像一颗倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;树叶朝下的。 有一个特殊的节点&#xff0c;称为根节点&#…