torch.bincount()
函数是PyTorch中的一个函数,用于计算一维整数张量中每个非负整数值出现的频次
函数的用法 :
torch.bincount(input, weights=None, minlength=0) → Tensor
参数:
input
:输入的一维整数张量weights
(可选):与input
张量相同形状的张量,用于为每个值指定权重minlength
(可选):输出张量的最小长度
返回值:一个具有长度为max(input) + 1
的一维长整型张量,其中索引i
处的值表示i
在输入张量中出现的频次
函数说明:
torch.bincount()
函数 统计输入张量中每个非负整数值的频次。它适用于整数类型的张量,如torch.int8
、torch.int16
、torch.int32
、torch.int64
等- 输入张量可以是CPU上的张量,也可以是CUDA张量(GPU上的张量)
- 输出张量的长度是输入张量中的最大值加1,即
max(input) + 1
- 输出张量中的元素顺序与输入张量中的非负整数值顺序相同
例如:
import torch
input = torch.tensor([1, 2, 3, 2, 1, 1])
counts = torch.bincount(input)
print(counts) # 输出: tensor([0, 3, 2, 1])
rrint(counts[1:]) # 输出:tensor([3,2,1])
在上面示例中,有一个输入张量input
,包含一些非负整数值, 通过调用torch.bincount(input)
,计算了每个值在输入张量中出现的频次,得到了张量counts, counts[0]
为0,因为0在输入张量中没有出现;counts[1]
为3,因为1在输入张量中出现了3次,以此类推
注意:
在使用
torch.bincount()
函数时,它会计算一维整数张量中每个非负整数值的频次,包括最小值到最大值之间的所有整数值,即使某些整数值在输入张量中没有出现在上述的例子中,
input
是一维张量[1, 2, 3, 2, 1, 1],
虽然 0 在input
中没有出现,但torch.bincount(input)
仍会考虑到0的存在 ,输出结果为tensor([0, 3, 2, 1])
,其中索引0 表示 0 这个整数值在input
中出现的次数为0次,索引1出现了3次,索引2出现了2次,索引3出现了1次
torch.bincount()
的输出张量长度与输入张量中的最大整数值相关。对于输入张量input = torch.tensor([1, 2, 3, 2, 1, 1])
,它包含了整数值1、2和3,torch.bincount(input)
的输出张量将具有长度为4,对应索引0到索引3。具体来说,输出张量的长度由输入张量中的最大整数值加1决定在这个例子中,最大整数值是3,因此输出张量的长度为4
如果确保输入张量中不包含0,可以通过对输出进行切片来忽略索引0的值
例如,
counts[1:]
表示忽略索引0后的部分,得到tensor([3, 2, 1])
还可以传入一个与输入张量相同形状的权重张量 weights
,可以为每个值指定权重
weights = torch.tensor([0.1, 0.2, 0.3, 0.4, 0.5, 0.6])
weighted_counts = torch.bincount(input, weights)
print(weighted_counts) # 输出:tensor([0.0000, 1.2000, 0.6000, 0.3000])
# 对于0来说,没有出现就是0
# 对于1来说,出现了三次:第一次出现位置上对应的权重为0.1,第二次出现位置上对应的权重为0.5,第三次出现 # 位置上对应的权重为0.6,所以0.1+0.5+0.6=1.2
# 对2来说,出现两次:第一次出现位置对应的权重为0.2,第二次出现位置对应的权重为0.4,故0.2+0.4=0.6
# 对于3来说,出现了一次:第一次出现位置上对应的权重为0.3,所以为0.3
通过调用 torch.bincount(input, weights)
,计算了每个值在输入张量中出现的加权频次,得到了张量 weighted_counts
此外,可以通过设置 minlength
参数来指定输出张量的最小长度
minlength_counts = torch.bincount(input, minlength=5)
print(minlength_counts) # 输出: tensor([0, 3, 2, 1, 0])
在上面的示例中,我们调用torch.bincount(input, minlength=5)
,将最小长度设置为5,得到了张量 minlength_counts
,它的长度为5,包含了输入张量中每个非负整数值的频次
补充:对于numpy数组有 numpy.bincount( )函数的用法: numpy.bincount( )函数的用法-CSDN博客 可以参考博文对比理解