Pytorch常用的函数(九)torch.gather()用法
torch.gather()
就是在指定维度上收集value。
torch.gather()
的必填也是最常用的参数有三个,下面引用官方解释:
input
(Tensor) – the source tensordim
(int) – the axis along which to indexindex
(LongTensor) – the indices of elements to gather
一句话概括 gather 操作就是:根据 index
,在 input
的 dim
维度上收集 value。
1、举例直观理解
# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
# 2、我们有index_tensor如下
>>> index_tensor = torch.tensor(
[[[0, 0, 0, 0],
[2, 2, 2, 2]],
[[0, 0, 0, 0],
[2, 2, 2, 2]]]
)
# 3、我们通过torch.gather()函数获取out_tensor
>>> out_tensor = torch.gather(input_tensor, dim=1, index=index_tensor)
tensor([[[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[20, 21, 22, 23]]])
我们以out_tensor中[0,1,0]=8为例,解释下如何利用dim和index
,从input_tensor中获得8。
根据上图,我们很直观的了解根据 index
,在 input
的 dim
维度上收集 value的过程。
- 假设
input
和index
均为三维数组,那么输出 tensor 每个位置的索引是列表[i, j, k]
,正常来说我们直接取input[i, j, k]
作为 输出 tensor 对应位置的值即可; - 但是由于
dim
的存在以及input.shape
可能不等于index.shape
,所以直接取值可能就会报错 ; - 所以我们是将索引列表的相应位置替换为
dim
,再去input
取值。在上面示例中,由于dim=1,那么我们就替换索引列表第1个值,即[i,dim,k]
,因此由原来的[0,1,0]替换为[0,2,0]后,再去input_tensor中取值。 - pytorch官方文档的写法如下,同一个意思。
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
2、反推法再理解
# 1、我们有input_tensor如下
>>> input_tensor = torch.arange(24).reshape(2, 3, 4)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
# 2、假设我们要得到out_tensor如下
>>> out_tensor
tensor([[[ 0, 1, 2, 3],
[ 8, 9, 10, 11]],
[[12, 13, 14, 15],
[20, 21, 22, 23]]])、
# 3、如何知道dim 和 index_tensor呢?
# 首先,我们要记住:out_tensor的shape = index_tensor的shape
# 从 output_tensor 的第一个位置开始:
# 此时[i, j, k]一样,看不出来 dim 应该是多少
output_tensor[0, 0, :] = input_tensor[0, 0, :] = 0
# 同理可知,此时index都为0
output_tensor[0, 0, 1] = input_tensor[0, 0, 1] = 1
output_tensor[0, 0, 2] = input_tensor[0, 0, 2] = 2
output_tensor[0, 0, 3] = input_tensor[0, 0, 3] = 3
# 我们从下一行的第一个位置开始:
# 这里我们看到维度 1 发生了变化,1 变成了 2,所以 dim 应该是 1,而 index 应为 2
output_tensor[0, 1, 0] = input_tensor[0, 2, 0] = 8
# 同理可知,此时index都为2
output_tensor[0, 1, 1] = input_tensor[0, 2, 1] = 9
output_tensor[0, 1, 2] = input_tensor[0, 2, 2] = 10
output_tensor[0, 1, 3] = input_tensor[0, 2, 3] = 11
# 根据上面推导我们易知dim=1,index_tensor为:
>>> index_tensor = torch.tensor(
[[[0, 0, 0, 0],
[2, 2, 2, 2]],
[[0, 0, 0, 0],
[2, 2, 2, 2]]]
)
3、实际案例
在大神何凯明MAE模型(Masked Autoencoders Are Scalable Vision Learners)源码中,多次使用了torch.gather()
函数。
- 论文链接:https://arxiv.org/pdf/2111.06377
- 官方源码:https://github.com/facebookresearch/mae
在MAE中根据预设的掩码比例(paper 中提倡的是 75%),使用服从均匀分布的随机采样策略采样一部分 tokens 送给 Encoder,另一部分mask 掉。采样25%作为unmasked tokens过程中,使用了torch.gather()
函数。
# models_mae.py
import torch
def random_masking(x, mask_ratio=0.75):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N, L, D = x.shape # batch, length, dim
len_keep = int(L * (1 - mask_ratio)) # 计算unmasked的片数
# 利用0-1均匀分布进行采样,避免潜在的【中心归纳偏好】
noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
# sort noise for each sample【核心代码】
ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
ids_restore = torch.argsort(ids_shuffle, dim=1)
# keep the first subset
ids_keep = ids_shuffle[:, :len_keep]
# 利用torch.gather()从源tensor中获取25%的unmasked tokens
x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
# generate the binary mask: 0 is keep, 1 is remove
mask = torch.ones([N, L], device=x.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return x_masked, mask, ids_restore
if __name__ == '__main__':
x = torch.arange(64).reshape(1, 16, 4)
random_masking(x)
# x模拟一张图片经过patch_embedding后的序列
# x相当于input_tensor
# 16是patch数量,实际上一般为(img_size/patch_size)^2 = (224 / 16)^2 = 14*14=196
# 4是一个patch中像素个数,这里只是模拟,实际上一般为(in_chans * patch_size * patch_size = 3*16*16 = 768)
>>> x = torch.arange(64).reshape(1, 16, 4)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15],
[16, 17, 18, 19], # 4
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31],
[32, 33, 34, 35],
[36, 37, 38, 39],
[40, 41, 42, 43], # 10
[44, 45, 46, 47],
[48, 49, 50, 51], # 12
[52, 53, 54, 55], # 13
[56, 57, 58, 59],
[60, 61, 62, 63]]])
# dim=1, index相当于index_tensor
>>> index
tensor([[[10, 10, 10, 10],
[12, 12, 12, 12],
[ 4, 4, 4, 4],
[13, 13, 13, 13]]])
# x_masked(从源tensor即x中,随机获取25%(4个patch)的unmasked tokens)
>>> x_masked相当于out_tensor
tensor([[[40, 41, 42, 43],
[48, 49, 50, 51],
[16, 17, 18, 19],
[52, 53, 54, 55]]])