7、深入剖析PyTorch nn.Module源码

文章目录

  • 1. 重要类
  • 2. add_modules
  • 3. Apply(fn)
  • 4. register_buffer
  • 5. nn.Parameters®ister_parameters
  • 6. 后续测试

1. 重要类

  • nn.module --> 所有神经网络的父类,自定义神经网络需要继承此类,并且自定义__init__,forward函数即可:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :MyModelNet.py
# @Time      :2024/11/20 13:38
# @Author    :Jason Zhang
import torch
from torch import nn


class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork,self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


if __name__ == "__main__":
    run_code = 0
    x_row = 28
    x_column = 28
    x_total = x_row * x_column
    x = torch.arange(x_total, dtype=torch.float).reshape((1, x_row, x_column))
    my_net = NeuralNetwork()
    y = my_net(x)
    print(f"y.shape={y.shape}")
    print(my_net)
  • 结果:
y.shape=torch.Size([1, 10])
NeuralNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)

2. add_modules

通过add_modules在旧的网络里面添加新的网络

  • 重点: 用nn.ModuleList自带的insert,新的网络继承自老网络中,直接用按位置插入
  • python
import torch
from torch import nn
from pytorch_model_summary import summary

torch.manual_seed(2323)


class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.flatten = nn.Flatten()
        self.block = nn.ModuleList([
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        ])

    def forward(self, x):
        x = self.flatten(x)
        for layer in self.block:
            x = layer(x)
        return x


class MyNewNet(MyModel):
    def __init__(self):
        super(MyNewNet, self).__init__()
        self.block.insert(2, nn.Linear(512, 256))  # 插入新层
        self.block.insert(3, nn.ReLU())  # 插入新的激活函数
        self.block.insert(4, nn.Linear(256, 512))  # 插入另一层
        self.block.insert(5, nn.ReLU())  # 插入激活函数


if __name__ == "__main__":
    # 测试原始模型
    my_model = MyModel()
    print("Original Model:")
    print(summary(my_model, torch.ones((1, 28, 28))))

    # 测试新模型
    my_new_model = MyNewNet()
    print("\nNew Model:")
    print(summary(my_new_model, torch.ones((1, 28, 28))))
  • 结果:
Original Model:
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
=======================================================================
         Flatten-1            [1, 784]               0               0
          Linear-2            [1, 512]         401,920         401,920
            ReLU-3            [1, 512]               0               0
          Linear-4             [1, 10]           5,130           5,130
=======================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
-----------------------------------------------------------------------

New Model:
-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
=======================================================================
         Flatten-1            [1, 784]               0               0
          Linear-2            [1, 512]         401,920         401,920
            ReLU-3            [1, 512]               0               0
          Linear-4            [1, 256]         131,328         131,328
            ReLU-5            [1, 256]               0               0
          Linear-6            [1, 512]         131,584         131,584
            ReLU-7            [1, 512]               0               0
          Linear-8             [1, 10]           5,130           5,130
=======================================================================
Total params: 669,962
Trainable params: 669,962
Non-trainable params: 0
-----------------------------------------------------------------------

3. Apply(fn)

模型权重weight,bias 的初始化

  • python
import torch.nn as nn
import torch


class MyAwesomeModel(nn.Module):
    def __init__(self):
        super(MyAwesomeModel, self).__init__()
        self.fc1 = nn.Linear(3, 4)
        self.fc2 = nn.Linear(4, 5)
        self.fc3 = nn.Linear(5, 6)


# 定义初始化函数
@torch.no_grad()
def init_weights(m):
    print(m)
    if type(m) == nn.Linear:
        m.weight.fill_(1.0)
        print(m.weight)


# 创建神经网络实例
model = MyAwesomeModel()

# 应用初始化权值函数到神经网络上
model.apply(init_weights)
  • 结果:
Linear(in_features=3, out_features=4, bias=True)
Parameter containing:
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], requires_grad=True)
Linear(in_features=4, out_features=5, bias=True)
Parameter containing:
tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]], requires_grad=True)
Linear(in_features=5, out_features=6, bias=True)
Parameter containing:
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]], requires_grad=True)
MyAwesomeModel(
  (fc1): Linear(in_features=3, out_features=4, bias=True)
  (fc2): Linear(in_features=4, out_features=5, bias=True)
  (fc3): Linear(in_features=5, out_features=6, bias=True)
)

Process finished with exit code 0

4. register_buffer

将模型中添加常数项。比如加1

  • python:
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :RegisterBuffer.py
# @Time      :2024/11/23 19:21
# @Author    :Jason Zhang
import torch
from torch import nn


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()
        self.register_buffer("my_buffer_a", torch.ones(2, 3))

    def forward(self, x):
        x = x + self.my_buffer_a
        return x


if __name__ == "__main__":
    run_code = 0
    my_test = MyNet()
    in_x = torch.arange(6).reshape((2, 3))
    y = my_test(in_x)
    print(f"x=\n{in_x}")
    print(f"y=\n{y}")
  • 结果:
x=
tensor([[0, 1, 2],
        [3, 4, 5]])
y=
tensor([[1., 2., 3.],
        [4., 5., 6.]])

5. nn.Parameters&register_parameters

  • python
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# @FileName  :ParameterTest.py
# @Time      :2024/11/23 19:37
# @Author    :Jason Zhang
import torch
from torch import nn


class MyModule(nn.Module):
    def __init__(self, in_size, out_size):
        self.in_size = in_size
        self.out_size = out_size
        super(MyModule, self).__init__()
        self.test = torch.rand(self.in_size, self.out_size)
        self.linear = nn.Linear(self.in_size, self.out_size)

    def forward(self, x):
        x = self.linear(x)
        return x


class MyModuleRegister(nn.Module):
    def __init__(self, in_size, out_size):
        self.in_size = in_size
        self.out_size = out_size
        super(MyModuleRegister, self).__init__()
        self.test = torch.rand(self.in_size, self.out_size)
        self.linear = nn.Linear(self.in_size, self.out_size)

    def forward(self, x):
        x = self.linear(x)
        return x


class MyModulePara(nn.Module):
    def __init__(self, in_size, out_size):
        self.in_size = in_size
        self.out_size = out_size
        super(MyModulePara, self).__init__()
        self.test = nn.Parameter(torch.rand(self.in_size, self.out_size))
        self.linear = nn.Linear(self.in_size, self.out_size)

    def forward(self, x):
        x = self.linear(x)
        return x


if __name__ == "__main__":
    run_code = 0
    test_in = 4
    test_out = 6
    my_test = MyModule(test_in, test_out)
    my_test_para = MyModulePara(test_in, test_out)
    test_list = list(my_test.named_parameters())
    test_list_para = list(my_test_para.named_parameters())
    my_test_register = MyModuleRegister(test_in, test_out)
    para_register = nn.Parameter(torch.rand(test_in, test_out))
    my_test_register.register_parameter('para_add_register', para_register)
    test_list_para_register = list(my_test_register.named_parameters())

    print(f"*" * 50)
    print(f"test_list=\n{test_list}")
    print(f"*" * 50)
    print(f"*" * 50)
    print(f"test_list_para=\n{test_list_para}")
    print(f"*" * 50)
    print(f"*" * 50)
    print(f"test_list_para_register=\n{test_list_para_register}")
    print(f"*" * 50)
  • 结果:
**************************************************
test_list=
[('linear.weight', Parameter containing:
tensor([[ 0.3805, -0.3368,  0.2348,  0.4525],
        [-0.4557, -0.3344,  0.1368, -0.3471],
        [-0.3961,  0.3302,  0.1904, -0.0111],
        [ 0.4542, -0.3325, -0.3782,  0.0376],
        [ 0.2083, -0.3113, -0.3447, -0.1503],
        [ 0.0343,  0.0410, -0.4216, -0.4793]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([-0.3465, -0.4510,  0.4919,  0.1967, -0.1366, -0.2496],
       requires_grad=True))]
**************************************************
**************************************************
test_list_para=
[('test', Parameter containing:
tensor([[0.1353, 0.9934, 0.0462, 0.2103, 0.3410, 0.0814],
        [0.7509, 0.2573, 0.8030, 0.0952, 0.1381, 0.5360],
        [0.1972, 0.1241, 0.5597, 0.2691, 0.3226, 0.0660],
        [0.3333, 0.8031, 0.9226, 0.4290, 0.3660, 0.6159]], requires_grad=True)), ('linear.weight', Parameter containing:
tensor([[-0.0633, -0.4030, -0.4962,  0.1928],
        [-0.1707,  0.2259,  0.0373, -0.0317],
        [ 0.4523,  0.2439, -0.1376, -0.3323],
        [ 0.3215,  0.1283,  0.0729,  0.3912],
        [ 0.0262, -0.1087,  0.4721, -0.1661],
        [-0.1055, -0.2199, -0.4974, -0.3444]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([ 0.3702, -0.0142, -0.2098, -0.0910, -0.2323, -0.0546],
       requires_grad=True))]
**************************************************
**************************************************
test_list_para_register=
[('para_add_register', Parameter containing:
tensor([[0.2428, 0.1388, 0.6612, 0.4215, 0.0215, 0.2618],
        [0.4234, 0.0160, 0.8947, 0.4784, 0.4403, 0.4800],
        [0.8845, 0.1469, 0.6894, 0.7050, 0.5911, 0.7702],
        [0.7694, 0.0491, 0.3583, 0.4451, 0.2282, 0.4293]], requires_grad=True)), ('linear.weight', Parameter containing:
tensor([[ 0.1358, -0.4704, -0.4181, -0.4504],
        [ 0.0903,  0.3235, -0.3164, -0.4163],
        [ 0.1342,  0.3108,  0.0612, -0.2910],
        [ 0.3527,  0.3397, -0.0414, -0.0408],
        [-0.4877,  0.1925, -0.2912, -0.2239],
        [-0.0081, -0.1730,  0.0921, -0.4210]], requires_grad=True)), ('linear.bias', Parameter containing:
tensor([-0.2194,  0.2233, -0.4950, -0.3260, -0.0206, -0.0197],
       requires_grad=True))]
**************************************************

6. 后续测试

  • register_module
  • get_submodule
  • get_parameter

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/922130.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Gate学习(6) 指令学习3

一、/particle/ 目录及其子目录下的命令 在 `/particle/` 命令目录及其子目录下,可以控制和管理粒子相关的属性和过程。以下是每个命令目录和命令的简要解释: ### `/particle/` 这是粒子控制命令的主目录,包括选择粒子、列出粒子名称、查找粒子编码、创建所有离子和同位旋等…

【Git】:Git基本操作

目录 创建、配置本地仓库 创建本地仓库 配置本地仓库 认识工作区、暂存区、版本库 修改文件 版本回退 撤销修改 删除文件 创建、配置本地仓库 创建本地仓库 我们通常可以通过以下两种方式之一获取 Git 存储库: 自己在本地目录创建一个本地仓库 从其它服务…

android 性能分析工具(03)Android Studio Profiler及常见性能图表解读

说明:主要解读Android Studio Profiler 和 常见性能图表。 Android Studio的Profiler工具是一套功能强大的性能分析工具集,它可以帮助开发者实时监控和分析应用的性能,包括CPU使用率、内存使用、网络活动和能耗等多个方面。以下是对Android …

LabVIEW配电网谐波在线监测与分析系统

统利用LabVIEW与NI数据采集卡,结合高精度谐波分析算法,实现了配电网谐波的实时监测与分析。通过虚拟仪器技术的灵活性和扩展性,显著提高电网运行的可靠性与电能质量,提供了一套有效的技术解决方案。 项目背景 随着非线性负载&am…

git使用(二)

git使用(二) git常用基本操作命令git clonegit loggit remotegit statusgit addgit commitgit pushgit branchgit pull git常用基本操作命令 git clone 项目开发中项目负责人会在github上创建一个远程仓库,我们需要使用git clone将远程仓库…

Excel求和如何过滤错误值

一、问题的提出 平时,我们在使用Excel时,最常用的功能就是求和了,一说到求和你可能想到用sum函数,但是如果sum的求和区域有#value #Div等错误值怎么办?如下图,记算C列中工资的总和。 直接用肯定会报错&…

【数据分享】2024年我国省市县三级的住宿服务设施数量(8类住宿设施/Excel/Shp格式)

宾馆酒店、旅馆招待所等住宿服务设施的配置情况是一个城市公共基础设施完善程度的重要体现,一个城市住宿服务设施种类越丰富,数量越多,通常能表示这个城市的公共服务水平越高! 本次我们为大家带来的是我国各省份、各地级市、各区…

自制Windows系统(十)

上图 (真的不是Windows破解版) 开源地址:仿Windows

Ubuntu20.04下安装向日葵

向日葵远程控制app官方下载 - 贝锐向日葵官网 下载Ununtu版的图形版本的安装deb包SunloginClient_15.2.0.63064_amd64.deb 直接执行 sudo dpkg -i SunloginClient_15.2.0.63064_amd64.deb 的话会报错: 如果在Ubuntu20.04里直接执行sudo apt install libgconf-2-4安装libgco…

vitepress博客模板搭建

vitepress博客搭建 个人博客技术栈更新,快速搭建一个vitepress自定义博客 建议去博客查看文章,观感更佳。原文地址 模板仓库: vitepress-blog-template 前言 服务器过期快一年了,博客也快一年没更新了,最近重新搭…

修改一下达梦disql 提示符

经常用disql的有时某些信息希望提示一下,默认的只显示SQL> 为了方便使用,可以在 glogin.sql 中增加些内容。 vi $DM_HOME/bin/disql_conf/glogin.sql增加以下几行 set time on set lineshow offcol global_name new_value global_name SELECT ins…

TCP vs UDP:如何选择适合的网络传输协议?

在网络通信中,TCP(Transmission Control Protocol)和UDP(User Datagram Protocol)是两种非常重要的传输层协议。它们各有特点,适用于不同类型的应用场景。本文将详细探讨TCP和UDP协议的结构、优缺点及应用&…

支付宝租赁小程序的优势与应用前景分析

内容概要 在这个快节奏的时代,租赁服务越来越成为大家生活中的一部分。而支付宝租赁小程序正是这个大潮流中的一颗璀璨明珠。通过简单易用的界面和强大的功能,这个小程序不仅让用户在租赁过程中获得了前所未有的便利,也为商家提供了新的商业…

前端面试笔试(五)

最近面试中手撕题以及笔试中总遇到递归回溯类题目,于是去牛客上找典型题目。这里浅浅列一道。 目录 题目 解决 代码 详情 变量初始化 递归函数dg 递归终止条件 递归主体 初始调用和返回结果 题目 没有重复项数字的全排列(递归回溯,js…

【数据分享】中国价格统计年鉴(2013-2024) PDF

数据介绍 犹如一座珍贵的宝库,全面而系统地记录了中国在这一时期的价格变动情况。它涵盖了丰富的内容,包括宏观经济指标、商品价格、居民消费价格以及城市物价监测等多个方面。 在宏观经济指标方面,年鉴中收录了 GDP、CPI、PPI 等重要数据&…

数据集-目标检测系列- 花卉 鸡蛋花 检测数据集 frangipani >> DataBall

数据集-目标检测系列- 花卉 鸡蛋花 检测数据集 frangipani >> DataBall DataBall 助力快速掌握数据集的信息和使用方式,会员享有 百种数据集,持续增加中。 贵在坚持! 数据样例项目地址: * 相关项目 1)数据集…

manin动画编程(安装+入门)

文章目录 1.基本介绍2.效果展示3.安装步骤3.1安装manba软件3.2配置环境变量3.3查看是否成功3.4什么是mamba3.5创建虚拟环境3.6尝试进入虚拟环境 4.vscode操作4.1默认配置文件 5.安装ffmpeg6.安装manim软件6.vscode制作7.我的学习收获 1.基本介绍 这个manim就是一款软件&#x…

《Python编程实训快速上手》第十天--处理CSV文件和JSON数据

CSV:简化的电子表格,被保存为纯文本文件 JSON:是一种数据交换格式,易于人阅读和编写,同时也易于机器解析和生成,以JavaScript源代码的形式将信息保存在纯文本文件中 一、csv模块 CSV文件中的每行代表电…

利用FileZilla搭建ftp服务器

一 利用windows自带的ftp服务搭建服务器,要复杂一些,好处是无需借用外部软件。 也有一些好的工具,如FileZilla的Server版,构建过程简单,好用。 下面看看。 二 安装FileZilla Server 当前下载版本是0.9.43&#xf…

P1 练习卷(C++4道题)

1.纷繁世界 内存限制:256MB 时间限制:1s 问题描述 这是一个纷繁复杂的世界。 某一天清晨你起床很迟,没有吃上早饭。于是你骑着自行车去超市,但是你又发现商店的工作人员已经重新贴上了价格标签,零食价格都涨了50%。你…