目录
- 一、time embedding调用过程
- 二、time embedding定义过程
- 三、PositionalEmbedding定义过程
time embedding本质上就是把时间步t转换为指定维度的嵌入向量,这个向量由时间步
,向量维度,周期等参数决定,可以简单理解为根据时间步t在正弦余弦函数采样得到的向量。
一、time embedding调用过程
传输时间time,得到嵌入向量
def forward(self, x, time=None, y=None):
if self.time_mlp is not None:
if time is None:
raise ValueError("time conditioning was specified but tim is not passed")
time_emb = self.time_mlp(time)
else:
time_emb = None
二、time embedding定义过程
PositionalEmbedding相当于创建的一个查询表,根据time,返回向量,再经过nn.Linear()->nn.SiLU()->nn.Linear()得到最终的嵌入向量
class UNet(nn.Module):
def __init__(
self,
img_channels,
base_channels,
channel_mults=(1, 2, 4, 8),
num_res_blocks=2,
time_emb_dim=None,
time_emb_scale=1.0,
num_classes=None,
activation=F.relu,
dropout=0.1,
attention_resolutions=(),
norm="gn",
num_groups=32,
initial_pad=0,
):
super().__init__()
self.time_mlp = nn.Sequential(
PositionalEmbedding(base_channels, time_emb_scale),
nn.Linear(base_channels, time_emb_dim),
nn.SiLU(),
nn.Linear(time_emb_dim, time_emb_dim),
) if time_emb_dim is not None else None
三、PositionalEmbedding定义过程
class PositionalEmbedding(nn.Module):
__doc__ = r"""Computes a positional embedding of timesteps.
Input:
x: tensor of shape (N)
Output:
tensor of shape (N, dim)
Args:
dim (int): embedding dimension
scale (float): linear scale to be applied to timesteps. Default: 1.0
"""
def __init__(self, dim, scale=1.0):
super().__init__()
assert dim % 2 == 0
self.dim = dim
self.scale = scale
def forward(self, x):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / half_dim
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
emb = torch.outer(x * self.scale, emb)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
下面是transformer中位置嵌入的一个定义。
接下来是这段代码对应的位置嵌入的一个定义。
上述代码的emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
对应下面这个式子torch.exp(torch.arange(half_dim) / half_dim)*1/10000
也就是