文章目录
- 解释
- 代码举例
解释
torch.max
是 PyTorch 中的一个函数,用于在张量中沿指定维度计算最大值。它有两种用法:
① 如果只提供一个输入张量,则返回该张量中的最大值和对应的索引。
② 如果提供两个输入张量,则返回两个张量中对应位置的较大值。
深度学习中主要使用第一种用法,下面对该用法举例说明:
代码举例
import torch
# 创建一个张量
# tensor = torch.rand(1, 4, 3, 3)
tensor = torch.tensor(
[[[[2, 2, 0.7944],
[2, 0.6368, 0.6928],
[0.9620, 0.5716, 0.3827]],
[[0.6216, 0, 1],
[0.0588, 1, 0.0718],
[1, 0.1084, 0.0462]],
[[0.3117, 0.3333, 0.655],
[0.8207, 0.5918, 3],
[0.6565, 3, 0.2866]],
[[0.6613, 0.1222, 0.0590],
[0.4555, 0.0166, 0.0838],
[0.3797, 0.6666, 4]]]])
# print(tensor)
print("原张量的shape为:", tensor.shape, '\n')
# 计算整个张量中的最大值和对应的索引
max_value, max_indices = torch.max(tensor, dim=1)
print("max_value:\n", max_value) # 输出第二个维度上的最大值
print("max_indices:\n", max_indices, '\n') # 输出第二个维度上最大值的索引
print("max_value.shape为:", max_value.shape) # 输出每行的最大值
print("max_indices.shape为:", max_indices.shape) # 输出每行最大值的索引
运行结果: