pytorch-16 复现经典网络:LeNet5与AlexNet

一、相关概念

对于(10,3,227,227)数据表示,10张3通道的图,图的大小(特征数)为227*227.
通道数:作为卷积的输入通道数和输出通道数。
特征数:特征图的大小
步长stride和填充padding:线性减小特征图的尺寸
池化pooling:非线性且高效减小特征图的尺寸
计算公式:hout = (hin +2p -k) /s +1

二、LeNet5:现代CNN的奠基者

LeNet的核心思想“卷积+池化+线性”。在PyTorch中实现其架构的代码如下:在这里插入图片描述

import torch
from torch import nn
from torch.nn import functional as F
from torchinfo import summary

data = torch.ones(size=(10,1,32,32))

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1,6,5) #(H+2p-K)/S + 1
        self.pool1 = nn.AvgPool2d(kernel_size=2,stride=2)
        self.conv2 = nn.Conv2d(6,16,5)
        self.pool2 = nn.AvgPool2d(2)
        self.fc1 = nn.Linear(5*5*16,120)
        self.fc2 = nn.Linear(120,84)
    
    def forward(self,x):
        x = F.tanh(self.conv1(x))
        x = self.pool1(x)
        x = F.tanh(self.conv2(x))
        x = self.pool2(x)
        x = x.view(-1,5*5*16) #-1,我不关心-1这个位置上的数是多少,你根据我输入的x的结构帮我自己计算吧
        x = F.tanh(self.fc1(x))
        output = F.softmax(self.fc2(x),dim=1) #(samples, features)

net = Model() #实例化
net(data)  #相当于在执行 net.forward(data) 

net = Model() #实例化
summary(net, input_size=(10,1,32,32))

结果显示:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Model                                    --                        --
├─Conv2d: 1-1                            [10, 6, 28, 28]           156
├─AvgPool2d: 1-2                         [10, 6, 14, 14]           --
├─Conv2d: 1-3                            [10, 16, 10, 10]          2,416
├─AvgPool2d: 1-4                         [10, 16, 5, 5]            --
├─Linear: 1-5                            [10, 120]                 48,120
├─Linear: 1-6                            [10, 84]                  10,164
==========================================================================================
Total params: 60,856
Trainable params: 60,856
Non-trainable params: 0
Total mult-adds (M): 4.22
==========================================================================================
Input size (MB): 0.04
Forward/backward pass size (MB): 0.52
Params size (MB): 0.24
Estimated Total Size (MB): 0.81
==========================================================================================

三、AlexNet:从浅层到深度

AlexNet的架构若用文字来表现,则可以
打包成4个组合:
输入→(卷积+池化)→(卷积+池化)→(卷积x3+池化)→(线性x3)→输出
相对的,LeNet5的架构可以打包成3个组合:
输入→(卷积+池化)→(卷积+池化)→(线性x2)→输出在这里插入图片描述
和只有6层(包括池化层)的LeNet5比起来,AlexNet主要做出了如下改变:
1、相比之下,卷积核更小、网络更深、通道数更多,这代表人们已经认识到了图像数据天生适合于多次
提取特征,“深度”才是卷积网络的未来。LeNet5是基于MNIST数据集创造,MNIST数据集中的图片尺寸
大约只有30*30的大小,LeNet5采用了5x5的卷积核,图像尺寸/核尺寸大约在6:1。而基于ImageNet
数据集训练的AlexNet最大的卷积核只有11x11,且在第二个卷积层就改用5x5,剩下的层中都使用3x3
的卷积核,图像尺寸/核尺寸至少也超过20:1。小卷积核让网络更深,但也让特征图的尺寸变得很小,
为了让信息尽可能地被捕获,AlexNet也使用了更多的通道。小卷积核、多通道、更深的网络,这些都
成为了卷积神经网络后续发展的指导方向。
2、使用了ReLU激活函数,摆脱Sigmoid与Tanh的各种问题。
3、使用了Dropout层来控制模型复杂度,控制过拟合。
4、引入了大量传统或新兴的图像增强技术来扩大数据集,进一步缓解过拟合。
5、使用GPU对网络进行训练,使得“适当的训练“(proper training)成为可能。

1、AlexNet的架构复现

在PyTorch中来复现AlexNet的架构:

import torch
from torch import nn
from torch.nn import functional as F

data = torch.ones(size=(10,3,227,227)) #224 x 224

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 大卷积核、较大的步长、较多的通道
        # 为了处理尺寸较大的原始图片,先使用11x11的卷积核和较大的步长来快速降低特征图的尺寸
        # 同时,使用比较多的通道数,来弥补降低尺寸造成的数据损失
        self.conv1 = nn.Conv2d(3,96, kernel_size=11, stride=4)
        self.pool1 = nn.MaxPool2d(kernel_size=3, stride=2) #overlap pooling
        
        #卷积核、步长恢复正常大小,进一步扩大通道
        # 已经将特征图尺寸缩小到27x27,计算量可控,可以开始进行特征提取了
        # 卷积核、步长恢复到业界常用的大小,进一步扩大通道来提取数据
        self.conv2 = nn.Conv2d(96,256,kernel_size=5,padding=2)
        self.pool2 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # 疯狂提取特征,连续用多个卷积层
        # kernel 5, padding 2, kernel 3, padding 1 可以维持住特征图的大小
        self.conv3 = nn.Conv2d(256,384,kernel_size=3, padding =1) 
        self.conv4 = nn.Conv2d(384,384,kernel_size=3, padding =1)
        self.conv5 = nn.Conv2d(384,256,kernel_size=3, padding =1)
        self.pool3 = nn.MaxPool2d(kernel_size=3, stride=2)
        
        # 进入全连接层,进行信息汇总
        self.fc1 = nn.Linear(6*6*256,4096) #上层所有特征图上的所有像素
        self.fc2 = nn.Linear(4096,4096)
        self.fc3 = nn.Linear(4096,1000)
    
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.pool1(x)

        x = F.relu(self.conv2(x))
        x = self.pool2(x)

        x = F.relu(self.conv3(x))
        x = F.relu(self.conv4(x))
        x = F.relu(self.conv5(x))
        x = self.pool3(x)

        x = x.view(-1,6*6*256)  #将数据拉平

        x = F.dropout(x,p=0.5)                      #dropout:随机让50%的权重为0
        x = F.relu(F.dropout(self.fc1(x),p=0.5))    #dropout:随机让50%的权重为0 
        x = F.relu(self.fc2(x))
        output = F.softmax(self.fc3(x),dim=1)

net = Model()
net(data)

from torchinfo import summary
summary(net,input_size=(10,3,227,227))

结果显示:

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
Model                                    --                        --
├─Conv2d: 1-1                            [10, 96, 55, 55]          34,944
├─MaxPool2d: 1-2                         [10, 96, 27, 27]          --
├─Conv2d: 1-3                            [10, 256, 27, 27]         614,656
├─MaxPool2d: 1-4                         [10, 256, 13, 13]         --
├─Conv2d: 1-5                            [10, 384, 13, 13]         885,120
├─Conv2d: 1-6                            [10, 384, 13, 13]         1,327,488
├─Conv2d: 1-7                            [10, 256, 13, 13]         884,992
├─MaxPool2d: 1-8                         [10, 256, 6, 6]           --
├─Linear: 1-9                            [10, 4096]                37,752,832
├─Linear: 1-10                           [10, 4096]                16,781,312
├─Linear: 1-11                           [10, 1000]                4,097,000
==========================================================================================
Total params: 62,378,344
Trainable params: 62,378,344
Non-trainable params: 0
Total mult-adds (G): 11.36
==========================================================================================
Input size (MB): 6.18
Forward/backward pass size (MB): 52.74
Params size (MB): 249.51
Estimated Total Size (MB): 308.44
==========================================================================================

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/657979.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Unity入门】认识Unity编辑器

Unity 是一个广泛应用于游戏开发的强大引擎,从 1.0 版本开始到现在,其编辑器的基本框架一直保持稳定。其基于组件架构的设计,使得界面使用起来直观且高效。为了更好地理解 Unity 的界面,我们可以将其比喻为搭建一个舞台。以下是对…

K8s 小白入门|从电影配乐谈起,聊聊容器编排和 K8s

来听听音乐 电影,是我们生活中的重要调味剂。 配乐,是电影中不可或缺的一部分。 有的时候,配乐可以跟剧情共振,让你按捺不住自己的情绪,或眼含热泪、或慷慨激昂、或人仰马翻、或怅然若失; 有的时候&…

AI图书推荐:基于ChatGPT API和Python开发应用程序的详细指南

ChatGPT已经以其革命性的能力引起了人们的关注,利用其API可能会成为你的游戏规则改变者。这不仅仅是关于编码;它是关于为您的创作添加一层智能,将它们提升到之前无法想象的水平。《基于ChatGPT API和Python开发应用程序的详细指南》&#xff…

让EXCEL VBA支持鼠标滚轮,vb6 IDE鼠标滚轮插件原理

vb6 IDE鼠标滚轮插件怎么运行的(适用于VBA) 使用 Spy,我发现代码窗口正在获取 WM_MOUSEWHEEL 事件,但没有触发 WM_VSCROLL 消息。因此,我编写了一个简单的消息钩子,当它捕获鼠标滚轮事件时触发滚动事件。 我从 Spy 得知代码窗口的…

从零自制docker-14-【实现 mydocker commit 打包容器成镜像】

文章目录 目标注意exec.Commandtar代码结果 目标 piveroot切换工作目录到/merged后,通过docker commit将此时工作目录的文件系统保存下来,使用tar包将该文件系统打包为tar文件 命令类似 ./mydocker commit myimage然后当前目录下会得到myimage.tar 注意…

SA316系列音频传输模块-传输距离升级音质不打折

SA316是思为无线研发的一款远距离音频传输模块,音频采样率为48K,传输距离可达200M。为了满足更多用户需求,思为无线在SA316基础上进一步增加传输距离推出SA316F30。相比SA316性能,同样其采用48K采样,-96dBm灵敏度&…

[C][动态内存分配][柔性数组]详细讲解

目录 1.动态内存函数的介绍1.malloc2.free2.calloc4.realloc 2.常见的动态内存错误3.C/C程序的内存开辟4.柔性数组1.是什么?2.柔性数组的特点3.柔性数组的使用4.柔性数组的优势 1.动态内存函数的介绍 1.malloc 函数原型:void* malloc(size_t size)功能…

Linux网络编程:应用层协议|HTTP

前言: 我们知道OSI模型上层分为应用层、会话层和表示层,我们接下来要讲的是主流的应用层协议HTTP,为什么需要这个协议呢,因为在应用层由于操作系统的不同、开发人员使用的语言类型不同,当我们在传输结构化数据时&…

前端开发工程师——AngularJS

一.表达式和语句 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><meta name"viewport" content"widthdevice-w…

AI播客下载:Acquired podcast每个公司都有一个故事

"Acquired Podcast" 是一档专注于深度解析科技行业和企业发展历程的播客节目&#xff0c;由Ben Gilbert和David Rosenthal主持。其口号是&#xff1a;Every company has a story.《Acquired》每一集都围绕一个特定的主题或公司进行讨论。它以独特的视角和深入的分析&…

FFmpeg编解码的那些事(1)

看了网上很多ffmpeg的编解码的文章和代码&#xff0c;发现有很多文章和代码都过时了&#xff0c;主要还是ffmpeg有很多接口都已经发生变化了。 这里简单说一下&#xff0c;什么是编码和解码。 1.视频编码 对于视频来说&#xff0c;可以理解为多张&#xff08;rgb或者yuv&…

Numba 的 CUDA 示例(1/4):踏上并行之旅

按照本系列从头开始使用 Python 学习 CUDA 编程 介绍 GPU&#xff08;图形处理单元&#xff09;&#xff0c;顾名思义&#xff0c;最初是为计算机图形学开发的。从那时起&#xff0c;它们几乎在每个需要高计算吞吐量的领域都无处不在。这一进步得益于 GPGPU&#xff08;通用 G…

从零开始学React--环境搭建

React官网 快速入门 – React 中文文档 1.搭建环境 下载nodejs,双击安装 nodejs下载地址 更新npm npm install -g npm 设置npm源&#xff0c;加快下载速度 npm config set registry https://registry.npmmirror.com 创建一个react应用 npx create-react-app react-ba…

React-组件基础使用

组件是什么 概念&#xff1a;一个组件就是用户界面的一部分&#xff0c;它可以有自己的逻辑和外观&#xff0c;组件之间可以互相嵌套&#xff0c;也可以复用多次 组件化开发可以让开发者像搭积木一样构建一个完整的庞大的应用 React组件 在React中&#xff0c;一个组件就是首…

Rocksdb原理简介

100编程书屋_孔夫子旧书网 Rocksdb作为当下nosql中性能的代表被各个存储组件&#xff08;mysql、tikv、pmdk、bluestore&#xff09;作为存储引擎底座&#xff0c;其基于LSM tree的核心存储结构&#xff08;将随机写通过数据结构转化为顺序写&#xff09;来提供高性能的写吞吐时…

读人工智能时代与人类未来笔记17_人类未来

1. 人类未来 1.1. 在印刷读物出现之前&#xff0c;中世纪的欧洲人主要通过社区传统获取知识 1.1.1. 通过参与收割和季节轮作积累民间智慧 1.1.2. 在礼拜场所践行信仰并遵守圣礼 1.1.3. 加入行业公会&#xff0c;学习技术&…

重置服务器之后 SSH 登录报错:REMOTE HOST IDENTIFICATION HAS CHANGED!

问题原因&#xff1a; 报错是由于远程的主机的公钥发生了变化导致的。ssh服务是通过公钥和私钥来进行连接的&#xff0c;它会把每个曾经访问过计算机或服务器的公钥&#xff08;public key&#xff09;&#xff0c;记录在~/.ssh/known_hosts 中&#xff0c;当下次访问曾经访问…

量子密钥分发系统基础器件(一):光纤干涉仪

干涉仪的基本原理是利用波的叠加来获得波的相位信息&#xff0c;从而获取实验中所关心的物理量。光纤干涉仪是由光学干涉仪发展而来的&#xff0c;利用光纤实现光的干涉&#xff0c;由于光纤取代透镜系统构成的光路具有柔软、形状可随意变化、传输距离远等特点&#xff0c;当前…

MouseBoost Pro for Mac v3.4.7 鼠标右键助手 安装教程【支持M芯片】

MouseBoost Pro for Mac v3.4.7 鼠标右键助手 安装教程【支持M芯片】 原文地址&#xff1a;https://blog.csdn.net/weixin_48311847/article/details/139201501

电源模块测试系统怎么测试输入电压范围?

在现代电子设备中&#xff0c;电源模块的性能直接影响着整个系统的稳定性和效率。其中&#xff0c;电源输入电压范围是指电源能够接受的输入电压的最小值和最大值&#xff0c;它是确保电源正常工作的重要参数。为了提高测试效率和精度&#xff0c;自动化的测试方法逐渐取代了传…