查看神经网络中间层特征矩阵及卷积核参数

可视化feature maps以及kernel weights,使用alexnet模型进行演示。

1. 查看中间层特征矩阵

alexnet模型,修改了向前传播

import torch
from torch import nn
from torch.nn import functional as F

# 对花图像数据进行分类
class AlexNet(nn.Module):
    def __init__(self,num_classes=1000,init_weights=False, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.conv1 = nn.Conv2d(3,48,11,4,2)
        self.pool1 = nn.MaxPool2d(3,2)
        self.conv2 = nn.Conv2d(48,128,5,padding=2)
        self.pool2 = nn.MaxPool2d(3,2)
        self.conv3 = nn.Conv2d(128,192,3,padding=1)
        self.conv4 = nn.Conv2d(192,192,3,padding=1)
        self.conv5 = nn.Conv2d(192,128,3,padding=1)
        self.pool3 = nn.MaxPool2d(3,2)

        self.fc1 = nn.Linear(128*6*6,2048)
        self.fc2 = nn.Linear(2048,2048)
        self.fc3 = nn.Linear(2048,num_classes)
        # 是否进行初始化
        # 其实我们并不需要对其进行初始化,因为在pytorch中,对我们对卷积及全连接层,自动使用了凯明初始化方法进行了初始化
        if init_weights:
            self._initialize_weights()

    def forward(self,x):
        outputs = []  # 定义一个列表,返回我们要查看的哪一层的输出特征矩阵
        x = self.conv1(x)
        outputs.append(x)
        x = self.pool1(F.relu(x,inplace=True))
        x = self.conv2(x)
        outputs.append(x)
        x = self.pool2(F.relu(x,inplace=True))
        x = self.conv3(x)
        outputs.append(x)
        x = F.relu(x,inplace=True)
        x = F.relu(self.conv4(x),inplace=True)
        x = self.pool3(F.relu(self.conv5(x),inplace=True))
        x = x.view(-1,128*6*6)
        x = F.dropout(x,p=0.5)
        x = F.relu(self.fc1(x),inplace=True)
        x = F.dropout(x,p=0.5)
        x = F.relu(self.fc2(x),inplace=True)
        x = self.fc3(x)

        # for name,module in self.named_children():
        #     x = module(x)
        #     if name == ["conv1","conv2","conv3"]:
        #         outputs.append(x)
        return outputs

    # 初始化权重
    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m,nn.Conv2d):
                # 凯明初始化 - 何凯明
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m,nn.Linear):
                nn.init.normal_(m.weight, 0,0.01)  # 使用正态分布给权重赋值进行初始化
                nn.init.constant_(m.bias,0)

拿到向前传播的结果,对特征图进行可视化,这里,我们使用训练好的模型,直接加载模型参数。

注意,要使用与训练时相同的数据预处理。

import matplotlib.pyplot as plt
from torchvision import transforms
import alexnet_model
import torch
from PIL import Image
import numpy as np
from alexnet_model import AlexNet

# AlexNet 数据预处理
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
# 实例化模型
model = AlexNet(num_classes=5)
weights = torch.load("./alexnet_weight_20.pth", map_location="cpu")
model.load_state_dict(weights)

image = Image.open("./images/yjx.jpg")
image = transform(image)
image = image.unsqueeze(0)

with torch.no_grad():
    output = model(image)

for feature_map in output:
    # (N,C,W,H) -> (C,W,H)
    im = np.squeeze(feature_map.detach().numpy())
    # (C,W,H) -> (W,H,C)
    im = np.transpose(im,[1,2,0])
    plt.figure()
    # 展示当前层的前12个通道
    for i in range(12):
        ax = plt.subplot(3,4,i+1) # i+1: 每个图的索引
        plt.imshow(im[:,:,i],cmap='gray')
    plt.show()

结果:

在这里插入图片描述


2. 查看卷积核参数

import matplotlib.pyplot as plt
import numpy as np
import torch

from AlexNet.model import AlexNet

# 实例化模型
model = AlexNet(num_classes=5)
weights = torch.load("./alexnet_weight_20.pth", map_location="cpu")
model.load_state_dict(weights)

weights_keys = model.state_dict().keys()
for key in weights_keys:
    if "num_batches_tracked" in key:
        continue
    weight_t = model.state_dict()[key].numpy()
    weight_mean = weight_t.mean()
    weight_std = weight_t.std(ddof=1)
    weight_min = weight_t.min()
    weight_max = weight_t.max()
    print("mean is {}, std is {}, min is {}, max is {}".format(weight_mean, weight_std, weight_min, weight_max))

    weight_vec = np.reshape(weight_t,[-1])
    plt.hist(weight_vec,bins=50)
    plt.title(key)
    plt.show()

结果:

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/332061.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java网络编程:概述--快速入门

I. 介绍 1.1 什么是网络编程 - 网络编程是指通过计算机网络实现程序之间的通信。在Java中,网络编程通常涉及到数据的传输、通信协议的使用以及与网络相关的各种操作。 1.2. 为什么学习Java网络编程 - Java网络编程是Java开发者重要的技能之一,因为它允许…

HarmonyOS —— buildMode 设置(对比 Android Build Varient)

前言 在安卓中 Build Variant 主要依赖模块(module)中 build.gradle 的 BuildType 和 ProductFlavor 提供的属性和方法,我们可以使用 Build Type 可以配置不同的构建方式、ProductFlavor 主要用来进行多渠道打包。 在鸿蒙中要做到同样像效果…

Spring Boot 配置文件和日志

目录 配置文件格式 properties配置文件说明 1.properties基本语法 2.读取配置文件 3.properties缺点 yml配置文件说明 1.yml基本语法 2.配置不同数据类型 3.字符串特殊情况 4.配置对象 properties和yml对比 日志 日志的使用 日志级别 日志持久化 Lombok Lombo…

计算机网络课程设计-网络聊天程序的设计与实现

目录 前言 1 实验题目 2 实验目的 3 实验内容 3.1 客户端 3.1.1 步骤 3.1.2 关键代码 3.2 服务器 3.2.1 步骤 3.2.2 关键代码 4 实验结果与分析 5 代码 5.1 客户端 5.2 服务器 前言 本实验为计算机网络课程设计内容,基本上所有代码都是根据指导书给的附…

1.18寒假集训

A: 解题思路&#xff1a; 这题看似很复杂&#xff0c;其实很简单&#xff0c;找规律不难发现就是输出n 1 下面是c代码&#xff1a; #include<iostream> using namespace std; int main() {int n;cin >> n;cout << n 1;return 0; } B: 解题思路&#xf…

3dmax中怎么在模型上开洞?

3dmaxS是Autodesk公司开发的基于PC系统的三维动画渲染和制作软件。我们可以使用它来做各种模型。那么怎么在模型上开洞呢&#xff1f;我们一起来看看吧&#xff01; 1、首先我们打开我们的3damx&#xff0c;这里面我使用的版本为3damxs2012,虽然版本可能各不相同。但是功能并没…

【RT-DETR有效改进】轻量级视觉变换器RepViT改进特征提取网络(轻量化网络)

前言 大家好&#xff0c;我是Snu77&#xff0c;这里是RT-DETR有效涨点专栏。 本专栏的内容为根据ultralytics版本的RT-DETR进行改进&#xff0c;内容持续更新&#xff0c;每周更新文章数量3-10篇。 专栏以ResNet18、ResNet50为基础修改版本&#xff0c;同时修改内容也支持Re…

数据库MySQL----多表查询

二、多表查询 1.创建student和score表 CREATE TABLE student ( id INT(10) NOT NULL UNIQUE PRIMARY KEY , name VARCHAR(20) NOT NULL , sex VARCHAR(4) , birth YEAR, department VARCHAR(20) , address VARCHAR(50) ); CREATE TABLE score ( id INT(10) NOT…

翻遍全网才找到的Oracle19c安装教程(Windows版)

一、下载 官方地址&#xff1a;Database Software Downloads | Oracle 中国 我下载下来了&#xff0c;有需要的可以网盘领取。 链接&#xff1a;https://pan.baidu.com/s/1-OeiuoBX3K4X7Kw7vGxdFA?pwdhssy 提取码&#xff1a;hssy 二、解压 下载下来是一个zip格式的压缩…

Java爬虫爬取图片壁纸

Java爬虫 以sougou图片为例&#xff1a;https://pic.sogou.com/ JDK17、SpringBoot3.2.X、hutool5.8.24实现Java爬虫&#xff0c;爬取页面图片 项目介绍 开发工具&#xff1a;IDEA2023.2.5 JDK&#xff1a;Java17 SpringBoot&#xff1a;3.2.x 通过 SpringBoot 快速构建开发环境…

身份验证遇到问题,登陆ChatGPT时提示:“we ran into an issue while authenticating you…”

oops&#xff01; we ran into an issue while authenticating you, if this issue persists, please contact us through our help center at help.openai.com 说明&#xff1a;哎呀&#xff01;我们在验证您的身份时遇到了一个问题&#xff0c;如果这个问题仍然存在&#xff…

Linux:shell脚本:基础使用(8)《函数局部|全局变量函数传入位置变量return》

基本的函数定义 把一些重复调用的命令写进一个函数里&#xff0c;下次直接调用函数名&#xff0c;这样的既方便修改&#xff0c;又可以让思路清晰 function 函数名(){ 当调用这个函数时候执行的命令...... } 这个是一个基础的函数定义&#xff0c;当然你不加function也是可以的…

从 GPT1 - GPT4 拆解

从 GPT1 - GPT4 拆解 从 GPT1 - GPT4GPT1&#xff1a;更适用于文本生成领域GPT2&#xff1a;扩展数据集、模型参数&#xff0c;实现一脑多用&#xff08;多个任务&#xff09;GPT3&#xff1a;元学习 大力出奇迹InstructGPT&#xff1a;指示和提示学习 人工反馈强化学习 RLHF…

实验二 体系结构

&#x1f57a;作者&#xff1a; 主页 我的专栏C语言从0到1探秘C数据结构从0到1探秘Linux菜鸟刷题集 &#x1f618;欢迎关注&#xff1a;&#x1f44d;点赞&#x1f64c;收藏✍️留言 &#x1f3c7;码字不易&#xff0c;你的&#x1f44d;点赞&#x1f64c;收藏❤️关注对我真的…

Open3D 点云转深度图像

目录 一、算法原理1、算法过程2、主要函数二、代码实现三、结果展示1、点云2、深度图像四、测试数据Open3D 点云转深度图像由CSDN点云侠原创。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫与GPT。<

TypeScript 中的深拷贝和浅拷贝

什么是深拷贝 在JavaScript/TypeScript中&#xff0c;深拷贝是指创建一个对象的副本&#xff0c;而不仅仅是创建对原始对象的引用。对复制的对象进行的任何更改都不会影响原始对象&#xff0c;反之亦然。 这个副本将完全复制基础对象&#xff0c;包括每个嵌套级别的所有属性和…

在vite5和vue3开发环境中使用jodit4富文本编辑器,并添加自定义插件和使用highlight.js实现代码块高亮(附其他自定义配置项和全部代码)

最近富文本编辑器jodit终于更新发布到了4.0版本&#xff0c;加入了css变量、有更好的typescript支持&#xff0c;截止发文时的版本是&#xff1a;4.0.5&#xff0c;看到有了新版本于是便想着将本地项目中的jodit版本也进行升级&#xff0c;琢磨着再丰富和添加一些功能&#xff…

excel 各种用法

excel 各种用法 实现两张表格数据关联 vlookup 实现两张表格数据关联 vlookup 实现两个 sheet 间的关联需要用 vlookup 函数实现 函数第一个参数设置成 sheet1 中的 A 列&#xff08;如果数据很多&#xff0c;可以直接选中 A 列&#xff09; 函数的第二个参数选中 sheet2 中…

android 和 opencv 开发环境搭建

本文详细说明给android项目添加opencv库的详细步骤&#xff0c;并通过实现图片灰度化来查看配置是否成功。 下载OPENCV ANDROID SDK 到官网下载 打开 https://opencv.org/releases/ 选择android&#xff0c;下载完成后解压出下面的文件&#xff1a; 安装android sdk 和 ndk …

YOLOv8全网首发:DCNv4更快收敛、更高速度、更高性能,效果秒杀DCNv3、DCNv2等 ,助力检测

💡💡💡本文独家改进:DCNv4更快收敛、更高速度、更高性能,完美和YOLOv8结合,助力涨点 DCNv4优势:(1) 去除空间聚合中的softmax归一化,以增强其动态性和表达能力;(2) 优化存储器访问以最小化冗余操作以加速。这些改进显著加快了收敛速度,并大幅提高了处理速度,DCN…