Pytorch学习 day08(最大池化层、非线性激活层)

最大池化层

  • 最大池化,也叫上采样,是池化核在输入图像上不断移动,并取对应区域中的最大值,目的是:在保留输入特征的同时,减小输入数据量,加快训练。
  • 参数设置如下:
    • kernel_size:池化核的高宽(整数或元组),整数时表示高宽都为该整数,元组时表示分别在水平和垂直方向上的长度。
    • stride:池化核每次移动的步长(整数或元组),整数时表示在水平和垂直方向上使用相同的步长。元组时分别表示在水平和垂直方向上的步长。默认为池化核的高宽。
    • padding:控制在输入张量的边界周围添加的零填充的数量(为整数或元组),如果是整数,表示在水平和垂直方向上使用相同的填充数量。如果是元组,分别表示在水平和垂直方向上的填充数量。默认为0
    • dilation:卷积核内部元素之间的距离,空洞卷积,如图:
      在这里插入图片描述
    • ceil_mode:True表示ceil模式,即向上取整,保留未满部分。False表示floor模式,即向下取整,舍去未满部分。默认为False,如图:
      在这里插入图片描述
    • 如下是示意图:
      在这里插入图片描述
  • 以下是代码实现:
    • 注意:部分版本的MaxPool2d不支持int类型,所以需要指定数据类型为浮点数
import torch

input = torch.tensor([[1, 2, 0, 3, 1],
                      [0, 1, 2, 3, 1],
                      [1, 2, 1, 0, 0],
                      [5, 2, 3, 1, 1],
                      [2, 1, 0, 1, 1]], dtype=torch.float32)  # 由于部分版本的MaxPool2d不支持int类型,所以这里需要指定数据类型为float32  

input = torch.reshape(input, (-1, 1, 5, 5)) # 将input从二维张量变成(N, C, H, W)的四维张量
print(input.shape)

class Tudui(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pool1 = torch.nn.MaxPool2d(kernel_size=3, ceil_mode=True)  # 最大池化层,池化核大小3*3,向上取整

    def forward(self, input):
        output = self.pool1(input)
        return output

tudui = Tudui()
output = tudui(input)
print(output)

# 输出结果为
# torch.Size([1, 1, 5, 5])
# tensor([[[[2., 3.],
#           [5., 1.]]]])
  • 同样可以通过tensorboard进行展示输入输出结果,代码如下:
    • 由于最大池化层不会改变channel,所以不需要对输出进行reshape()操作
import torch
import torchvision
from torch.utils.tensorboard import SummaryWriter

test_dataset = torchvision.datasets.CIFAR10(root='Dataset', train=False, download=True, transform=torchvision.transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=4, shuffle=False, num_workers=0)



class Tudui(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pool1 = torch.nn.MaxPool2d(kernel_size=3, ceil_mode=True)  # 最大池化层,池化核大小3*3,向上取整

    def forward(self, input):
        output = self.pool1(input)
        return output

tudui = Tudui()

writer = SummaryWriter("logs")  # 创建一个SummaryWriter对象,指定日志文件保存路径
step = 0
for data in test_loader:
    imgs, targets = data    # 获取输入数据
    outputs = tudui(imgs)   # 调用网络模型进行前向传播
    writer.add_images("input", imgs, step)  # 将输入数据imgs写入日志文件
    # 由于最大池化层不会改变通道数,所以不需要对outputs进行reshape()操作
    writer.add_images("output", outputs, step)  # 将输出数据outputs写入日志文件
    step += 1

writer.close()
  • 结果如下:
    在这里插入图片描述

非线性激活层

  • 主要目的是向网络中引入一些非线性特征,非线性越多,才能训练出符合复杂数据集的模型,提高模型的泛化性
  • 常用的非线性激活层有:ReLU、Sigmoid,如下:
    • ReLU:当输入大于0时,输出等于输入。当输入小于0时,输出等于0。

    • 注意:输入的第一个位置要是batch_size,之后的不做限制

    • inplace:是否对输入进行结果替换,默认为False(不替换),并返回输出
      在这里插入图片描述
      在这里插入图片描述

    • Sigmoid:将输入经过以下公式,得到输出。

    • 注意:第一个位置也要是batch_size,之后的不做限制:
      在这里插入图片描述

  • ReLU激活函数代码如下:
import torch
from torch import nn

input = torch.tensor([[1, -0.5],
                      [-1, 3]])
# 由于ReLU激活函数需要第一个维度为batch_size,所以需要对输入input进行reshape操作
input = torch.reshape(input, (-1, 1, 2, 2)) 

class Tudui(nn.Module):

    def __init__(self): # 初始化
        super().__init__()  # 继承父类的初始化
        self.relu1 = nn.ReLU()  # ReLU激活函数

    def forward(self, input):
        output = self.relu1(input)  # 调用ReLU激活函数对输入input进行激活
        return output

tudui = Tudui()
output = tudui(input)
print(output)

# 输出结果:
# tensor([[[[1., 0.],
#           [0., 3.]]]])
  • Sigmoid激活函数代码如下:
import torch
import torchvision
from torch import nn
from torch.utils.tensorboard import SummaryWriter


test_dataset = torchvision.datasets.CIFAR10(root='Dataset', train=False, download=True, transform=torchvision.transforms.ToTensor())
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False, num_workers=0)

writer = SummaryWriter("logs")  # 创建一个SummaryWriter对象,指定日志文件保存路径
class Tudui(nn.Module):

    def __init__(self): # 初始化
        super().__init__()  # 继承父类的初始化
        self.sigmoid1 = nn.Sigmoid()  # Sigmoid激活函数

    def forward(self, input):
        output = self.sigmoid1(input)  # 调用Sigmoid激活函数
        return output
step = 0
for data in test_loader:
    imgs, targets = data
    tudui = Tudui() # 实例化网络模型
    writer.add_images("input", imgs, step)  # 将输入数据imgs写入日志文件
    outputs = tudui(imgs)   # 调用网络模型进行前向传播
    writer.add_images("output", outputs, step)  # 将输出数据outputs写入日志文件
    step += 1

writer.close()  # 关闭日志文件
  • 结果如下:
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/443089.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

微信加好友频繁会被封号吗?

微信加好友频繁会被封号吗? 微信规定,每个人每天最多可以加20个好友,但一天之内如果频繁加好友,微信可能会出现异常提示,需要暂停好友添加操作。 面对微信上突如其来的大量好友申请,一定要谨慎处理,以免被…

Golang搭建grpc环境

简介 OS : Windows 11 Golang 版本: go1.22.0 grpc : 1.2 protobuffer: 1.28代理 没有代理国内环境下载不了库七牛CDN (试过可用) go env -w GOPROXYhttps://goproxy.cn,direct阿里云代理(运行grpc时下载包出现报错 ): go env -w GOPROXYhttps://mirr…

CCProxy代理服务器地址的设置步骤

目录 前言 一、下载和安装CCProxy 二、启动CCProxy并设置代理服务器地址 三、验证代理服务器设置是否生效 四、使用CCProxy进行代理设置的代码示例 总结 前言 CCProxy是一款常用的代理服务器软件,可以帮助用户实现网络共享和上网代理。本文将详细介绍CCProxy…

IntelliJ IDEA 2020.2.4试用方法

打开idea,准备好ide-eval-resetter压缩包。 将准备好的压缩包拖入idea中 选中弹窗中的自动重置选项,并点击重置 查看免费试用时长

[数据集][目标检测]变电站缺陷检测数据集VOC+YOLO格式8307张17类别

数据集格式:Pascal VOC格式YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):8307 标注数量(xml文件个数):8307 标注数量(txt文件个数):8307 标注…

汽车大灯汽车尾灯破裂裂纹破损破洞掉角崩角等问题能修复吗?修复后需要注意什么?

汽车灯罩破损修复后,车主需要注意以下几点: 检查修复效果:修复完成后,车主应该仔细检查灯罩的修复效果,确保破损部分已经被填补并恢复原有的透明度和光泽。如果修复效果不理想,需要及时联系维修店进行处理…

问题:前端获取long型数值精度丢失,后面几位都为0

文章目录 问题分析解决 问题 通过接口获取到的数据和 Postman 获取到的数据不一样,仔细看 data 的第17位之后 分析 该字段类型是long类型问题:前端接收到数据后,发现精度丢失,当返回的结果超过17位的时候,后面的全…

什么是工业级物联网智能网关?如何远程控制PLC?

在这个信息爆炸的时代,物联网技术已经逐渐渗透到我们生活的方方面面,而工业级物联网智能网关作为连接工业设备和云端的重要桥梁,更是引领着工业4.0时代的浪潮。那么,究竟什么是工业级物联网智能网关呢?今天&#xff0c…

git删除comimit提交的记录

文章目录 本地的删除远程同步修改上次提交更多详情阅读 本地的删除 例如我的提交历史如下 commit 58211e7a5da5e74171e90d8b90b2f00881a48d3a Author: test <test36nu.com> Date: Fri Sep 22 20:55:38 2017 0800add d.txtcommit 0fb295fe0e0276f0c81df61c4fd853b7a00…

详解DNS服务

华子目录 概述产生原因作用连接方式 因特网的域名结构拓扑分类域名服务器类型划分 DNS域名解析过程分类解析图图过程分析注意 搭建DNS域名解析服务器概述安装软件bind服务中的三个关键文件 配置文件分析主配置文件共4部分组成区域配置文件作用区域配置文件示例分析正向解析反向…

STM32代码调试时遇到的一些error和warning

持续更新 ERROR WARNING 1.Note: object file renamed from “xxx.o“ to “xxx_1.o“ 出现下面这些warning可能的原因&#xff1a; &#xff08;1&#xff09;没有将头文件加入到main.c中&#xff0c;检查一下在编译。 &#xff08;2&#xff09;修改源文件路径的时候忘记…

python学习the sixth day

python函数进阶 一、函数多返回值 二、函数的多种参数使用 1.位置参数 2.关键字参数 3.缺省参数 设置默认值&#xff0c;必须放在最后面 4.不定长参数 4.总结 三、匿名函数 1.函数作为参数传递 这是计算逻辑的传递&#xff0c;而非数据的传递 2.lambda匿名函数 python文件操…

【电路笔记】-PNP晶体管

PNP晶体管 文章目录 PNP晶体管1、概述2、PNP晶体管电路示例3、PNP晶体管识别1、概述 PNP 晶体管与我们在上一篇教程中看到的 NPN 晶体管器件完全相反。 在这种类型的 PNP 晶体管结构中,两个互连的二极管相对于之前的 NPN 晶体管是相反的。 这会产生正-负-正类型的配置,箭头…

JAVA实战开源项目:智能停车场管理系统(Vue+SpringBoot)

目录 一、摘要1.1 项目介绍1.2 项目录屏 二、研究内容A. 车主端功能B. 停车工作人员功能C. 系统管理员功能1. 停车位模块2. 车辆模块3. 停车记录模块4. IC卡模块5. IC卡挂失模块 三、界面展示3.1 登录注册3.2 车辆模块3.3 停车位模块3.4 停车数据模块3.5 IC卡档案模块3.6 IC卡挂…

Android Studio下载gradle超时问题解决

方法一 1. 配置根目录的setting.gradle.kts文件 pluginManagement {repositories {maven { urluri ("https://www.jitpack.io")}maven { urluri ("https://maven.aliyun.com/repository/releases")}maven { urluri ("https://maven.aliyun.com/repos…

基于springboot的家庭装修报价系统设计与实现

目 录 摘 要 I Abstract II 引 言 1 1 相关技术 3 1.1 SpringBoot框架 3 1.2 ECharts 3 1.3 Vue框架 3 1.4 Bootstrap框架 3 1.5 JQuery技术 4 1.6 Ajax技术 4 1.7 本章小结 4 2 系统分析 5 2.1 需求分析 5 2.2 非功能需求 7 2.3 本章小结 8 3 系统设计 9 3.1 系统总体设计 9 …

Python学习日记之学习turtle库(上 篇)

一、初步认识turtle库 turtle 库是 Python 语言中一个很流行的绘制图像的函数库&#xff0c;想象一个小乌龟&#xff0c;在一个横 轴为 x、纵轴为 y 的坐标系原点&#xff0c;(0,0)位置开始&#xff0c;它根据一组函数指令的控制&#xff0c;在这个平面 坐标系中移动&#xff0…

ubuntu 运行opencv_sample遇到的问题

首先我遇到的问题就是摄像头连接不上 勾选最后一个 然后是 usb接口问题 点击虚拟机设置 我的是改为 3 就可以啦

TensorRT是什么,有什么作用,如何使用

TensorRT 是由 NVIDIA 提供的一个高性能深度学习推理&#xff08;inference&#xff09;引擎。它专为生产环境中的部署而设计&#xff0c;用于提高在 NVIDIA GPU 上运行的深度学习模型的推理速度和效率。以下是关于 TensorRT 的详细介绍&#xff1a; TensorRT 是 NVIDIA 推出的…

Facebook广告必坑指南

不明确的目标&#xff1a; 在开始广告活动之前&#xff0c;确保你清楚自己的广告目标。是想提高品牌知名度、促进销售、还是增加网站流量&#xff1f;明确的目标有助于指导广告内容和策略。 忽视目标受众定位&#xff1a; 确定你的目标受众是关键的。使用Facebook广告管理工具…