torch.isclose
是 PyTorch 中用于比较两个张量是否“近似相等”的函数。它主要用于判断两个张量的对应元素在数值上是否接近(考虑了浮点数精度的可能误差)。
函数定义
torch.isclose(input, other, rtol=1e-05, atol=1e-08, equal_nan=False)
参数说明
-
input
(Tensor):- 第一个张量。
-
other
(Tensor):- 第二个张量,和
input
的形状必须相同,或者可以通过广播机制与input
对齐。
- 第二个张量,和
-
rtol
(float, 可选,默认值:1e-05
):- 相对容忍误差(relative tolerance)。比较时的相对误差阈值,定义了两个值相对距离的可接受范围。
-
atol
(float, 可选,默认值:1e-08
):- 绝对容忍误差(absolute tolerance)。比较时的绝对误差阈值,定义了两个值绝对距离的可接受范围。
-
equal_nan
(bool, 可选,默认值:False
):- 是否将 NaN 视为“接近”。
- 如果为
True
,则两个 NaN 会被认为是相等的。
返回值
- 返回一个与
input
和other
形状相同的布尔张量。 - 每个元素表示
input
和other
对应位置的元素是否“近似相等”。
比较规则
两个元素 aa 和 bb 被认为是“近似相等”的条件是:
|a - b|
: 表示input
和other
对应元素之间的绝对差值。atol
: 绝对误差阈值。rtol
: 相对误差阈值。
常见用途
- 比较浮点数是否相等(避免浮点数精度误差)。
- 检查数值计算中张量的结果是否一致或接近。
- 判断两个张量之间的元素是否在某个容忍范围内。
示例
import torch
# 创建两个浮点张量
a = torch.tensor([1.0, 2.0, 3.0001])
b = torch.tensor([1.0, 2.0, 3.0])
# 默认参数下比较
result = torch.isclose(a, b)
print(result) # tensor([True, True, False])
# 调整容忍误差
result = torch.isclose(a, b, rtol=1e-03, atol=1e-05)
print(result) # tensor([True, True, True])
# 比较含有 NaN 的张量
a = torch.tensor([1.0, float('nan')])
b = torch.tensor([1.0, float('nan')])
# 默认不认为 NaN 相等
print(torch.isclose(a, b)) # tensor([ True, False])
# 允许 NaN 相等
print(torch.isclose(a, b, equal_nan=True)) # tensor([ True, True])
注意事项
-
rtol
和atol
的选择:- 如果数值范围较大,可以增加
rtol
。 - 如果数值精度要求较高,可以减小
atol
。
- 如果数值范围较大,可以增加
-
NaN 的比较:
- 默认情况下,
torch.isclose
认为 NaN 是不相等的,除非equal_nan=True
。
- 默认情况下,