torch.multinomial
用于从一个给定的概率分布中随机抽样,返回的是采样得到的索引。通常用于分类任务中的多项式采样(例如在序列生成模型中选择下一个元素)。它接受概率张量作为输入,并返回根据这些概率分布采样的索引。
基本语法
torch.multinomial(input, num_samples, replacement=False, *, generator=None, out=None) → Tensor
参数解释:
- input:概率张量,要求非负。如果为一维张量,则表示单一概率分布;如果是二维张量,每一行表示一个概率分布。
- num_samples:要采样的元素数量。
- replacement:是否允许重复采样。若为
True
,则可能从同一分布中多次抽样相同的元素。 - generator:用于控制随机数生成的可选随机数生成器。
示例代码
以下是一个简单的示例,展示如何使用 torch.multinomial
从不同的概率分布中进行采样:
import torch
# 定义一个概率分布
probs = torch.tensor([0.1, 0.2, 0.3, 0.4])
# 从单一分布中采样1个元素,不允许重复
sampled_idx = torch.multinomial(probs, num_samples=1, replacement=False)