卷积神经网络可视化的探索

文章目录

    • 训练LeNet模型
      • 下载FashionMNIST数据
      • 训练
      • 保存模型
    • 卷积神经网络可视化
      • 加载模型
      • 一个测试图像
      • 不同层对图像处理的可视化
      • 第一个卷积层的处理
      • 第二个卷积层的处理

卷积神经网络是利用图像空间结构的一种深度学习网络架构,图像在经过卷积层、激活层、池化层、全连接层等处理后得到输出。

本次想探索一下图像经过每一层都发生了什么变化,比如不同的卷积核(滤波器)都提取了图像的什么特征?越深层是否会对图像更抽象化?

带着这些问题,本文将使用FashionMNIST数据、简单的LeNet模型来探索CNN是如何处理图像的。

训练LeNet模型

首先来训练一个LeNet模型(换成其他卷积神经网络也可以),目的是为了利用训练好的模型参数获得输入图像的各层输出,以供可视化之用。

下载FashionMNIST数据

root:设置下载路径;

train:为True表示下载训练集,反之为测试集;

download:首次下载设为True,下载好后可以改为False。

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"

from matplotlib import pyplot as plt
%matplotlib inline
from PIL import Image

import torch
from torch import nn
import torchvision
from torch.utils import data
from torchvision import transforms

#下载数据
def load_fashion_mnist(batch_size):
    trans = transforms.Compose([transforms.ToTensor()])
    train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=False)
    test = torchvision.datasets.FashionMNIST(root="../data", train=False,transform=trans,download=False)
    
    return (data.DataLoader(train, batch_size, shuffle=True), data.DataLoader(test, batch_size, shuffle=False))

训练

简单地训练网络

#批量大小
batch_size = 512
train_iter, test_iter = load_fashion_mnist(batch_size=batch_size)

#LeNet网络
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
    nn.Linear(120, 84), nn.ReLU(),
    nn.Linear(84, 10))

#参数初始化
def init_network(model, method='xavier'):
    for name, w in model.named_parameters():
        if 'weight' in name:
            if method == 'xavier':
                nn.init.xavier_normal_(w)
            elif method == 'kaiming':
                nn.init.kaiming_normal_(w)
            else:
                nn.init.normal_(w)
        elif 'bias' in name:
            nn.init.constant_(w, 0)
        else:
            pass    
init_network(net)

#损失函数
loss=nn.CrossEntropyLoss()

#优化算法
lr=0.05
updater=torch.optim.SGD(net.parameters(),lr=lr)

#训练
def train(net, train_iter, test_iter, loss, num_epochs, updater,device):
    net.to(device)
    for epoch in range(num_epochs):
        if isinstance(net, torch.nn.Module):
            net.train()
        for X, y in train_iter:
            X,y=X.to(device),y.to(device)         
            y_hat = net(X)        
            l = loss(y_hat, y)
            
            updater.zero_grad() 
            l.backward()                
            updater.step()
           
device = torch.device("mps" if torch.backends.mps.is_available else "cpu") #Mac使用mps
num_epochs = 20
train(net, train_iter, test_iter, loss, num_epochs, updater,device)

保存模型

#保存模型参数
torch.save(net.state_dict(),'LeNet.params')

卷积神经网络可视化

本节将使用上文训练好的模型来可视化卷积神经网络不同层对图像的处理过程。

加载模型

#LeNet网络结构
net = nn.Sequential(
    nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),
    nn.AvgPool2d(kernel_size=2, stride=2),
    nn.Flatten(),
    nn.Linear(16 * 5 * 5, 120), nn.ReLU(),
    nn.Linear(120, 84), nn.ReLU(),
    nn.Linear(84, 10))

#加载模型参数
net.load_state_dict(torch.load('LeNet.params'))
net.eval()

一个测试图像

#批量大小
batch_size = 1
train_iter, test_iter = load_fashion_mnist(batch_size=batch_size)
#随机选择一个图像
for x,y in train_iter:
    _=plt.imshow(x.squeeze(0).permute(1,2,0).numpy())
    break

看上去我们抽到了一件T恤。


在这里插入图片描述


不同层对图像处理的可视化

#计算到给定层的输出
def cnn_net(X,net,l=1):
    for i,layer in enumerate(net[0:l]):
        X=layer(X)  
        if i==l-1:
            print('第%s层:%-10s 输出形状:%s'%(i+1, layer.__class__.__name__, X.shape))        
    return X
  
#可视化
def cnn_visual(imgs,nrows,ncols,scale):
    figsize = (ncols * scale, nrows * scale)
    fig,axes = plt.subplots(nrows=nrows, ncols=ncols,figsize=figsize)
    axes = axes.flatten()

    for i, (ax, img) in enumerate(zip(axes, imgs.squeeze(0))):   
        _ = ax.imshow(img.detach().numpy())
        ax.axes.get_xaxis().set_visible(False)
        ax.axes.get_yaxis().set_visible(False)
    return axes
  
for l in [1,2,3,4,5,6]:
    imgs=cnn_net(x.clone(),net,l=l)
    nrows=2
    ncols=int(imgs.shape[1]/nrows)
    axes=cnn_visual(imgs,nrows,ncols,2)

在这里插入图片描述

第一个卷积层的处理


我们先来看看第一个卷积层中不同卷积核分别从图像中提取了什么信息,第一个卷积层有6个输出通道,因此查看每个通道输出的图像。

从下图可以看出,第一个卷积层提取到了不同轮廓层次信息。


在这里插入图片描述


经过ReLU处理后:


在这里插入图片描述


再经平均池化处理后,变化不大:


在这里插入图片描述


第二个卷积层的处理

第二个卷积层有16个输出通道,随着层次加深,感受野扩大,通道的融合后,从下图看已经比较抽象了,但隐隐约约还能看出点端倪:


在这里插入图片描述


再经ReLU和池化处理后,基本上已经面目全非:


在这里插入图片描述


在这里插入图片描述

以上就是对卷积神经网络可视化的初步探索,感兴趣的读者可以在不同卷积神经网络和图像上多做尝试。
另附一个卷积神经网络可视化网站

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/796854.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Android liveData 监听异常,fragment可见时才收到回调记录

背景&#xff1a;在app的fragment不可见的情况下使用&#xff0c;发现注册了&#xff0c;但是没有回调导致数据一直未更新&#xff0c;只有在fragment可见的时候才收到回调 // 观察通用信息mLightNaviTopViewModel.getUpdateCommonInfo().observe(this, new Observer<Common…

13--memcache与redis

前言&#xff1a;数据库读取速度较慢一直是无法解决的问题&#xff0c;大型网站应对的方式主要是使用缓存服务器来缓解这种情况&#xff0c;减少数据库访问次数&#xff0c;以提高动态Web等应用的速度、提高可扩展性。 1、简介 Memcached/redis是高性能的分布式内存缓存服务器…

JVM:字节码文件

文章目录 一、Java虚拟机的组成二、字节码文件的组成1、基本信息2、常量池3、字段4、方法5、属性 三、常用的字节码工具1、javap -v 命令2、jclasslib插件3、阿里arthas 一、Java虚拟机的组成 二、字节码文件的组成 1、基本信息 魔数、字节码文件对应的Java版本号访问标识&am…

走进linux

1、为什么要使用linux 稳定性和可靠性&#xff1a; Linux内核以其稳定性而闻名&#xff0c;能够持续运行数月甚至数年而不需要重新启动。这对于服务器来说至关重要&#xff0c;因为它们需要保持长时间的稳定运行&#xff0c;以提供持续的服务 安全性&#xff1a; Linux系统…

[Vulnhub] Simple CuteNews-CMS+Kernel权限提升

信息收集 IP AddressOpening Ports192.168.8.104TCP:80 $ nmap -p- 192.168.8.104 --min-rate 1000 -sC -sV PORT STATE SERVICE VERSION 80/tcp open http Apache httpd 2.4.7 ((Ubuntu)) |_http-title: Please Login / CuteNews |_http-server-header: Apache/2.4.7…

【matlab】大数据基础与应用实例

目录 引言 线性回归模型 基本形式 最小二乘法 多元线性回归 线性回归的假设 模型评估 应用 独热编码 原理 应用场景 优点 缺点 数据收集 数据可视化 数据处理与分析 完整代码 引言 线性回归模型 线性回归模型是一种用于预测连续值输出&#xff08;或称为因变…

一文了解MySQL的表级锁

文章目录 ☃️概述☃️表级锁❄️❄️介绍❄️❄️表锁❄️❄️元数据锁❄️❄️意向锁⛷️⛷️⛷️ 介绍 ☃️概述 锁是计算机协调多个进程或线程并发访问某一资源的机制。在数据库中&#xff0c;除传统的计算资源&#xff08;CPU、RAM、I/O&#xff09;的争用以外&#xff0…

【PVE】新增2.5G网卡作为主网卡暨iperf测速流程

【PVE】新增2.5G网卡作为主网卡暨iperf测速流程 新增网卡 新增网卡的首先当然需要关闭PVE母机&#xff0c;把新网卡插上&#xff0c;我用淘宝遥现金搞了个红包&#xff0c;花了26元买了块SSU的2.5G网卡。说实话这个价位连散热片都没有&#xff0c;确实挺丐的。稍后测下速度看…

Spring源码十九:Bean实例化流程二

上一篇我们在Spring源码十八&#xff1a;Bean实例化流程一 中&#xff0c;主要讨论了Spring在实例化前的两重要准备工作&#xff0c;1、获取我们前面注册好的BeanDefinition&#xff0c;将GenericBeanDefinition封装为RootBeanDefinition如果Bean Definition只存在父容器中&…

iNavFlight飞控固件学习-1《开发环境搭建》

目录 文章目录 目录摘要1.官网2.形成Linux开发环境工具2.1 简介2.2 相关工具2.2.1 Ubuntu / Debian系统配置命令2.2.2 Fedora系统配置命令2.2.3 Fedora系统配置命令 2.3 克隆存储库2.4 构建工具2.5 使用cmake2.6 构建固件2.7 清除2.8 cmake 缓存维护2.9 编译通过ninja2.10 更新…

面试内容集合

用例设计方法 &#xff08;一&#xff09;等价类划分  常见的软件测试面试题划分等价类: 等价类是指某个输入域的子集合.在该子集合中,各个输入数据对于揭露程序中的错误都是等效的.并合理地假定:测试某等价类的代表值就等于对这一类其它值的测试.因此,可以把全部输入数据合理…

自注意力简介

在注意力机制中&#xff0c;每个查询都会关注所有的键值对并生成一个注意力输出。如果查询q&#xff0c;键k和值v都来自于同一组输入&#xff0c;那么这个注意力就被称为是自注意力&#xff08;self-attention&#xff09;。自注意力这部分理论&#xff0c;我觉得台大李宏毅老师…

FFmpeg开发环境搭建

FFmpeg是音视频开发必备的库&#xff0c;也是唯一的库。本文主要讲解在ubuntu22和macOS14环境下的编译安装。 为什么要自己编译呢&#xff1f;其中一个很重要的原因就是ffmpeg在编译时可以加入很多插件&#xff0c;这种特定的库网络上可能找不到编译好的版本&#xff0c;另外如…

在linux中查找 / 目录下的以.jar结尾的文件(find / -name *.jar)

文章目录 1、查找 / 目录下的以.jar结尾的文件 1、查找 / 目录下的以.jar结尾的文件 [rootiZuf6332h890vozldoxcprZ ~]# find / -name *.jar /etc/java/java-1.8.0-openjdk/java-1.8.0-openjdk-1.8.0.342.b07-1.el9_0.x86_64/lib/security/policy/limited/US_export_policy.ja…

Spring MVC 中 HttpMessageConverter 转换器

1. Spring MVC 中 HttpMessageConverter 转换器 文章目录 1. Spring MVC 中 HttpMessageConverter 转换器2. 补充&#xff1a;什么是 HTTP 消息3. 转换器3.1 转换器转换的是什么 4. Spring MVC中的 AJAX 请求5. ResponseBody 将服务器端的 return 返回值转化为“字符串(JSON格式…

RIP路由协议概述

RIP【Routing Information Protocol】 RIP是为TCP/IP 环境中开发的第一个路由选择协议标准 RIP是一个【距离——矢量】路由选择协议 当路由器的更新周期为30s到来时&#xff0c;向邻居发送路由表 RIP以跳数作为唯一度量值 RIP工作原理 RIP路由协议向邻居发送整个路由表信息RI…

Linux 复现Docker NAT网络

Linux 复现Docker NAT网络 docker 网络的构成分为宿主机docker0网桥和为容器创建的veth 对构成。这个默认网络命名空间就是我们登陆后日常使用的命名空间 使用ifconfig命令查看到的就是默认网络命名空间&#xff0c;docker0就是网桥&#xff0c;容器会把docker0当成路由&…

JavaDS —— 单链表 与 LinkedList

顺序表和链表区别 ArrayList &#xff1a; 底层使用连续的空间&#xff0c;可以随机访问某下标的元素&#xff0c;时间复杂度为O&#xff08;1&#xff09; 但是在插入和删除操作的时候&#xff0c;需要将该位置的后序元素整体往前或者向后移动&#xff0c;时间复杂度为O&…

代码随想录二刷7.22|977.有序数组的平方

暴力解法&#xff1a; ——如果想暴力解决这个问题的话&#xff0c;可以像题目那样&#xff0c;先将每一个元素平方&#xff0c;然后再排序 双指针&#xff1a; ——从题目中找到的信息&#xff1a;这是一个非递减顺序的整数数组&#xff0c;从例子中&#xff0c;可以容易看…

excel 百分位函数 学习

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 前言一、函数说明PERCENTILE 函数PERCENTILE.inc 函数PERCENTILE.exc 函数QUARTILE.EXC 函数 二、使用步骤总结 前言 excel 百分位函数 Excel提供了几个函数用于…