通过类定义一个网络

import torch
from torch import nn

x = torch.ones(2,10)

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.out = nn.Linear(10, 1)
    def forward(self,x):
        return self.out(x)

 1. 代码解析

  • 如何定义一个类?self 又是什么东西?
  • 类是如何继承基类的特性的?nn.module 是个什么对象?
  • 为什么会有一个初始化函数 init,初始化函数中的 super().init 函数是做什么用的?是否必须要有?
  • forward 函数有什么作用?该怎样用?
  • 为什么初始化函数前后有两条下划线?forward 前后为什么没有下划线?
  • nn.Linear 函数是干什么用的?
  • 输入输出张量的形状大小都是怎么对应的?
  • 模型内部的网络参数是怎么定义的,如何查看?

1.1 如何定义一个类 

class MLPSimple(nn.Module):
    def __init__(self):
        super().__init__()
        self.out = nn.Linear(10, 1)

a_s = MLPSimple()
a_s

 self 是指向类自身的一个指针,可以通过该指针引用类自身的成员,默认这个参数是每个成员函数的首个输入参数,如果没有self参数,那么定义的函数将无法引用类自身的成员

1.2 如何继承基类的特性?nn.module 是个什么对象?

定义类的时候,将需要被继承的类的名称作为参数传入,如 class MLPSimple(nn.Module) 这样就是定义了一个新的类 MLPSimple ,这个新类继承了 nn.Module 的所有特性。 以下展示再创建一个类,继承刚刚创建的新类 MLPSimple_p。

class MLPSimple_p(MLPSimple):
    def __init__(self):
        super().__init__()
a_s_p = MLPSimple_p()
a_s_p

nn.Module是PyTorch中的一个类,继承自torch.nn的基类,用于定义神经网络模型、提供前向传播过程所需的基本功能和方法。 在PyTorch中,神经网络模型通常是由多个层组成的,每个层都是一个nn.Module实例。通过继承nn.Module类并实现自己的forward方法,可以定义自己的神经网络模,。在神经网络的训练和推理过程中,PyTorch会自动调用nn.Module的forward方法来计算输出。

1.3 为什么会有一个初始化函数 init,初始化函数中的 super().init 函数是做什么用的?是否必须要有? 

与C++不同,我们自己新定义的python类没有显式的构造函数(python 类有自己的构造函数,该构造函数跟 init 函数一样也是个魔法函数),python类的对init函数的调用,可以被看做是类似于C++类调用构造函数类似的过程,当python中通过类创建对象的时候就会调用init函数对对象进初始化,与c++不同的是如果c++继承了基类,那么构造对象的时候会隐式的调用基类的构造方法,这里python却需要显示的主动调用基类初始化方法super().init()对基类的特性进行初始化。这里的显示调用时必须的,如果漏掉会报错。

1.4 forward 函数有什么作用?该怎样用?

python 的类成员方法中有一个非常特殊的函数叫做 call() 函数,这个函数使得实例化的对象自身可以像一个函数一样被调用,如同样实例对象为 a_s_p ,如果这个对象是在C++中,那么这个对象就单纯的是一个对象而已,要想让这个对象处理一些事情,就必须通过对象去调用它自身的一些方法来实现,如a_s_p.func(),但是在python类中,类定义里面有一个特别的函数叫做call函数,这个函数可以使被实例的对象本身像一个函数一样被直接调用,而在pytorch中这个call函数会默认直接调用创建类的forward函数,forward函数会接受所有传递给call函数的参数,call函数本身也会将forward函数的返回结果直接返回,因此就形成了pytorch中这种可以直接通过对象本身来处理数据的现象。

a_s_p(inputs) 隐含的意思就是 a_s_p.call(inputs) , 而 a_s_p.call(inputs) 本身的定义却类似于以下这样:

class MLPSimple(nn.Module):
    def __call__(self,inputs):
        return self.forward(inputs)
    def forward(self,inputs)
        return outputs

1.5 为什么初始化函数前后有两条下划线?forward 前后为什么没有下划线?

函数前后有两条下划线的方法叫做python的魔法函数,魔法函数本身是指的到了特定状况下会自动被调用的函数,因为其自适应性像魔法一样神奇所以被称为魔法函数,没有下划线的函数指的是普通函数,像forward函数的名称是pytorch的保留字,默认被call函数调用,但它仍然跟普通函数一样,没什么特别之处。其他的魔法函数还有如下这些:

  • init():类的初始化方法,在创建类的实例时自动调用。
  • new():类的构造函数,当使用类的构造函数创建新的类实例时自动调用。
  • str():返回对象的字符串表示,当调用print()函数输出对象时自动调用。
  • del():在对象被删除时自动调用。
  • call():当对象被作为函数调用时自动调用。
  • len():返回对象的长度,当使用len()函数调用对象时自动调用。
  • eq():比较两个对象是否相等,当使用==运算符比较两个对象时自动调用。
  • hash():返回对象的哈希值,当使用hash()函数调用对象时自动调用。
  • getitem():当使用方括号运算符[]访问对象的元素时自动调用。
  • setitem():当使用方括号运算符[]修改对象的元素时自动调用。的元素时自动调用。[]修改对的元素时自动调用。

1.6 nn.Linear 函数是干什么用的?

对输入向量进行线性变换的一个网络层类,,与之类似的类还有以下几个类: 'Bilinear(双线性变换' 'Identity(占位符) 'LazyLinear'(系数矩阵尺寸在第一次被调用时候自动初始化,不需要主动指定)

class Linear(Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = True,device=None, dtype=None) -> None:
        ...

   def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)

详细介绍见官方文档:Linear — PyTorch 2.0 documentation

初始化时候指定input_tensor[size_in_0,in_features] \output_tensor[size_out_0,out_features],即指定了线性层的系数矩阵尺寸 weight[in_features,out_features],计算时候 out_tensor = input_tensor * weight = [size_in_0,in_features] * [in_features,out_features] = [size_in_0 ,out_features] 在开头的例子中即为 out_tensor = input_tensor * weight = [2,10] * [10,1] = [2 ,1] 以上计算过程中可以发现,矩阵的最后一个维度是样本的特征维度,比如说线性变换中的自变量个数即为 in_features = 10 , 因变量的个数为 out_features = 1 ,这两个个数即为单个样本的特征维度或者说是特征数。倒数第二个维度是样本的批量大小,像本例中输入样本为2,输出样本自然的对应也应该是2,输入样本的数量不需要单独指定,在传入模型处理的时候,模型会自动去识别处理。

1.7 模型内部的网络参数是怎么定义的,如何查看?

访问权重系数 :print(a.out.weight)

访问偏置系数 :print(a.out.bias)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/101115.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Doris架构中包含哪些技术?

Doris主要整合了Google Mesa(数据模型),Apache Impala(MPP Query Engine)和Apache ORCFile (存储格式,编码和压缩)的技术。 为什么要将这三种技术整合? Mesa可以满足我们许多存储需求的需求,但是Mesa本身不提供SQL查询引擎。 Impala是一个…

开发指导—利用CSS动画实现HarmonyOS动效(一)

注:本文内容分享转载自 HarmonyOS Developer 官网文档 一. CSS 语法参考 CSS 是描述 HML 页面结构的样式语言。所有组件均存在系统默认样式,也可在页面 CSS 样式文件中对组件、页面自定义不同的样式。请参考通用样式了解兼容 JS 的类 Web 开发范式支持的…

TCP协议报文

前言 TCP/IP协议簇——打开虚拟世界大门中,已经给大家大致介绍了TCP/IP协议簇的分层。 TCP (Transmission Control Protocol)传输控制协议,在TCP/IP协议簇中,处于传输层。是为了在不可靠的互联网络(IP协议)中&#x…

蒲公英路由器如何设置远程打印?

现如今,打印机已经是企业日常办公中必不可少的设备,无论何时何地,总有需要用到打印的地方,包括资料文件、统计报表等等。 但若人在外地或分公司,有文件急需通过总部的打印机进行打印时,由于不在同一物理网络…

CSS中你不得不知道的盒子模型

目录 1、CSS的盒子模型1.1 css盒子模型有哪些:1.2 css盒子模型的区别1.3 通过css如何转换css盒子模型 1、CSS的盒子模型 1.1 css盒子模型有哪些: 标准盒子模型、怪异盒子模型(IE盒子模型) 1.2 css盒子模型的区别 标准盒子模型&a…

【ES系列】(一)简介与安装

首发博客地址 首发博客地址[1] 系列文章地址[2] 教学视频[3] 为什么要学习 ES? 强大的全文搜索和检索功能:Elasticsearch 是一个开源的分布式搜索和分析引擎,使用倒排索引和分布式计算等技术,提供了强大的全文搜索和检索功能。学习 ES 可以掌…

微机原理 || 第2次测试:汇编指令(加减乘除运算,XOR,PUSH,POP,寻址方式,物理地址公式,状态标志位)(测试题+手写解析)

(一)测试题目: 1.数[X]补1111,1110B,则其真值为 2.在I/O指令中,可用于表示端口地址的寄存器 3. MOV AX,[BXSl]的指令中,源操作数的物理地址应该如何计算 4.执行以下两条指令后,标志寄存器FLAGS的六个状态…

Go 面向对象(匿名字段)

概述 严格意义上说,GO语言中没有类(class)的概念,但是我们可以将结构体比作为类,因为在结构体中可以添加属性(成员),方法(函数)。 面向对象编程的好处比较多,我们先来说一下“继承…

python爬虫-数据解析BeautifulSoup

1、基本简介 BeautifulSoup简称bs4,BeautifulSoup和lxml一样是一个html的解析器,主要功能也是解析和提取数据。 BeautifulSoup和lxml类似,既可以解析本地文件也可以响应服务器文件。 缺点:效率没有lxml的效率高 。 优点:接口设…

stm32之IIC协议

主要通过两个层面来讲:物理层、协议层。 IIC是一个同步半双工串行总线协议。 一、物理层(通信模型) 1、最早是飞利浦公司开发的这个协议,最早应用到其产品上去。 2、两线制(两根信号线) 其中SCL为时钟…

vue的第2篇 第一个vue程序

一 环境的搭建 1.1常见前端开发ide 1.2 安装vs.code 1.下载地址:Visual Studio Code - Code Editing. Redefined 2.进行安装 1.2.1 vscode的中文插件安装 1.在搜索框输入“chinese” 2.安装完成重启,如下变成中文 1.2.2 修改工作区的颜色 选中[浅色]…

MySQL8.0.22安装过程记录(个人笔记)

1.点击下载MySQL 2.解压到本地磁盘(注意路径中不要有中文) 3.在解压目录创建my.ini文件 文件内容为 [mysql] # 设置mysql客户端默认字符集 default-character-setutf8[mysqld] # 设置端口 port 3306 # 设计mysql的安装路径 basedirE:\01.app\05.Tool…

win10安装Docker Desktop,并修改存储目录

安装之前先看看自己电脑c盘剩余容量,如果小于30G,建议先配置下再安装 因为docker 安装时不提供指定安装路径和数据存储路径的选项,且默认是安装在C盘的。C盘比较小的,等docker运行久了,一大堆的东西放在上面容易导致磁…

视频监控人员行为识别算法

视频监控人员行为识别算法通过opencvpython网络模型框架算法,视频监控人员行为识别算法可以识别和判断员工的行为是否符合规范要求,一旦发现不符合规定的行为,视频监控人员行为识别算法将自动发送告警信息。OpenCV的全称是Open Source Comput…

almaLinux 8 安装 xxdiff 5.1

almaLinux 安装 xxdiff XXdiff——比较和合并工具下载安装安装qt5 XXdiff——比较和合并工具 XXdiff是一款免费、强大的文件和目录比较及合并工具,可以在类似Unix的操作系统上运行,比如Linux、Solaris、HP/UX、IRIX和DEC Tru64。XXdiff的一大局限就是不…

栈和队列篇

目录 一、栈 1.栈的概念及结构 1.1栈的概念 1.2栈的结构示意图 2.栈的实现 2.1支持动态增长的栈的结构 2.2压栈(入栈) 2.3出栈 2.4支持动态增长的栈的代码实现 二、队列 1.队列的概念及结构 1.1队列的概念 1.2队列的结构示意图 2.队列的实…

设计模式-适配器

文章目录 一、简介二、适配器模式基础1. 适配器模式定义与分类2. 适配器模式的作用与优势3.UML图 三、适配器模式实现方式1. 类适配器模式2. 对象适配器模式3.类适配器模式和对象适配器模式对比 四、适配器模式应用场景1. 继承与接口的适配2. 跨平台适配 五、适配器模式与其他设…

滑动窗口和双指针

滑动窗口和双指针 一、循环不变量1.1 定义1.2 总结 二、使用循环不变量写对代码2.1 注意2.2 总结 三、滑动窗口3.1 固定长度的滑动窗口(同向交替移动的两个变量)3.2 不定长度的滑动窗口3.2.1 定义3.2.2 总结 3.3 计数问题3.3.1 标准3.3.2 总结 3.4 使用数…

JavaScript【转】

以下内容转载和参考自:w3school的JavaScript学习内容,HTML JavaScript。 JavaScript 使 HTML 页面更具动态性和交互性,前面我们都是在代码中一开始就将元素的值、属性、style样式写死,使用JavaScript 的话就可以对这些内容动态的更…

ChatGPT癌症治疗“困难重重”,真假混讲难辨真假,准确有待提高

近年来,人工智能在医疗领域的应用逐渐增多,其中自然语言处理模型如ChatGPT在提供医疗建议和信息方面引起了广泛关注。然而,最新的研究表明,尽管ChatGPT在许多领域取得了成功,但它在癌症治疗方案上的准确性仍有待提高。…