目录
- 一、 TimestepEmbedSequential
- 二、PyTorch之Checkpoint机制
- 三、AttentionBlock
- 四、use_scale_shift_norm
和nanoDiffusion-main相比,improved-diffusion-main代码是相似的,但有几个不是很好理解的地方记录一下。
一、 TimestepEmbedSequential
代码中class ResBlock继承自TimestepBlock,需要执行时间步嵌入操作,其他不需要。
class TimestepBlock(nn.Module):
"""
Any module where forward() takes timestep embeddings as a second argument.
"""
@abstractmethod
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
"""
A sequential module that passes timestep embeddings to the children that
support it as an extra input.
"""
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class ResBlock(TimestepBlock):
二、PyTorch之Checkpoint机制
def checkpoint(func, inputs, params, flag):
"""
Evaluate a function without caching intermediate activations, allowing for
reduced memory at the expense of extra compute in the backward pass.
:param func: the function to evaluate.
:param inputs: the argument sequence to pass to `func`.
:param params: a sequence of parameters `func` depends on but does not
explicitly take as arguments.
:param flag: if False, disable gradient checkpointing.
"""
if flag:
args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args)
else:
return func(*inputs)
checkpoint 是在 torch.no_grad() 模式下计算的目标操作的前向函数,这并不会修改原本的叶子结点的状态,有梯度的还会保持。只是关联这些叶子结点的临时生成的中间变量会被设置为不需要梯度,因此梯度链式关系会被断开。
三、AttentionBlock
class AttentionBlock(nn.Module):
def __init__(self, channels, num_heads=1, use_checkpoint=False):
super().__init__()
self.channels = channels
self.num_heads = num_heads
self.use_checkpoint = use_checkpoint
self.norm = normalization(channels)
self.qkv = conv_nd(1, channels, channels * 3, 1)
self.attention = QKVAttention()
self.proj_out = zero_module(conv_nd(1, channels, channels, 1))
def forward(self, x):
return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint)
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
h = self.attention(qkv)
h = h.reshape(b, -1, h.shape[-1])
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class QKVAttention(nn.Module):
"""
A module which performs QKV attention.
"""
def forward(self, qkv):
"""
Apply QKV attention.
:param qkv: an [N x (C * 3) x T] tensor of Qs, Ks, and Vs.
:return: an [N x C x T] tensor after attention.
"""
ch = qkv.shape[1] // 3
q, k, v = th.split(qkv, ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = th.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
weight = th.softmax(weight.float(), dim=-1).type(weight.dtype)
return th.einsum("bts,bcs->bct", weight, v)
@staticmethod
def count_flops(model, _x, y):
b, c, *spatial = y[0].shape
num_spatial = int(np.prod(spatial))
# We perform two matmuls with the same number of ops.
# The first computes the weight matrix, the second computes
# the combination of the value vectors.
matmul_ops = 2 * b * (num_spatial ** 2) * c
model.total_ops += th.DoubleTensor([matmul_ops])
下面这个函数是准备适合的qkv矩阵
def _forward(self, x):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)-》输入转换为(b,c,N)
qkv = self.qkv(self.norm(x))-》通过卷积转换为(b,3*c,N)
qkv = qkv.reshape(b * self.num_heads, -1, qkv.shape[2])
h = self.attention(qkv)
h = h.reshape(b, -1, h.shape[-1])
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class QKVAttention的forward就是下面的公式:
四、use_scale_shift_norm
def _forward(self, x, emb):
h = self.in_layers(x)
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = th.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
在深度学习中,特别是在处理如扩散模型(Diffusion Models)或任何需要精细控制输出特征的神经网络时,use_scale_shift_norm引入一种灵活的变换,这种变换通过缩放(scale)和平移(shift)来调整网络层的输出。
use_scale_shift_norm是一个布尔值(True或False),用于决定是否应用这种缩放和平移的归一化方法。如果use_scale_shift_norm为True,则执行以下步骤:
分割输出层:首先,代码将self.out_layers(一个包含网络层的列表)分割为两部分。out_norm是列表中的第一个层,负责进行某种形式的归一化或变换(尽管这里的名字是out_norm,但它可能不仅仅执行归一化,而是任何形式的变换层)。out_rest是列表中剩余的所有层,这些层将在缩放和平移之后应用。
提取缩放和平移参数:接下来,从emb_out(可能是嵌入层的输出或其他某种特征表示)中提取缩放(scale)和平移(shift)参数。这里假设emb_out的维度被设计为包含这两组参数,通过th.chunk(emb_out, 2, dim=1)沿着第二维(dim=1)将其分割成两部分,分别代表缩放和平移参数。
应用缩放和平移:然后,将h(可能是之前某个层的输出)通过out_norm层进行变换,之后使用从emb_out中提取的缩放和平移参数对结果进行调整。调整的方式是将out_norm(h)的输出乘以(1 + scale)并加上shift。这个步骤实质上是在对out_norm(h)的输出进行线性变换,以引入额外的灵活性和控制。
通过剩余层:最后,将调整后的输出h通过out_rest中剩余的层进行进一步的处理。
这种技术的一个关键优势是它能够以一种灵活且数据驱动的方式调整网络层的输出,而不需要在模型架构中硬编码特定的归一化或变换策略。