深入解析全连接层:PyTorch 中的 nn.Linear、nn.Parameter 及矩阵运算

文章目录

这篇文章会从基础的一个数学概念到对应的代码实现,你将了解到:

  • 为什么nn.Parameter()接受 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)作为参数?
  • 为什么不是torch.matmul(self.weight, x) + self.bias
  • 如何使用torch.matmul()@F.linear() 去等价地实现nn.Linear()的输出。

数学概念(全连接层,线性层)

线性变化是数学中一个基础的概念,它描述了如何通过线性变换将输入映射到输出。在线性代数中,线性变化通常表示为矩阵乘法。在神经网络中,线性层的核心就是实现这样的矩阵运算。

数学公式:

给定一个输入向量 x ∈ R n \mathbf{x} \in \mathbb{R}^n xRn 和一个输出向量 y ∈ R m \mathbf{y} \in \mathbb{R}^m yRm,线性变化通过矩阵 W ∈ R m × n \mathbf{W} \in \mathbb{R}^{m \times n} WRm×n 和偏置项 b ∈ R m \mathbf{b} \in \mathbb{R}^m bRm 进行变换,其公式为:
y = W x + b \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} y=Wx+b

  • W \mathbf{W} W:是权重矩阵,维度为 m × n m \times n m×n,它决定了输入向量如何线性变换到输出空间;
  • x \mathbf{x} x:是输入向量,维度为 n n n,表示特征数据;
  • b \mathbf{b} b:是偏置向量,维度为 m m m,用来调整线性变换的输出;
  • y \mathbf{y} y:是输出向量,维度为 m m m,是变换后的结果。

例子:

如果输入向量 x \mathbf{x} x 有 3 个特征,输出向量 y \mathbf{y} y 有 2 个特征,则权重矩阵 W \mathbf{W} W 的形状为 2 × 3 2 \times 3 2×3。假设:
W = [ 1 2 3 4 5 6 ] , x = [ 1 2 3 ] , b = [ 0 1 ] \mathbf{W} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix}, \quad \mathbf{x} = \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}, \quad \mathbf{b} = \begin{bmatrix} 0 \\ 1 \end{bmatrix} W=[142536],x= 123 ,b=[01]
线性变换计算为:
y = W x + b = [ 1 2 3 4 5 6 ] [ 1 2 3 ] + [ 0 1 ] = [ 14 32 ] + [ 0 1 ] = [ 14 33 ] \mathbf{y} = \mathbf{W} \mathbf{x} + \mathbf{b} = \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} + \begin{bmatrix} 0 \\ 1 \end{bmatrix} = \begin{bmatrix} 14 \\ 33 \end{bmatrix} y=Wx+b=[142536] 123 +[01]=[1432]+[01]=[1433]
矩阵运算过程:
[ 1 2 3 4 5 6 ] [ 1 2 3 ] = [ ( 1 × 1 ) + ( 2 × 2 ) + ( 3 × 3 ) ( 4 × 1 ) + ( 5 × 2 ) + ( 6 × 3 ) ] = [ 14 32 ] \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{bmatrix} \begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix} = \begin{bmatrix} (1 \times 1) + (2 \times 2) + (3 \times 3) \\ (4 \times 1) + (5 \times 2) + (6 \times 3) \end{bmatrix} = \begin{bmatrix} 14 \\ 32 \end{bmatrix} [142536] 123 =[(1×1)+(2×2)+(3×3)(4×1)+(5×2)+(6×3)]=[1432]

nn.Linear()

nn.Linear() 会自动创建一个权重矩阵(Weight)和偏置项(Bias),并将它们应用到输入上。

代码示例:

import torch
import torch.nn as nn

# 定义一个输入为3,输出为2的线性层
linear_layer = nn.Linear(3, 2)

# 打印权重矩阵和偏置项
print("权重矩阵 W:")
print(linear_layer.weight)

print("偏置项 b:")
print(linear_layer.bias)

# 模拟输入向量
input_vector = torch.tensor([1.0, 2.0, 3.0])
output_vector = linear_layer(input_vector)
print("输出向量 y:")
print(output_vector)

image-20240912221728559

在这里,nn.Linear(3, 2) 创建了一个 2×3 的权重矩阵和一个 2 维的偏置向量。通过 linear_layer(input_vector),可以直接获得输入向量经过线性变换后的输出。

nn.Parameter()

在 PyTorch 中,nn.Linear() 自动处理了权重和偏置项的初始化和更新,但有时你可能希望对这些参数自定义一些操作,比如 LoRA。这时,我们可以使用 nn.Parameter() 来自定义权重和偏置,其实 nn.Linear() 本身就是使用的nn.Parameter(),感兴趣的话可以看官方源码。

以自定义线性层为例:

class CustomLinearLayer(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(CustomLinearLayer, self).__init__()
        # 使用 nn.Parameter 手动定义权重和偏置
        self.weight = nn.Parameter(torch.randn(output_dim, input_dim))
        self.bias = nn.Parameter(torch.randn(output_dim))

    def forward(self, x):
        # 手动实现线性变换 y = Wx + b
        return torch.matmul(x, self.weight.T) + self.bias

# 使用自定义的线性层
custom_layer = CustomLinearLayer(3, 2)
output = custom_layer(input_vector)
print(output)

image-20240912222625609

在看完代码后,你可能会产生两个疑惑:

Q

1. 为什么 self.weight 的权重矩阵 shape 使用 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)而不是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)?

实际这正是我写这篇博客分享的原因,现在进入正题,因为这个疑惑完全不应该产生。

让我们重新使用 in_features \text{in\_features} in_features out_features \text{out\_features} out_features来重现之前的数学定义:

对于输入向量 x ∈ R in_features \mathbf{x} \in \mathbb{R}^{\text{in\_features}} xRin_features,全连接层的输出为:

y = W x + b \mathbf{y} = W\mathbf{x} + \mathbf{b} y=Wx+b

其中:

  • W ∈ R out_features × in_features W \in \mathbb{R}^{\text{out\_features} \times \text{in\_features}} WRout_features×in_features 是权重矩阵,
  • b ∈ R out_features \mathbf{b} \in \mathbb{R}^{\text{out\_features}} bRout_features 是偏置项。

在线性变换中,输入向量 x \mathbf{x} x 的维度是 in_features \text{in\_features} in_features,而输出向量 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features。根据矩阵乘法的规则,要将输入 x \mathbf{x} x 映射到输出 y \mathbf{y} y,权重矩阵 W W W 的形状应该是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features),因为矩阵乘法中 W x W\mathbf{x} Wx的维度要求是:

( out_features × in_features ) × ( in_features × 1 ) = ( out_features × 1 ) (\text{out\_features} \times \text{in\_features}) \times (\text{in\_features} \times 1) = (\text{out\_features} \times 1) (out_features×in_features)×(in_features×1)=(out_features×1)

这保证了输出 y \mathbf{y} y 的维度是 out_features \text{out\_features} out_features

如果权重矩阵的形状是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features),矩阵乘法的维度将不匹配,无法实现线性变换。

现在是不是感觉清晰了?不要 nn.Linear(in_feature, out_feature) 用多了就将权重矩阵当作是 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)遗忘了线性代数的概念,数学才是这一切的基石。

2. 为什么是torch.matmul(x, self.weight.T) + self.bias 而不是torch.matmul(self.weight, x) + self.bias?

主要原因还是在于 输入张量 x 的形状矩阵乘法规则

一般来说,模型的输入 x 实际上并不是 ( in_features , 1 ) (\text{in\_features}, 1) (in_features,1),而是 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features),而权重矩阵 self.weight 的形状是 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)​,我们需要实现的线性变换是:
y = W x + b y = W x + b y=Wx+b
根据矩阵乘法规则,第一个矩阵的列数必须等于第二个矩阵的行数,这意味着我们不能直接计算 torch.matmul(self.weight, x),因为这样会导致维度不匹配:

  • self.weight 形状为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features)x 形状为 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features)
  • torch.matmul(self.weight, x) 的维度计算规则将要求 x 的形状为 ( in_features , batch_size ) (\text{in\_features}, \text{batch\_size}) (in_features,batch_size),但这与模型的输入不匹配。

因此,正确的矩阵乘法应该是 torch.matmul(x, self.weight.T),其中 self.weight.T 表示 self.weight 的转置矩阵,此时的形状为 ( in_features , out_features ) (\text{in\_features}, \text{out\_features}) (in_features,out_features)

这样,torch.matmul(x, self.weight.T) 的维度计算为:

( batch_size , in_features ) × ( in_features , out_features ) = ( batch_size , out_features ) (\text{batch\_size}, \text{in\_features}) \times (\text{in\_features}, \text{out\_features}) = (\text{batch\_size}, \text{out\_features}) (batch_size,in_features)×(in_features,out_features)=(batch_size,out_features)

这就得到了正确的输出形状 ( batch_size , in_features ) (\text{batch\_size}, \text{in\_features}) (batch_size,in_features)

3. 为什么不直接设置self.weight = nn.Parameter(torch.randn(input_dim, output_dim))

这样不就可以不转置直接使用torch.matmul(x, self.weight)了吗?的确如此,或许是因为 ( out_features , in_features ) (\text{out\_features}, \text{in\_features}) (out_features,in_features) 对于矩阵运算 W x W\mathbf{x} Wx 来讲更符合直觉吧。

计算过程的细分:torch.matmul() vs @ 运算符

在 PyTorch 中,torch.matmul() 用于实现矩阵乘法,而 @ 是其简洁的符号形式,是 Python 的语法糖,二者在功能上是等价的。

示例代码:

import torch

# 定义权重矩阵 W 和输入向量 input_vector
W = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
input_vector = torch.tensor([1.0, 2.0, 3.0])

# 使用 torch.matmul 实现矩阵乘法
result1 = torch.matmul(W, input_vector)

# 使用 @ 运算符
result2 = W @ input_vector

print("使用 torch.matmul 计算的结果:")
print(result1)

print("使用 @ 运算符计算的结果:")
print(result2)

结果:

image-20240912233355773

使用 F.linear()

PyTorch 提供了 F.linear() 作为函数式接口,它与 nn.Linear() 类似,但不需要创建一个线性层对象。F.linear() 可以接受线性层的权重和偏置作为输入,在下一篇有关 LoRA 的文章中,你将看到使用范例。

示例代码:

import torch.nn.functional as F

# 使用 F.linear 进行线性变换
output = F.linear(input_vector, linear_layer.weight, linear_layer.bias)
print(output)

image-20240912233501651

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/875162.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Bev pool 加速(2):自定义c++扩展

文章目录 1. c++扩展2. 案例2.1 案例12. 1.1 代码实现(1) c++ 文件(2) setup.py编写(3) python 代码编写2.2 案例22.2.1 模型搭建2.2.2 c++ 扩展实现(1)c++ 扩展代码(2)setup.py编写(3)python 调用c++扩展在bevfusion论文中,将bev_pooling定义为view transform中的效率瓶…

PROTOTYPICAL II - The Practice of FPGA Prototyping for SoC Design

The Art of the “Start” The semiconductor industry revolves around the “start.” Chip design starts lead to more EDA tool purchases, more wafer starts, and eventually to more product shipments. Product roadmaps develop to extend shipments by integrating…

FloodFill算法

文章目录 1. 图像渲染(733)2. 岛屿数量(200)3. 岛屿的最大面积(695)4. 被围绕的区域(130) 1. 图像渲染(733) 题目描述: 算法原理: …

DAY13信息打点-Web 应用源码泄漏开源闭源指纹识别GITSVNDS备份

#知识点 0、Web架构资产-平台指纹识别 1、开源-CMS指纹识别源码获取方式 2、闭源-习惯&配置&特性等获取方式 3、闭源-托管资产平台资源搜索监控 演示案例: ➢后端-开源-指纹识别-源码下载 ➢后端-闭源-配置不当-源码泄漏 ➢后端-方向-资源码云-源码泄漏 …

1、https的全过程

目录 一、概述二、SSL过程如何获取会话秘钥1、首先认识几个概念:2、没有CA机构的SSL过程:3、没有CA机构下的安全问题4、有CA机构下的SSL过程 一、概述 https是非对称加密和对称加密的过程,首先建立https链接需要经过两轮握手: T…

算法提高模板强连通分量tarjan算法

AC代码&#xff1a; #include<bits/stdc.h>using namespace std;typedef long long ll; const int MOD 998244353; const int N 2e5 10;//强联通分量模板 //tarjan算法 vector<int>e[N]; int n, m, cnt; int dfn[N], low[N], ins[N], idx; int bel[N];//记录每…

Redis高可用,Redis性能管理

文章目录 一&#xff0c;Redis高可用&#xff0c;Redis性能管理二&#xff0c;Redis持久化1.RDB持久化1.1触发条件&#xff08;1&#xff09;手动触发&#xff08;2&#xff09;自动触发 1.2 Redis 的 RDB 持久化配置1.3 RDB执行流程(1) 判断是否有其他持久化操作在执行(2) 父进…

Chainlit集成Langchain并使用通义千问实现文生图网页应用

前言 本文教程如何使用通义千问的大模型服务平台的接口&#xff0c;实现图片生成的网页应用&#xff0c;主要用到的技术服务有&#xff0c;chainlit 、 langchain、 flux。合利用了大模型的工具选择调用能力。实现聊天对话生成图片的网页应用。 阿里云 大模型服务平台百炼 API…

R语言统计分析——功效分析3(相关、线性模型)

参考资料&#xff1a;R语言实战【第2版】 1、相关性 pwr.r.test()函数可以对相关性分析进行功效分析。格式如下&#xff1a; pwr.r.test(n, r, sig.level, power, alternative) 其中&#xff0c;n是观测数目&#xff0c;r是效应值&#xff08;通过线性相关系数衡量&#xff0…

NAT技术+代理服务器+内网穿透

NAT技术 IPv4协议中&#xff0c;会存在IP地址数量不充足的问题&#xff0c;所以不同的子网中会存在相同IP地址的主机。那么就可以理解为私有网络的IP地址并不是唯一对应的&#xff0c;而公网中的IP地址都是唯一的&#xff0c;所以NAT&#xff08;Network Address Translation&…

跨平台集成:在 AI、微服务和 Azure 云之间实现无缝工作流

跨平台集成在现代 IT 架构中的重要性 随着数字化转型的不断加速&#xff0c;对集成各种技术平台的需求也在快速增长。在当今的数字世界中&#xff0c;组织在复杂的环境中执行运营&#xff0c;其中多种技术需要无缝协作。环境的复杂性可能取决于业务的性质和组织提供的服务。具…

Windows更新之后任务栏卡死?桌面不断闪屏刷新?

前言 小白这几天忙于工作&#xff0c;更新就变得异常缓慢。但就算这么忙的情况下&#xff0c;晚上休息的时间还是会给小伙伴们提供咨询和维修服务。 这不&#xff0c;就有一个小伙伴遇到了个很奇怪的问题&#xff1a;电脑Windows更新之后&#xff0c;任务栏点了没反应&#xf…

深入理解Java虚拟机:Jvm总结-Java内存区域与内存溢出异常

第二章 Java内存区域与内存溢出异常 2.1 意义 对于C、C程序开发来说&#xff0c;程序员需要维护每一个对象从开始到终结。Java的虚拟自动内存管理机制&#xff0c;让java程序员不需要手写delete或者free代码&#xff0c;不容易出现内存泄漏和内存溢出问题&#xff0c;但是如果…

概念——二连杆与三连杆解算

https://zhuanlan.zhihu.com/p/161348413 跟着这位大佬学的&#xff0c;侵权删除 计划&#xff1a;先行理解概念以及推演过程&#xff0c;然后自行推导一遍&#xff08;确认理解&#xff09;&#xff0c;之后思考如何实际调用&#xff0c;目前在第一步 目录 01-二连杆 1-正…

腾讯云2024年数字生态大会开发者嘉年华(数据库动手实验)TDSQL-C初体验

在2024年9月5-6日&#xff0c;有幸参加了腾讯云举办的2024年数字生态大会开发者嘉年华。 有幸体验了腾讯的多项黑科技和云计算知识。特别是在“增一行代码”互动展区&#xff0c;体验了腾讯云云计算数据库TDSQL-C技术并进行了动手实验。这些技术充分展示了腾讯在云计算的强大实…

关于OceanBase 多模一体化的浅析

在当今多元化的业务生态中&#xff0c;各行各业对数据库系统的需求各有侧重。举例来说&#xff0c;金融风控领域对数据库的高效事务处理&#xff08;TP&#xff09;和分析处理&#xff08;AP&#xff09;能力有着严格要求&#xff1b;游戏行业则更加注重文档数据库的灵活性和性…

购物车装载状态检测系统源码分享

购物车装载状态检测检测系统源码分享 [一条龙教学YOLOV8标注好的数据集一键训练_70全套改进创新点发刊_Web前端展示] 1.研究背景与意义 项目参考AAAI Association for the Advancement of Artificial Intelligence 项目来源AACV Association for the Advancement of Comput…

有关WSL和docker的介绍

目录标题 如何利用在windows上配置docker实现linux和windows容器修改WSL默认安装&#xff08;也就是linux子系统&#xff09;目录到其他盘 如何利用在windows上配置docker实现linux和windows容器 wsl的基本命令&#xff1a;参考网页 docker入门到实践&#xff1a;参考网页 官方…

tekton pipeline workspaces

tekton pipeline workspace是一种为执行中的管道及其任务提供可用的共享卷的方法。 在pipeline中定义worksapce作为共享卷传递给相关的task。在tekton中定义workspace的用途有以下几点: 存储输入和/或输出在task之间共享数据secret认证的挂载点ConfigMap中保存的配置的挂载点…

影刀RPA实战:自动化批量生成条形码完整指南

今天我们聊聊使用影刀来实现批量生成条形码&#xff0c;条形码在零售行业运用非常广泛&#xff0c;主要作用表现在产品识别&#xff0c;库存管理&#xff0c;销售管理&#xff0c;防伪保护等&#xff0c;这些作用使其成为现代商业和工业环境中不可或缺的工具&#xff0c;它极大…