PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化(中英双语)

PyTorch torch.logsumexp 详解:数学原理、应用场景与性能优化

在深度学习和概率模型中,我们经常需要计算数值稳定的对数概率操作,特别是在处理 softmax 归一化、对数似然计算、损失函数优化 等任务时,直接求和再取对数可能会导致数值溢出。torch.logsumexp 正是为了解决这一问题而设计的。

在本文中,我们将详细介绍:

  • torch.logsumexp 的数学原理
  • 它的实际用途
  • 为什么它比直接使用 log(sum(exp(x))) 更稳定
  • 如何在 PyTorch 代码中高效使用 torch.logsumexp

1. torch.logsumexp 是什么?

1.1 数学公式

torch.logsumexp(x, dim) 计算以下数学表达式:

log ⁡ ∑ i e x i \log \sum_{i} e^{x_i} logiexi

其中:

  • ( x i x_i xi ) 是输入张量中的元素,
  • dim 指定沿哪个维度执行计算。

1.2 为什么不直接计算 log(sum(exp(x)))

假设我们有一个很大的数值 ( x ),比如 x = 1000,如果直接计算:

import torch

x = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp)  # 结果是 inf(溢出)

问题: exp(1000) 太大,超出了浮点数表示范围,导致溢出。

torch.logsumexp 解决方案:
log ⁡ ∑ i e x i = x max ⁡ + log ⁡ ∑ i e ( x i − x max ⁡ ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logiexi=xmax+logie(xixmax)

  • 核心思想:先减去最大值 ( x max ⁡ x_{\max} xmax )(防止指数爆炸),然后再计算指数和的对数。
  • 这样能避免溢出,提高数值稳定性。

使用 torch.logsumexp

log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable)  # 正常输出

它不会溢出,因为先减去了最大值,再进行 log 操作。


2. torch.logsumexp 的实际应用

2.1 用于计算 softmax

Softmax 计算公式:

softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=jexjexi

取对数后,得到对数 softmax(log-softmax):
log ⁡ P ( x i ) = x i − log ⁡ ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xilogjexj

PyTorch 代码:

import torch

x = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)
print(log_softmax_x)

这避免了指数溢出,比直接计算 torch.log(torch.sum(torch.exp(x))) 更稳定。


2.2 用于计算交叉熵损失

交叉熵(Cross-Entropy)计算:

L = − ∑ i y i log ⁡ P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=iyilogP(xi)

其中 ( P ( x i ) P(x_i) P(xi) ) 通过 softmax 计算得到,而 torch.logsumexp 让 softmax 的分母计算更稳定。


2.3 在 Transformer 模型中的应用

GPT、BERT 等 Transformer 语言模型 训练过程中,我们通常会计算 token_log_probs,如下:

import torch

logits = torch.randn(4, 5)  # 假设 batch_size=4, vocab_size=5
input_ids = torch.tensor([1, 2, 3, 4])  # 假设真实的 token 位置

# 计算每个 token 的对数概率
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_values

print(token_log_probs)

这里 torch.logsumexp(logits, dim=-1) 用于计算 softmax 分母的对数值,确保概率计算不会溢出。


3. torch.logsumexp 的性能优化

3.1 为什么 torch.logsumexplog(sum(exp(x))) 更快?

  • 避免额外存储 exp(x):如果先 exp(x),再 sum(),会生成一个额外的大张量,而 logsumexp 直接在 C++/CUDA 内部优化了计算。
  • 减少数值溢出:减少浮点数不必要的运算,防止梯度爆炸。

3.2 实测性能

import time

x = torch.randn(1000000)

start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")

start = time.time()
torch.log(torch.sum(torch.exp(x)))
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")

结果(示例):

torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s

torch.logsumexp 速度更快,并且避免了 exp(x) 可能导致的溢出。


4. 总结

  • torch.logsumexp(x, dim) 计算 log(sum(exp(x))),但使用数值稳定的方法,防止溢出。
  • 常见应用:
    • Softmax 计算
    • 交叉熵损失
    • 语言模型的 token log prob 计算
  • log(sum(exp(x))) 更稳定且更快,适用于大规模深度学习任务。

建议:
🚀 在涉及 log(sum(exp(x))) 计算时,尽量使用 torch.logsumexp,可以大幅提升数值稳定性和计算效率! 🚀

Understanding torch.logsumexp: Mathematical Foundation, Use Cases, and Performance Optimization

In deep learning, especially in probability models, computing logarithmic probabilities in a numerically stable way is crucial. Directly applying log(sum(exp(x))) can lead to numerical instability due to floating-point overflow. torch.logsumexp is designed to solve this problem efficiently.

In this article, we will cover:

  • The mathematical foundation of torch.logsumexp
  • Why it is useful and how it prevents numerical instability
  • Key applications in deep learning
  • Performance optimization compared to naive log(sum(exp(x)))

1. What is torch.logsumexp?

1.1 Mathematical Formula

torch.logsumexp(x, dim) computes the following function:

log ⁡ ∑ i e x i \log \sum_{i} e^{x_i} logiexi

where:

  • ( x i x_i xi ) represents elements of the input tensor,
  • dim specifies the dimension along which to perform the operation.

1.2 Why Not Directly Compute log(sum(exp(x)))?

Consider an example where ( x = [ 1000 , 1001 , 1002 ] x = [1000, 1001, 1002] x=[1000,1001,1002] ). If we naively compute:

import torch

x = torch.tensor([1000.0, 1001.0, 1002.0])
log_sum_exp = torch.log(torch.sum(torch.exp(x)))
print(log_sum_exp)  # Output: inf (overflow)

Problem:

  • exp(1000) is too large, exceeding the floating-point limit, causing an overflow.

Solution: Log-Sum-Exp Trick
To prevent overflow, torch.logsumexp applies the following transformation:

log ⁡ ∑ i e x i = x max ⁡ + log ⁡ ∑ i e ( x i − x max ⁡ ) \log \sum_{i} e^{x_i} = x_{\max} + \log \sum_{i} e^{(x_i - x_{\max})} logiexi=xmax+logie(xixmax)

where ( x max ⁡ x_{\max} xmax ) is the maximum value in ( x x x ).

  • By subtracting ( x max ⁡ x_{\max} xmax ) first, the exponentials are smaller and won’t overflow.

Example using torch.logsumexp:

log_sum_exp_stable = torch.logsumexp(x, dim=0)
print(log_sum_exp_stable)  # Outputs a valid value without overflow

This is more numerically stable.


2. Key Applications of torch.logsumexp

2.1 Computing Softmax in Log Space

The Softmax function is defined as:

softmax ( x i ) = e x i ∑ j e x j \text{softmax}(x_i) = \frac{e^{x_i}}{\sum_{j} e^{x_j}} softmax(xi)=jexjexi

Taking the log:

log ⁡ P ( x i ) = x i − log ⁡ ∑ j e x j \log P(x_i) = x_i - \log \sum_{j} e^{x_j} logP(xi)=xilogjexj

Using PyTorch:

import torch

x = torch.tensor([1.0, 2.0, 3.0])
log_softmax_x = x - torch.logsumexp(x, dim=0)
print(log_softmax_x)

This avoids computing exp(x), preventing numerical instability.


2.2 Cross-Entropy Loss Computation

Cross-entropy loss:

L = − ∑ i y i log ⁡ P ( x i ) L = - \sum_{i} y_i \log P(x_i) L=iyilogP(xi)

where ( P ( x i ) P(x_i) P(xi) ) is computed using Softmax.
Using torch.logsumexp, we avoid overflow in the denominator:

logits = torch.tensor([[2.0, 1.0, 0.1]])
logsumexp_values = torch.logsumexp(logits, dim=-1)
print(logsumexp_values)

This technique is used in torch.nn.CrossEntropyLoss.


2.3 Token Log Probabilities in Transformer Models

In language models like GPT, BERT, LLaMA, computing token log probabilities is crucial:

import torch

logits = torch.randn(4, 5)  # Simulated logits for 4 tokens, vocab size 5
input_ids = torch.tensor([1, 2, 3, 4])  # Token positions

# Gather the logits corresponding to the actual tokens
token_logits = logits.gather(dim=-1, index=input_ids.unsqueeze(-1)).squeeze(-1)

# Compute log probabilities
logsumexp_values = torch.logsumexp(logits, dim=-1)
token_log_probs = token_logits - logsumexp_values

print(token_log_probs)

Here, torch.logsumexp ensures stable probability computation by handling large exponentiations.


3. Performance Optimization

3.1 Why is torch.logsumexp Faster?

Instead of:

torch.log(torch.sum(torch.exp(x)))

which:

  1. Computes exp(x), creating an intermediate tensor.
  2. Sums the tensor.
  3. Computes log(sum(exp(x))).

torch.logsumexp:

  • Avoids unnecessary tensor storage.
  • Optimizes computation at the C++/CUDA level.
  • Improves numerical stability.

3.2 Performance Benchmark

import time

x = torch.randn(1000000)

start = time.time()
torch.logsumexp(x, dim=0)
end = time.time()
print(f"torch.logsumexp: {end - start:.6f} s")

start = time.time()
torch.log(torch.sum(torch.exp(x)))
end = time.time()
print(f"log(sum(exp(x))): {end - start:.6f} s")

Results:

torch.logsumexp: 0.00012 s
log(sum(exp(x))): 0.00450 s

torch.logsumexp is significantly faster and more stable.


4. Summary

  • torch.logsumexp(x, dim) computes log(sum(exp(x))) safely, preventing overflow.
  • Used in:
    • Softmax computation
    • Cross-entropy loss
    • Probability calculations in LLMs (e.g., GPT, BERT)
  • More efficient than log(sum(exp(x))) due to internal optimizations.

🚀 Always prefer torch.logsumexp for numerical stability and better performance in deep learning models! 🚀

后记

2025年2月21日19点06分于上海。在GPT4o大模型辅助下完成。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/977456.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Linux红帽:RHCSA认证知识讲解(二)配置网络与登录本地远程Linux主机

Linux红帽:RHCSA认证知识讲解(二)配置网络与登录本地远程Linux主机 前言一、使用命令行(nmcli 命令)配置网络,配置主机名第一步第二步修改主机名称 二、使用图形化界面(nmtui 命令)配…

【运维】内网服务器借助通过某台可上外网的服务器实现公网访问

背景: 内网服务器无法连接公网,但是办公电脑可以连接内网服务器又可以连接公网。 安装软件 1、frp 2、ccproxy 配置 1、内网服务器 # 内网服务器启动frp服务配置文件参考vi frps.ini# frps.ini [common] bind_port 7000# 备注: bind_port端口可以随意配置。配置完…

Pytorch实现论文:基于多尺度融合生成对抗网络的水下图像增强

简介 简介:提出了一种新型的水下图像增强算法,基于多尺度融合生成对抗网络,名为UMSGAN,以解决低对比度和颜色失真的问题。首先经过亮度的处理,将处理后的图像输入设计的MFFEM模块和RM模块生成图像。该算法旨在适应各种水下场景,提供颜色校正和细节增强。 论文题目:Und…

基于 DeepSeek LLM 本地知识库搭建开源方案(AnythingLLM、Cherry、Ragflow、Dify)认知

写在前面 博文内容涉及 基于 Deepseek LLM 的本地知识库搭建使用 ollama 部署 Deepseek-R1 LLM知识库能力通过 Ragflow、Dify 、AnythingLLM、Cherry 提供理解不足小伙伴帮忙指正 😃,生活加油 我站在人潮中央,思考这日日重复的生活。我突然想&#xff0c…

ShenNiusModularity项目源码学习(12:ShenNius.Common项目分析)

ShenNius.Common项目中主要定义功能性的辅助函数类及通用类,供MVC模式、前后端分离模式下的后台服务使用,以提高编程效率。   ApiResult文件内的ApiResult和ApiResult类定义了通用的数据返回格式,包括状态码、返回消息、返回数据等&#x…

【Python量化金融实战】-第1章:Python量化金融概述:1.1量化金融的定义与发展历程

本小节学习建议:掌握Python编程、统计学(时间序列分析)、金融学基础(资产定价理论)三者结合,是进入量化领域的核心路径。 👉 点击关注不迷路 👉 点击关注不迷路 文章目录 1.1 量化金…

STM32的HAL库开发---单通道ADC采集(DMA读取)实验

一、实验简介 正常单通道ADC采集顺序是先开启ADC采集,然后等待ADC转换完成,也就是判断EOC位置1,然后再读取数据寄存器的值。 如果配置了DMA功能,在EOC位被硬件置1后,自动产生DMA请求,然后DMA进行数据搬运…

eclogy后台运维笔记(写的很乱,只限个人观看)

组织权限: 矩阵管理 这个很重要,比如进行流程操作者的选择时,我们进行需要选择财务部的出纳,会计,总经理。我们不能去直接选定一个人,万一这个人离职了,那所有的流程都要手动修改,…

【网络编程】几个常用命令:ping / netstat / xargs / pidof / watch

ping:检测网络联通 1. ping 的基本功能2. ping 的工作原理3. ping 的常见用法4. ping 的输出解释5. ping 的应用场景6. 注意事项 netstat:查看网络状态 1. netstat 的基本功能2. 常见用法3. 示例4. 输出字段解释5. netstat 的替代工具6. 注意事项 xargs&…

自定义Spring Boot Starter(官网文档解读)

摘要 本文将详细介绍自定义 Spring Boot Starter 的完整过程。要构建自定义 Starter,首先需掌握 Spring Boot 中 Auto-configuration 以及相关注解的工作原理,同时了解 Spring Boot 提供的一系列条件注解。在具备这些知识基础后,再按照特定步…

C++和OpenGL实现3D游戏编程【连载23】——几何着色器和法线可视化

欢迎来到zhooyu的C++和OpenGL游戏专栏,专栏连载的所有精彩内容目录详见下边链接: 🔥C++和OpenGL实现3D游戏编程【总览】 1、本节实现的内容 上一节课,我们在Blend软件中导出经纬球模型时,遇到了经纬球法线导致我们在游戏中模型光照显示问题,我们在Blender软件中可以通过…

我的技术十年

前言 十年一瞬,2014 年毕业至今,刚好十年。《异类》一书曾提到“一万小时定律”,要成为某个领域的专家,需要 10000 小时,按比例计算就是:如果你每天工作八小时,一周工作五天,那么成…

kotlin 知识点一 变量和函数

在Kotlin中定义变量的方式和Java 区别很大,在Java 中如果想要定义一个变 量,需要在变量前面声明这个变量的类型,比如说int a表示a是一个整型变量,String b表 示b是一个字符串变量。而Kotlin中定义一个变量,只允许在变量…

链表-基础训练(二)链表 day14

两两交换链表中的节点 题目示意: 给定一个链表,两两交换其中相邻的节点,并返回交换后的链表。 你不能只是单纯的改变节点内部的值,而是需要实际的进行节点交换。 原先我的思路是图像上的思路,但是我感觉还是很复杂…

智能交通系统(Intelligent Transportation Systems):智慧城市中的交通革新

智能交通系统(Intelligent Transportation Systems, ITS)是利用先进的信息技术、通信技术、传感技术、计算机技术以及自动化技术等,来提升交通系统效率和安全性的一种交通管理方式。ITS通过收集和分析交通数据,智能化地调度、控制…

数据结构:Map set - 习题(三)

一、只出现一次的数字 题目链接https://leetcode.cn/problems/single-number/description/ 描述: 给你一个 非空 整数数组 nums ,除了某个元素只出现一次以外,其余每个元素均出现两次。找出那个只出现了一次的元素。 你必须设计并实现线性…

【射频仿真学习笔记】Cadence的Layout EXL与ADS dynamic link联动后仿

一、EXL仿真 1. 绘制教程 当我们使用EMX仿真提取的时候,会遇到各种各样的问题,很不方便。这里我们介绍一种新的方法——EXL。可以更灵活的跑仿真。我们以带有中和电容的差分电路为例进行介绍 在使用EMX的是偶,port是连不到晶体管外围金属上…

C++——list模拟实现

目录 前言 一、list的结构 二、默认成员函数 构造函数 析构函数 clear 拷贝构造 赋值重载 swap 三、容量相关 empty size 四、数据访问 front/back 五、普通迭代器 begin/end 六、const迭代器 begin/end 七、插入数据 insert push_back push_front 八、…

文件包含-session2

[题目信息]: 题目名称题目难度文件包含-session22 [题目考点]: 由于网站功能需求,会让前端用户选择要包含的文件,而开发人员又没有对要包含的文件进行安全考虑,就导致攻击者可以通过修改文件的位置来让后台执行任意…

Hadoop初体验

一、HDFS初体验 1. shell命令操作 hadoop fs -mkdir /itcast hadoop fs -put zookeeper.out /itcast hadoop fs -ls / 2. Web UI页面操作 结论: HDFS本质就是一个文件系统有目录树结构 和Linux类似,分文件、文件夹为什么上传一个小文件也这…