目录
- 1. 基本知识
- 2. Demo
1. 基本知识
基本的原理知识如下:
-
输入张量和掩码:
masked_fill 接受两个主要参数:一个输入张量和一个布尔掩码
掩码的形状必须与输入张量相同,True 表示需要填充的位置,False 表示保持原值 -
掩码操作:
在执行 masked_fill 操作时,函数会检查掩码中每个元素的值
如果掩码对应的位置为 True,则在输出张量中填充指定的值;
如果为 False,则保留输入张量中对应位置的值 -
输出结果:
最终生成的新张量包含了在掩码位置上被替换的值,其余位置保持原样
在代码逻辑上:
- 创建掩码:
mask 是一个布尔张量,标识了哪些位置需要填充:
[[False, True, False],
[True, False, True],
[False, False, True]]
- 执行 masked_fill:
当调用 tensor.masked_fill(mask, -1) 时,PyTorch 会遍历掩码中的每个元素:对于 mask 中的每个 True 值,tensor 在对应位置的值会被替换为 -1,对于 False 值,保持原值不变
masked_fill 操作是基于 C/C++ 的实现,因此在处理大规模数据时性能较高。常用于深度学习模型中的数据预处理,比如在填充序列、处理缺失值或标记特定条件的数据时
2. Demo
Demo 1: 基本用法
import torch
# 创建一个 3x3 的张量
tensor = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 创建一个掩码,标记要填充的位置
mask = torch.tensor([[False, True, False],
[True, False, True],
[False, False, True]])
# 使用 masked_fill 填充掩码位置为 -1
result = tensor.masked_fill(mask, -1)
print("原始张量:")
print(tensor)
print("\n填充后的张量:")
print(result)
截图如下:
Demo 2: 与条件结合使用
import torch
# 创建一个随机张量
tensor = torch.randn(3, 3)
# 创建掩码:标记负值的位置
mask = tensor < 0
# 将负值位置填充为 0
result = tensor.masked_fill(mask, 0)
print("原始张量:")
print(tensor)
print("\n填充后的张量 (负值填充为 0):")
print(result)
截图如下:
Demo 3: 结合计算
import torch
# 创建一个张量
tensor = torch.tensor([[10, 20, 30],
[40, 50, 60],
[70, 80, 90]])
# 创建掩码:标记大于 50 的位置
mask = tensor > 50
# 用 999 填充大于 50 的位置
result = tensor.masked_fill(mask, 999)
print("原始张量:")
print(tensor)
print("\n填充后的张量 (大于 50 的位置填充为 999):")
print(result)
截图如下: