深度学习Day-19:DenseNet算法实战与解析

 🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客
 🍖 原作者:[K同学啊 | 接辅导、项目定制]

要求:

  • 根据 Pytorch 代码,编写出 TensorFlow 代码
  • 研究 DenseNet 与 ResNetV 的区别
  • 改进思路是否可以迁移到其他地方

一、前言

        在计算机视觉领域,卷积神经网络(CNN)已经成为最主流的方法,比如GoogLenet,VGG-16,Incepetion等模型。CNN史上的一个里程碑事件是ResNet模型的出现,ResNet可以训练出更深的CNN模型,从而实现更高的准确度。ResNet模型的核心是通过建立前面层与后面层之间的“短路连接”进而训练出更深的CNN网络。

        DenseNet,它的基本思路与ResNet一致,但是它建立的是前面所有层与后面层的密集连接(dense connection),它的名称也是由此而来。DenseNet的另一大特色是通过特征在channel上的连接来实现特征重用(feature reuse)。这些特点让DenseNet在参数和计算成本更少的情形下实现比ResNet更优的性能。

二、设计理念

        相比ResNet,DenseNet提出了一个更激进的密集连接机制:即互相连接所有的层,具体来说就是每个层都会接受其前面所有层作为其额外的输入。

        图1为ResNet网络的残差连接机制,作为对比,图2为DenseNet的密集连接机制。可以看到,ResNet是每个层与前面的某层(一般是2~4层)短路连接在一起,连接方式是通过元素相加。而在DenseNet中,每个层都会与前面所有层在channel维度上连接(concat)在一起(即元素叠加),并作为下一层的输入。

        对于一个 L 层的网络,DenseNet共包含 \frac{L(L+1))}{2} 个连接,相比ResNet,这是一种密集连接。而且DenseNet是直接concat来自不同层的特征图,这可以实现特征重用,提升效率,这一特点是DenseNet与ResNet最主要的区别。

 图1:ResNet网络的短路连接机制(其中+代表的是元素级相加操作)

 图2:DenseNet网络的密集连接机制(其中c代表的是channel级连接操作)

        而对于DesNet,则是采用跨通道concat的形式来连接,会连接前面所有层作为输入,输入和输出的公式是Xl​=Hl​(X0​,X1​,...,Xl−1​)。这里要注意所有的层的输入都来源于前面所有层在channel维度的concat,用一张动图体会一下:

图3 DenseNet的前向过程 

 三、网络结构

图4 DenseNet的网络结构 

        CNN网络一般要经过Pooling或者stride>1的Conv来降低特征图的大小,而DenseNet的密集连接方式需要特征图大小保持一致。为了解决这个问题,DenseNet网络采用DenseBlock+Transition的结构,其中DenseBlock是包含很多层的模块,每个层的特征图大小相同,层与层之间采用密集连接方式。而Transition层是连接两个相邻的DenseBlock,并且通过Pooling使特征图大小降低。图5给出了DenseNet的网路结构,它共包含4个DenseBlock,各个DenseBlock之间通过Transition层连接在一起。

图5 采用DenseBlock+Transition的DenseNet网络 

         在DenseBlock中,各个层的特征图大小一致,可以在channel维度上连接。DenseBlock中的非线性组合函数 H ( ⋅ )的是BN + ReLU + 3x3 Conv的结构,如图所示。另外值得注意的一点是,与ResNet不同,所有DenseBlock中各个层卷积之后均输入k个特征图,即得到的特征图的channel数为k,或者说采用k个卷积核。k在DenseNet称为growth rate,这是一个超参数。一般情况下使用较小的k,就可以得到较佳的性能。

图6 DenseBlock中的非线性转换结构 

        由于后面层的输入会非常大,DenseBlock内部可以采用bottleneck层来减少计算量,主要是原有的结构中增加1x1 Conv,如图7所示,即 BN + ReLU + 1x1 Conv + BN + ReLU + 3x3 Conv,称为DenseNet-B结构。其中1x1 Conv得到4k个特征图,它起到的作用是降低特征数量,从而提升计算的效率。

图7 使用bottleneck层的DenseBlock结构 

        对于Transition层,它主要是连接两个相邻的DenseBlock,并且降低特征图大小。Transition层包括一个1x1的卷积和2x2的AvgPooling,结构为BN+ReLU+1x1Conv+2x2AvgPooling。另外,Transition层可以起到压缩模型的作用。

四、Tensorflow实现

1.代码

import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow import keras

data_dir = "./data/bird_photos"
train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,  # 分割数据集
    subset="training",  # 数据集类型
    seed=123,
    image_size=(224, 224),
    batch_size=32)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,  # 分割数据集
    subset="validation",  # 数据集类型
    seed=123,
    image_size=(224, 224),
    batch_size=32)


class BottleNeck(keras.Model):
    def __init__(self, growth_rate, bn_size=4, dropout=0.3):
        super().__init__()
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation("relu"),
        self.conv1 = layers.Conv2D(filters=bn_size * growth_rate, kernel_size=(1, 1),
                                   strides=1, padding='same')
        self.bn2 = layers.BatchNormalization()
        self.conv2 = layers.Conv2D(filters=growth_rate, kernel_size=(3, 3),
                                   strides=1, padding='same')
        self.dropout = layers.Dropout(rate=dropout)

        self.listLayers = [
            self.bn1,
            self.relu,
            self.conv1,
            self.bn2,
            self.relu,
            self.conv2,
            self.dropout
        ]

    def call(self, x):
        tem = x
        for layer in self.listLayers.layers:
            x = layer(x)
        return layers.concatenate([tem, x], axis=-1)


class Transition(tf.keras.Model):
    def __init__(self, growth_rate):
        super().__init__()
        self.bn1 = layers.BatchNormalization()
        self.relu = layers.Activation('relu')
        self.conv1 = layers.Conv2D(filters=growth_rate, kernel_size=(1, 1),
                                   strides=1, activation='relu', padding='same')
        self.pooling = layers.AveragePooling2D(pool_size=(2, 2), strides=2, padding='same')

        self.listLayers = [
            self.bn1,
            self.relu,
            self.conv1,
            self.pooling
        ]

    def call(self, x):
        for layer in self.listLayers.layers:
            x = layer(x)
        return x


class DenseBlock(tf.keras.Model):
    def __init__(self, num_layer, growth_rate, bn_size=4, dropout=0.3, efficient=False):
        super().__init__()
        self.efficient = efficient
        self.listLayers = []
        if self.efficient:
            _x = tf.recompute_grad(BottleNeck(growth_rate, bn_size=bn_size, dropout=dropout))
        else:
            _x = BottleNeck(growth_rate, bn_size=bn_size, dropout=dropout)
        for _ in range(num_layer):
            self.listLayers.append(BottleNeck(growth_rate, bn_size=bn_size, dropout=dropout))

    def call(self, x):
        for layer in self.listLayers.layers:
            x = layer(x)
        return x


class DenseNet(tf.keras.Model):
    def __init__(self, num_init_feature, growth_rate, block_config, num_classes,
                 bn_size=4, dropout=0.3, compression_rate=0.5, efficient=False):
        super().__init__()
        self.num_channels = num_init_feature
        self.conv = layers.Conv2D(filters=num_init_feature, kernel_size=7,
                                  strides=2, padding='same')
        self.bn = layers.BatchNormalization()
        self.relu = layers.Activation('relu')
        self.max_pool = layers.MaxPool2D(pool_size=3, strides=2, padding='same')

        self.dense_block_layers = []
        for i in block_config[:-1]:
            self.dense_block_layers.append(DenseBlock(num_layer=i, growth_rate=growth_rate,
                                                      bn_size=bn_size, dropout=dropout, efficient=efficient))
            self.num_channels = compression_rate * (self.num_channels + growth_rate * i)
            self.dense_block_layers.append(Transition(self.num_channels))

        self.dense_block_layers.append(DenseBlock(num_layer=block_config[-1], growth_rate=growth_rate,
                                                  bn_size=bn_size, dropout=dropout, efficient=efficient))

        self.avgpool = layers.GlobalAveragePooling2D()
        self.fc = tf.keras.layers.Dense(units=num_classes, activation=tf.keras.activations.softmax)

    def call(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        x = self.max_pool(x)

        for layer in self.dense_block_layers.layers:
            x = layer(x)

        x = self.avgpool(x)
        return self.fc(x)


model = DenseNet(num_init_feature=64,
                 growth_rate=32,
                 block_config=[6, 12, 24, 16],
                 compression_rate=0.5,
                 num_classes=4,
                 dropout=0.0,
                 efficient=True)

x = tf.random.normal((1, 224, 224, 3))
for layer in model.layers:
    x = layer(x)
    print(layer.name, 'output shape:\t', x.shape)

AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().shuffle(1000).prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)
opt = tf.keras.optimizers.Adam(learning_rate=0.002, decay=0.01)

model.compile(optimizer=opt,
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
epochs = 10
history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)

2.输出

Found 565 files belonging to 4 classes.
Using 113 files for validation.
conv2d output shape:	 (1, 112, 112, 64)
batch_normalization output shape:	 (1, 112, 112, 64)
activation output shape:	 (1, 112, 112, 64)
max_pooling2d output shape:	 (1, 56, 56, 64)
dense_block output shape:	 (1, 56, 56, 256)
transition output shape:	 (1, 28, 28, 128)
dense_block_1 output shape:	 (1, 28, 28, 512)
transition_1 output shape:	 (1, 14, 14, 256)
dense_block_2 output shape:	 (1, 14, 14, 1024)
transition_2 output shape:	 (1, 7, 7, 512)
dense_block_3 output shape:	 (1, 7, 7, 1024)
global_average_pooling2d output shape:	 (1, 1024)
dense output shape:	 (1, 4)
Epoch 1/10
15/15 [==============================] - 83s 5s/step - loss: 2.5051 - accuracy: 0.4508 - val_loss: 37955.5703 - val_accuracy: 0.3186
Epoch 2/10
15/15 [==============================] - 75s 5s/step - loss: 1.1275 - accuracy: 0.6922 - val_loss: 3854.2537 - val_accuracy: 0.2301
Epoch 3/10
15/15 [==============================] - 78s 5s/step - loss: 0.6559 - accuracy: 0.7780 - val_loss: 794.8064 - val_accuracy: 0.3186
Epoch 4/10
15/15 [==============================] - 84s 6s/step - loss: 0.5599 - accuracy: 0.7926 - val_loss: 94.6405 - val_accuracy: 0.2655
Epoch 5/10
15/15 [==============================] - 74s 5s/step - loss: 0.7278 - accuracy: 0.7682 - val_loss: 45.8066 - val_accuracy: 0.3186
Epoch 6/10
15/15 [==============================] - 74s 5s/step - loss: 0.3567 - accuracy: 0.8904 - val_loss: 16.1634 - val_accuracy: 0.3186
Epoch 7/10
15/15 [==============================] - 74s 5s/step - loss: 0.2239 - accuracy: 0.9287 - val_loss: 10.3661 - val_accuracy: 0.3186
Epoch 8/10
15/15 [==============================] - 75s 5s/step - loss: 0.1488 - accuracy: 0.9406 - val_loss: 1.8957 - val_accuracy: 0.5044
Epoch 9/10
15/15 [==============================] - 75s 5s/step - loss: 0.1024 - accuracy: 0.9630 - val_loss: 1.1245 - val_accuracy: 0.6018
Epoch 10/10
15/15 [==============================] - 76s 5s/step - loss: 0.0563 - accuracy: 0.9895 - val_loss: 1.2219 - val_accuracy: 0.5133

        具体代码细节在之前的文章中已有涉及,故不再做具体解释。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/683659.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SBOM是如何帮助医疗器械制造商提高产品透明度的?

SBOM(软件物料清单)通过以下方式帮助医疗器械制造商提高产品透明度: 1. 详细记录软件组成 SBOM详细列出了医疗器械所使用的所有软件组件、版本、作者、许可证信息等。这使得制造商能够清晰地了解产品的软件组成,包括每个组件的来…

Plotly的魔力:如何用Python创建令人惊叹的图表?

大家好,在数据分析和可视化领域,图表是不可或缺的工具。它们可以帮助我们更直观地理解数据趋势和模式。今天,我们要介绍的是一个强大的Python库——Plotly,它可以让你轻松创建交互式、漂亮的图表。无论你是数据科学家、分析师&…

C++学习/复习12--vector的实现(三个基本成员函数/迭代器/扩容/插入删除/重载/测试/杨辉三角)

一、构造函数 1.匿名对象与构造函数 在C中,匿名对象是一个临时对象,它没有名称,通常在对象创建后,只使用一次后就被销毁。创建匿名对象的方式是在创建对象时不使用变量名。 下面是创建匿名对象的几种方式: 直接使用…

【Linux】(二)—— 用户和用户组管理

在了解了Linux基础命令格式后,今天我要介绍的是Linux的用户管理 目录 root用户用户管理操作用户ID切换用户查看用户添加用户指定密码修改用户信息删除用户退出登录查看登陆系统的用户 用户组的管理增加新用户组查看用户组删除用户组修改用户组切换用户组 与用户有关…

GAMES101笔记

Games101 文章目录 Games101一、Review of Linear Algebra 线性代数向量向量点乘:判断同向反向、接近程度点乘基本属性点乘坐标系运算向量投影向量叉乘:判定左右、判定内外叉乘基本属性(右手坐标系):不满足交换律叉乘坐…

CSS函数:fit-content与matrix的使用

网格函数 fit-content()属于网格函数,除此之外的网格函数还有:CSS函数: 实现数据限阈的数字函数。顾名思义,这三个函数只能在网格布局中使用。fit-content()函数主要是用于给定布局可用大小,适应内容,其功…

【数据结构】排序(直接插入、折半插入、希尔排序、快排、冒泡、选择、堆排序、归并排序、基数排序)

目录 排序一、插入排序1.直接插入排序2.折半插入排序3.希尔排序 二、交换排序1.快速排序2.冒泡排序 三、选择排序1. 简单选择排序2. 堆排序3. 树排序 四、归并排序(2-路归并排序)五、基数排序1. 桶排序(适合元素关键字值集合并不大)2. 基数排序基数排序的…

十六、【源码】plugins插件

源码地址:https://github.com/mybatis/mybatis-3/ 仓库地址:https://gitcode.net/qq_42665745/mybatis/-/tree/16-plugin plugins插件 plugins功能分为两部分 1.首先是在xml解析时解析出配置的插件类,并实例化放入拦截器链中 2.其次执行…

量化投资分析平台 迅投 QMT(四)获取标的期权的代码

量化投资分析平台 迅投 QMT [迅投 QMT](https://www.xuntou.net/?user_code7NYs7O)我目前在使用有了底层标的如何获取期权的交易代码呢?上代码历史帖子 迅投 QMT 我目前在使用 两个月前(2024年4月)迅投和CQF有一个互动的活动,进…

Java实现2048游戏源代码(启动即可玩)

使用Java语言从零到一开发经典2048游戏的全部源代码及详尽的开发教程。 适合Java初学者、游戏开发爱好者以及想要深入理解Swing GUI框架的开发者学习与参考。 知识领域:游戏开发、Java编程、图形用户界面(GUI)设计 技术关键词:Java Swing、事件监听、多…

“探索‘循环购‘:快消品行业的新商业模式与增长策略“

大家好,我是吴军,来自一家深耕于软件开发和商业模式创新的科技公司。我们的专长在于为各类企业量身打造商城系统,并提供个性化的商业模式解决方案。迄今为止,我们已经助力众多企业成功实施了超过200种前沿的商业模式,实…

阿奇科技 简单java-swing计算器源码(可用于课设等)

此系统用的技术有java swing! 实现的功能: 加减乘除(可以进行小数运算) 清空数据 最小化 小巧方便,功能齐全! 页面截图: 源码地址:点击这里下载源码 获取全套代码,或咨询更多代码…

【JavaEE】Spring Boot 日志详解

一 日志概述 日志是用于记录系统运行状态、用户操作和重大事件的工具。 1.日志的用途 系统监控 监控现在几乎是一个成熟系统的标配, 我们可以通过日志记录这个系统的运行状态, 每⼀个方法的响应时间, 响应状态等, 对数据进行分析, 设置不同的规则, 超过阈值时进行报警. 比如统…

科技赋能,无障碍出行的新纪元

在现代社会,公共设施的建设不仅是衡量城市文明程度的标尺,更是实现社会公平与包容的重要载体。对于盲人群体而言,一个完善的公共设施网络,意味着他们能够更加独立、自信地融入社会,享受与视力健全者同等的公共服务与便…

湖南(品牌控价)源点调研 手机价格管理对品牌的影响分析

前言:手机自发明以来,过去一直是国际品牌占主导地位,从最初的爱立信、摩托罗拉,到后来的诺基亚、三星,苹果在这个手机行业里,竞争激励,没有百年企业,每个品牌的盛衰都有背后的历史背…

软件测试——Java自动化测试Selenium

目录 1.运行环境 2.环境配置 3.第一个浏览器程序 4.浏览器操作 5.元素定位 6.元素操作常用API 7.特殊元素定位与操作 8.元素三大等待 9.iframe操作 10.window操作 11.select选择框 12.js语句执行 13.鼠标操作 14.截图操作 1.运行环境 编译工具:IDEA …

教师自费出书的注意事项有哪些?

备案主编专著的优势:(qkfb88688) 1、副高、正高职称最高学术成果 2、专著可以代替核心 3、周期短、出书快、可重复使用 4、双号齐全:ISBN~CIP 5、版权长期有效 教师自费出书有以下一些注意事项: 关于书稿&…

python---正则表达式

本章目标: 1:能够知道在Python中使用正则要导入的模块; [了解] re模块 2:能够使用re模块匹配单个字符; [重点] \d \w 正则表达式的概述: 基本介绍 正则表达式,也叫做规则表达式,通常会说成[正则] 实际上正则表达式就是指符合一定规则的字符串,同时他能用于检查一段…

MyBatis拦截器使用方法

前言 MyBatis拦截器可以做的工作:SQL修改,分页操作,数据过滤,SQL执行时间性能监控等。 1. 基础介绍 1.1. 核心对象 从MyBatis代码实现的角度来看,MyBatis的主要的核心部件有以下几个: Configuration&am…

OpenStack所支持的虚拟化技术和KVM、Xen、Hyer-V、QEMU、Libvirt说明

OpenStack所支持的虚拟化技术主要包括以下几种: KVM (Kernel-based Virtual Machine): 基于Linux内核的虚拟化技术。在Linux内核中添加一个虚拟化模块来实现虚拟机的运行。是OpenStack用户使用较多的虚拟化技术,支持OpenStack的所有特性。通过QEMU模拟器…