【深度学习入门篇 ⑦】PyTorch池化层

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


池化层 (Pooling) 降低维度,缩减模型大小,提高计算速度. 即: 主要对卷积层学习到的特征图进行下采样(SubSampling)处理 。

  • 通过下采样,我们可以提取出特征图中最重要的特征,同时忽略掉一些不重要的细节。
  • 上采样是指增加数据(图像)的尺寸;通常用于图像的分割、超分辨率重建或生成模型中,以便将特征图恢复到原始图像的尺寸或更大的尺寸。 

池化层

池化包含最大池化和平均池化,有一维池化,二维池化,三维池化,在这里以二维池化为例

最大池化

最大池化就是求一个区域中的最大值,来代替该区域。

torch.nn.MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)

假设输入的尺寸是(𝑁,𝐶,𝐻,𝑊),输出尺寸是(𝑁,𝐶,𝐻𝑜𝑢𝑡,𝑊𝑜𝑢𝑡),kernel_size是(𝑘𝐻,𝑘𝑊),可以写成下面形式 :

其中,输入参数 kernel_sizestridepaddingdilation可以是

  • 一个 int :代表长宽使用同样的参数
  • 两个int组成的元组:第一个int用在H维度,第二个int用在W维度
import torch
import torch.nn as nn
#长宽一致的池化,核尺寸为3x3,池化步长为2
ml = nnMaxPool2d(3, stride=2)
#长宽不一致的池化
m2 = nn.MaxPool2d((3,2), stride=(2,1))
input = torch.randn(4,3,24,24)
output1 = m1( input)
output2 = m2( input)
print( "input.shape = " ,input.shape)
print( "output1.shape = " , output1.shape)
print( "output2.shape = " , output2.shape)

 输出:

input.shape = torch.size([4,3,24,24])
output1.shape = torch. size([4,3,11,11])
output2.shape = torch.size([4,3,11,23])
平均池化

平均池化就是用一个区域中的平均数来代替本区域

torch.nn.AvgPool2d(kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True, divisor_override=None)

import torch
import torch.nn as nn
#长宽一致的池化,核尺寸为3x3,池化步长为2
ml = nn. AvgPool2d( 3, stride=2)
#长宽不一致的池化
m2 = nn. AvgPool2d(( 3,2), stride=(2,1) )
input = torch.randn(4,3,24,24)
output1 = m1( input)
output2 = m2( input)
print("input.shape = ",input. shape)
print("output1.shape = " , output1.shape)
print( "output2.shape = ", output2.shape)
  • randn是生成形状为[batch_size, channels, height, width] 

输出:

input.shape = torch.size([4,3,24,24])
output1.shape = torch.size([4,3,11,11])
output2.shape = torch.size([4,3,11,23])

BN层

BN,即Batch Normalization,是对每一个batch的数据进行归一化操作,可以使得网络训练更稳定,加速网络的收敛。

import torch
import torch.nn as nn
#批量归一化层(具有可学习参数)
m_learnable = nn. BatchNorm2d(100)
#批量归一化层(不具有可学习参数)
m_non_learnable = nn.BatchNorm2d(100,affine=False)
#随机生成输入数据
input = torch.randn(20,100,35,45)
#应用具有可学习参数的批量归一化层
output_learnable = m_learnable(input)
#应用不具有可学习参数的批量归一化层
output_non_learnable = m_non_learnable(input)
print( "input.shape = ", input.shape)
print( "output_learnable.shape = ", output_learnable.shape)
print( "output_non_learnable.shape = ", output_non_learnable.shape)

 输出:

input.shape = torch.size([20,100,35,45])
output_learnable.shape = torch.size( [20,100,35,45])
output_non_learnable.shape = torch.size([20,100,35,45])

常见的层就是上面提到的这些,如果这些层结构被反复调用,我们可以将其封装成一个个不同的模块。

案例:复现LeNet

LeNet结构,使用PyTorch进行复现,卷积核大小5x5,最大池化层,核大小2x2

import torch
import torch.nn as nn
from torchsummary import summary
class LeNet( nn . Module):
    def _init_( self,num_classes=10):
        super(Leet, self)._init__()
        self.conv1 = nn.conv2d( in_channels=3,out_channels=6,kernel_size=5)
        self.pool1 = nn. MaxPool2d(kernel_size=2)
        self.conv2 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)
        self.pool2 = nn. MaxPool2d(kernel_size=2)
        self.conv3 = nn.conv2d(in_channels=16,out_channels=120, kernel_size=5)
        self.fc1 = nn.Linear(in_features=120,out_features=84)
        self.fc2 = nn.Linear(in_features=84,out_features=10)
    
    def forward(self, x):
        #通过卷积层、ReLU和池化层
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2(x)
        x = self.pool2(x)
        x = self.conv3(x)
        x = x.view( -1,120)
        x = self.fc1(x)
        x = self.fc2(x)
        return x
#创建网络实例
num_classes = 10
net = LeNet( num_classes)

#创建一个输入
batch_size = 4
input_tensor = torch.randn(batch_size,3,32,32)
# 假设输入是32x32的RGB图像
#将输入Tensor传递给网络
output = net(input_tensor)
# #显示输出Tensor的形状
print(output.shape)
summary(net,(3,32,32))

Sequential: 顺序容器

Sequential属于顺序容器。模块将按照在构造函数中传递的顺序从上到下进行运算。

使用OrderedDict,可以进一步对传进来的层进行重命名。

#使用sequential来创建小模块,当有输入进来,会从上到下依次经过所有模块
model = nn. Sequential(
nn.conv2d(1,20,5),nn.ReLu() ,
nn.conv2d(20,64,5),nn.ReLU()
)
#使用orderedDict,可以对传进来的模块进行命名,实现效果同上
from collections import orderedDict
model = nn. sequential ( orderedDict([
        ( 'conv1 ', nn.Conv2d( 1,20,5)),
        ( 'relu1 ', nn.ReLU( ) ),
        ( 'conv2 ', nn.conv2d(20,64,5)),
        ( 'relu2 ', nn.ReLU())
]))

除此之外,还可以用 ModuleList和 ModuleDict 来存放子模块,但是用的不多,掌握了上面的内容就足够了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/801585.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

鸿蒙开发:Universal Keystore Kit(密钥管理服务)【查询密钥是否存在(ArkTS)】

查询密钥是否存在(ArkTS) HUKS提供了接口供应用查询指定密钥是否存在。 开发步骤 指定密钥别名keyAlias,密钥别名最大长度为64字节。初始化密钥属性集。用于查询时指定密钥的属性TAG,比如查询的密钥范围(全量/单个),当查询单个时&#xff…

工时计算软件比较:哪些功能是必需的?

国内外主流的10款工时计算软件对比:PingCode、Worktile、蓝凌OA、用友软件、金蝶软件、智能人事、Time Doctor、Toggl Track、Clockify、Harvest。 在选择合适的工时计算软件时,很多企业面临的一个主要痛点是如何找到既能满足功能需求又具有成本效益的解…

虚拟机及其Debian(kali)安装

本机电脑为Windows10系统专业版,在此基础上安装VMware和系统(Kali) 步骤如下 一、安装 VMware Workstation Pro v16.2.4 安装步骤可参照网上博客,该步骤较简单,此处不做讲解。文件中共计两个,其中一个是激活…

YOLOv8白皮书-第Y8周:yolov8.yaml文件解读

本文为365天深度学习训练营中的学习记录博客 原作者:K同学啊|接辅导、项目定制 请根据YOLOv8n、YOLOv8s模型的结构输出,手写出YOLOv8l的模型输出 文件位置:./ultralytics/cfg/models/v8/yolov8.yaml 一、参数配置 # Parameters nc: 80 # n…

android APP在蓝牙模拟键盘扫描条码设备开机时闪退

🏆本文收录于《CSDN问答解答》专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收藏&…

下载安装VSCode并添加插件作为仓颉编程入门编辑器

VSCode下载地址:下载 Visual Studio Code - Mac、Linux、Windows 插件下载:GitCode - 全球开发者的开源社区,开源代码托管平台 仓颉社区中下载解压 cangjie.vsix 插件 打开VSCode 按 Ctrl Shift X 弹出下图 按照上图步骤依次点击选中我们下…

【Unity2D 2022:UI】制作主菜单

一、创建主菜单游戏场景 1. 在Scenes文件夹中新建一个游戏场景Main Menu 2. 为场景添加背景 (1)创建画布Canvas (2)在Canvas中创建新的空游戏物体Main Menu (3)在Main Menu中新建一个图像游戏物体Backgrou…

c++基础(1)

c语言是结构化和模块化的语言,用于处理规模较小的程序。当问题需要高度抽象和建模时,c语言不适合。c是基于c语言产生的,既可以进行c语言过程化程序设计,又可以以抽象数据类型为特点的基于对象的程序设计,还可以进行面向…

pico+unity手柄和摄像机控制初级设置

1、摄像头配置 摄像头模式、floor是追踪原点类型(将根据设备检测到地面的高度来计算追踪原点), Device 模式时,为通常理解的 Eye 模式,不会将根据设备检测到地面的高度来计算追踪原点 选择floor时,修改相…

达梦数据库的系统视图v$recover_status

达梦数据库的系统视图v$recover_status 在达梦数据库(DM Database)中,V$RECOVER_STATUS 是一个系统视图,用于显示数据库的恢复状态信息。这个视图对于数据库管理员来说非常重要,尤其是在数据库发生故障需要进行恢复操…

Claude 3.5 Sonnet模型发布,对比ChatGPT4o孰强孰弱

Anthropic 这家生而为打击 OpenAI 安全问题的公司,正式发布了Claude 3.5 Sonnet模型! 用官网的话就是: 今天,我们推出了 Claude 3.5 Sonnet,这是我们即将推出的 Claude 3.5 型号系列中的第一个版本。Claude 3.5 Sonne…

批量提取PDF指定区域内容到 Excel , 根据PDF文件第一行文字来自动重命名v1.3-附思路和代码实现

本次文章更新内容,图片以及扫描的PDF也可以支持批量提取指定区域内容了,主要是通过截图指定区域,然后使用OCR来识别该区域的文字来实现的,所以精度可能会有点不够,但是如果是数字的话,问题不大;…

在线工具--将Json结构映射为另外一种Json结构

具体请前往:在线工具-将json结构映射为另外一个json结构

Home Assistant在windows环境安装

Home Assistant是什么? Home Assistant 是一个开源的智能家居平台,旨在通过集成各种智能设备和服务,提供一个统一的、可自定义的家庭自动化解决方案。它可以允许用户监控、控制和自动化家中的各种设备,包括灯光、温度、安全系统、…

算法 —— 高精度(模拟)

目录 加法高精度 两个正整数相加 两个正小数相加 两正数相加 减法高精度 两个正整数相减 两个正小数相减 两正数相减 加减法总结 乘法高精度 两个正整数相乘 两个正小数相乘 乘法总结 加法高精度 题目来源洛谷:P1601 AB Problem(高精&#x…

JVM知识点

一、java内存区域与内存异常 1、运行时数据区域 1)程序计数器 程序计数器(Program Counter Register)是一块较小的内存空间,它可以看作是当前线程所执行的字节码的行号指示器。在Java虚拟机的概念模型里 [1] ,字节码…

降本增效CRKD:通过跨模态知识蒸馏增强相机与毫米波雷达目标检测精度

Abstract 在自动驾驶的3D目标检测领域,激光雷达-摄像头(LC)融合是表现最好的传感器配置。然而,激光雷达的成本相对较高,这阻碍了该技术在消费者汽车中的普及。相反,摄像头和雷达已经普遍部署在现有车辆上&…

Springboot整合MyBatis实现数据库查询(二)

目录 第一章、准备1.1)准备数据库表1.2)创建springboot项目,添加依赖1.3)使用mybatis逆向工程 第二章、代码开发2.1)建包并编写代码2.2)application配置文件2.3)设置编译位置 第三章、测试访问3…

用HTML和CSS实现提示工具(tooltip)及HTML元素的定位

所谓提示工具,是指将鼠标移动到某个HTML元素(工具)时会显示一些提示内容(提示文本),而鼠标移出工具元素的范围时提示文本就消失了。考虑到提示文本元素应当在鼠标进入工具元素时显示,鼠标离开工…

JDK之使用keytool安装cer证书

可针对https请求缺失证书解决报错: PKIX path building failed: sun.security.provider.certpath.SunCertPathBuilderException: unable to find valid certification path to requested target 解决办法: 先通过浏览器下载证书,再使用JDK自带…