分类神经网络1:VGGNet模型复现

目录

分类网络的常见形式

VGG网络架构

VGG网络部分实现代码


分类网络的常见形式

常见的分类网络通常由特征提取部分分类部分组成。

特征提取部分实质就是各种神经网络,如VGG、ResNet、DenseNet、MobileNet等。其负责捕获数据的有用信息,一般是通过堆叠多个卷积层和池化层来实现的,这些层有助于检测图像中的边缘、纹理和特征。

分类部分通常是一个全连接层,负责将特征提取部分输出的信息映射到最终的类别或标签。这些全连接层通常包括一个或多个隐藏层,以及一个输出层,其中输出层的节点数量等于任务中的类别数量。

VGG网络架构

论文原址:https://arxiv.org/pdf/1409.1556v6.pdf

VGG 网络是由牛津大学的Visual Geometry Group 开发的,其结构特点在于使用了多个 3x3 的小卷积核,并通过这些小卷积层的重复堆叠来构建网络,从而能够捕捉到更加复杂和抽象的特征表示。VGG 网络的模型结构如下:

VGG网络的核心架构可以分为以下几个部分:

  1. 输入层:VGG网络接受224x224像素的RGB图像作为输入。
  2. 卷积层:网络的前几层由多个卷积层组成,每个卷积层都使用3x3的卷积核来提取图像的特征。这些卷积层后面通常跟着一个2x2 最大池化,用于逐步减小特征图的空间尺寸,同时增加特征深度。
  3. 池化层:在卷积层之后,网络使用最大池化层来降低特征图的空间分辨率,这有助于减少计算量并提取更加抽象的特征。
  4. 全连接层:经过多个卷积和池化层之后,网络的特征图被展平并通过几个全连接层进行处理。全连接层的作用是将学习到的特征映射到最终的分类结果。
  5. 输出层:VGG网络的最后是一个softmax层,它将网络的输出转换为概率分布,以便进行多类别的分类任务。

VGG网络的一个显著特点是其深度,其相关配置信息如下:

VGG系列不同变体内容如下:

  • VGG A:这是一个基础的配置,没有特别独特的设计。
  • VGG A-LRN:在这个版本中,加入了局部响应归一化(LRN),这是一种在AlexNet中首次使用的技术。不过,LRN在当前的深度学习实践中已经较少被采用。
  • VGG B:相较于A版本,B版本增加了两个卷积层,以增强网络的学习能力。
  • VGG C:在B的基础上,C版本进一步增加了三个卷积层,但这些层使用的是1x1的卷积核。1x1卷积核可以看作是对输入特征图进行线性变换,有助于减少参数数量并增加非线性。
  • VGG D:D版本在C版本的基础上做了调整,将1x1的卷积核替换为3x3的卷积核,这个配置后来被称为VGG16,因为它总共有16层。
  • VGG E:在D版本的基础上,E版本进一步增加了三个3x3的卷积层,形成了VGG19,总共有19层。

从图中可以看出,随着网络深度的加深,模型变得更为复杂。通常来说,增加网络的深度可以增加模型的表示能力,使其能够学习到更复杂的特征和模式,从而在某些任务上取得更好的性能。然而,随着网络深度的增加,模型的参数数量也会增加,导致模型的复杂度增加,训练和推理的计算成本也会增加,同时可能会增加过拟合的风险。

VGG网络部分实现代码

废话不多说,直接上干货

import torch
import torch.nn as nn

__all__ = ["VGG", "vgg11_bn", "vgg13_bn", "vgg16_bn", "vgg19_bn"]

cfg = {
    'A': [64,     'M', 128,      'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'B': [64, 64, 'M', 128, 128, 'M', 256, 256,           'M', 512, 512,           'M', 512, 512,           'M'],
    'C': [64, 64, 'M', 128, 128, 'M', 256, 256, 256,      'M', 512, 512, 512,      'M', 512, 512, 512,      'M'],
    'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class ConvBNReLU(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1,  kernel_size=3, padding=1):
        super(ConvBNReLU, self).__init__()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
        self.bn = nn.BatchNorm2d(num_features=out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class VGG(nn.Module):
    def __init__(self, features, num_classes=1000, init_weights=True):
        super(VGG, self).__init__()
        self.features = features
        self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
        self.classifier = nn.Sequential(
            nn.Linear(512 * 7 * 7, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(4096, num_classes),
        )
        if init_weights:
            self._initialize_weights()

    def forward(self, x):
        for layer in self.features:
            x = layer(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.01)
                nn.init.constant_(m.bias, 0)

def make_layers(cfg):
    layers = nn.ModuleList()
    in_channels = 3
    for i in cfg:
        if i == 'M':
            layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        else:
            layers.append(ConvBNReLU(in_channels=in_channels, out_channels=i))
            in_channels = i
    return layers

def vgg11_bn(num_classes):
    model = VGG(make_layers(cfg['A']), num_classes=num_classes)
    return model

def vgg13_bn(num_classes):
    model = VGG(make_layers(cfg['B']), num_classes=num_classes)
    return model

def vgg16_bn(num_classes):
    model = VGG(make_layers(cfg['C']), num_classes=num_classes)
    return model

def vgg19_bn(num_classes):
    model = VGG(make_layers(cfg['D']), num_classes=num_classes)
    return model

if __name__=='__main__':
    import torchsummary
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    input = torch.ones(2, 3, 224, 224).to(device)
    net = vgg16_bn(num_classes=4)
    net = net.to(device)
    out = net(input)
    print(out)
    print(out.shape)
    torchsummary.summary(net, input_size=(3, 224, 224))
    # Total params: 134,285,380

这只是一个网络架构部分实现代码,其中 cfg 列表是 VGG 卷积和池化后的通道数,大家可以结合 VGG 的配置信息图一起对比理解。希望对大家有所帮助呀!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/562991.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

创新案例|Amazon.com 2023 年营销策略:电子商务零售巨头商业案例研究

2022 年最后一个季度,亚马逊报告净销售额超过 1,492 亿美元。这种季节性峰值是亚马逊季度报告的典型特征,但增长是不可否认的,因为这是该公司有史以来最高的季度。毫无疑问,这家电商零售巨头继续引领电商增长。本文将介绍我们今天…

Elasticsearch进阶篇(三):ik分词器的使用与项目应用

ik分词器的使用 一、下载并安装1.1 已有作者编译后的包文件1.2 只有源代码的版本1.3 安装ik分词插件 二、ik分词器的模式2.1 ik_smart演示2.2 ik_max_word演示2.3 standard演示 三、ik分词器在项目中的使用四、ik配置文件4.1 配置文件的说明4.2 自定义词库 五、参考链接 一、下…

mysql基础10——函数

数学函数 处理数值数据 取整函数 round(X,D) X表示要处理的数 D表示要保留的小数位数 处理的方式是四舍五入 round(X) 保留0位小数 金额要精确到分 说明保留两位小数 select round(salevalue,2) from demo.transactiondetails where transactionid1 and itemnum1; cei…

matplotlib从起点出发(15)_Tutorial_15_blitting

0 位图传输技术与快速渲染 Blitting,即位图传输、块传输技术是栅格图形化中的标准技术。在Matplotlib的上下文中,该技术可用于(大幅度)提高交互式图形的性能。例如,动画和小部件模块在内部使用位图传输。在这里&#…

记录一个hive中跑insert语句说没创建spark客户端的问题

【背景说明】 我目前搭建离线数仓,并将hive的执行引擎改成了Spark,在将ods层的数据装载到dim层,执行insert语句时报如下错误 【报错】 [42000][40000] Error while compiling statement: FAILED: SemanticException Failed to get a spark…

Rust序列化和反序列化

Rust 编写python 模块 必备库 docker 启动 nginx 服务 NGINX 反向代理配置

RAG技术从入门到精通

LLM之RAG技术从入门到精通 RAG技术介绍诞生背景定义 RAG与微调RAG流程架构RAG三种范式Naive RAGAdvanced RAG预检索过程嵌入后期检索过程RAG管道优化 Modular RAG RAG工作流程企业知识问答知识库RAG评估评价方法独立评估端到端评估 关键指标和能力 RAG优化RAG在企业知识库应用下…

WebSocket 快速入门 - springboo聊天功能

目录 一、概述 1、HTTP(超文本传输协议) 2、轮询和长轮询 3、WebSocket 二、WebSocket快速使用 1、基于Java注解实现WebSocket服务器端 2、JS前端测试 三、WebSocket进阶使用 1、如何获取当前用户信息 2、 后端聊天功能实现 一、概述 HTTP…

Navicat Premium 16最新版激活 mac/win

Navicat Premium 16 for Mac是一款专业的多连接数据库管理工具。它支持连接多种类型的数据库,包括MySQL、MongoDB、Oracle、SQLite、SQL Server、PostgreSQL等,可以同时连接多种数据库,帮助用户轻松地管理和迁移数据。 Navicat Premium 16 fo…

Wpf 使用 Prism 实战开发Day21

配置默认首页 当应用程序启动时&#xff0c;默认显示首页 一.实现思路&#xff0c;通过自定义接口来配置应用程序加载完成时&#xff0c;设置默认显示页 步骤1.创建自定义 IConfigureService 接口 namespace MyToDo.Common {/// <summary>/// 配置默认显示页接口/// <…

Golang那些违背直觉的编程陷阱

目录 知识点1&#xff1a;切片拷贝之后都是同一个元素 知识点2&#xff1a;方法集合决定接口实现&#xff0c;类型方法集合是接口方法集合的超集则认定为实现接口&#xff0c;否则未实现接口 切片拷贝之后都是同一个元素 package mainimport ("encoding/json"&quo…

springboot是什么?

可以应用于Web相关的应用开发。 选择合适的框架&#xff0c;去开发相关的功能&#xff0c;会有更高的效率。 为什么Spring Boot才是你该学的!学java找工作必会技能!在职程序员带你梳理JavaEE框架_哔哩哔哩_bilibili java工程师的必备技能 Spring是Java EE领域的企业级开发宽…

Kafka源码分析(四) - Server端-请求处理框架

系列文章目录 Kafka源码分析-目录 一. 总体结构 先给一张概览图&#xff1a; 服务端请求处理过程涉及到两个模块&#xff1a;kafka.network和kafka.server。 1.1 kafka.network 该包是kafka底层模块&#xff0c;提供了服务端NIO通信能力基础。 有4个核心类&#xff1a;…

华为海思校园招聘-芯片-数字 IC 方向 题目分享——第六套

华为海思校园招聘-芯片-数字 IC 方向 题目分享——第六套 (共9套&#xff0c;有答案和解析&#xff0c;答案非官方&#xff0c;未仔细校正&#xff0c;仅供参考&#xff09; 部分题目分享&#xff0c;完整版获取&#xff08;WX:didadidadidida313&#xff0c;加我备注&#x…

使用python socket搭建Client测试平台

目录 概述 1 背景 2 Client功能实现 2.1 何谓Client 2.2 代码功能介绍 2.3 代码实现 2.3.1 代码介绍 2.3.2 代码内容 3 测试 3.1 PC上创建Server 3.2 同一台PC上运行Client 3.2.1 建立连接 3.2.2 测试数据交互 3.3 Linux 环境下运行Client 3.3.1 建立连接 3.3.…

无限滚动分页加载与下拉刷新技术探析:原理深度解读与实战应用详述

滚动分页加载&#xff08;也称为无限滚动加载、滚动分页等&#xff09;是一种常见的Web和移动端应用界面设计模式&#xff0c;用于在用户滚动到底部时自动加载下一页内容&#xff0c;而无需点击传统的分页按钮。这种设计旨在提供更加流畅、连续的浏览体验&#xff0c;减少用户交…

Redis 如何实现分布式锁

课程地址 单机 Redis naive 版 加锁&#xff1a; SETNX ${lockName} ${value} # set if not exist如果不存在则插入成功&#xff0c;返回 1&#xff0c;加锁成功&#xff1b;否则返回 0&#xff0c;加锁失败 解锁&#xff1a; DEL ${lockName}问题1 2 个线程 A、B&#…

深入理解与实践“git add”命令的作用

文章目录 **git add命令的作用****git add命令的基本作用****高级用法与注意事项** git add命令的作用 引言&#xff1a; 在Git分布式版本控制系统中&#xff0c;git add命令扮演着至关重要的角色&#xff0c;它是将本地工作区的文件变动整合进版本控制流程的关键步骤。本文旨…

使用docker搭建GitLab个人开发项目私服

一、安装docker 1.更新系统 dnf update # 最后出现这个标识就说明更新系统成功 Complete!2.添加docker源 dnf config-manager --add-repohttps://download.docker.com/linux/centos/docker-ce.repo # 最后出现这个标识就说明添加成功 Adding repo from: https://download.…

ConcurrentHashMap 源码分析(一)

一、简述 本文对 ConcurrentHashMap#put() 源码进行分析。 二、源码概览 public V put(K key, V value) {return putVal(key, value, false); }上面是 ConcurrentHashMap#put() 的源码&#xff0c;我们可以看出其核心逻辑在 putVal() 方法中。 final V putVal(K key, V val…