AlphaFold3 的 AtomAttentionEncoder
类中,init_pair_repr
方法方法负责为原子之间的关系计算成对表示(pair representation),这是原子转变器(atom transformer)模型的关键组成部分,直接影响对蛋白质/分子相互作用的建模。
init_pair_repr
源代码:
def init_pair_repr(
self,
features: Dict[str, Tensor],
atom_cond: Tensor,
z_trunk: Optional[Tensor],
) -> Tensor:
"""Compute the pair representation for the atom transformer.
This is done in a separate function for checkpointing. The intermediate activations due to the
atom pair representations are large and can be checkpointed to reduce memory usage.
Args:
features:
Dictionary of input features.
atom_cond:
[bs, n_atoms, c_atom] The single atom conditioning from init_single_repr
z_trunk:
[bs, n_tokens, n_tokens, c_trunk] the pair representation from the trunk
Returns:
[bs, n_atoms // n_queries, n_queries, n_keys, c_atompair] The pair representation
"""
# Compute offsets between atom reference positions
a = partition_tensor(features['ref_pos'], self.n_queries, self.n_queries) # (bs, n_atoms // 32, 32, 3)
b = partition_tensor(features['ref_pos'], self.n_queries, self.n_keys) # (bs, n_atoms // 32, 128, 3)
offsets = a[:, :, :, None, :] - b[:, :, None, :, :] # (bs, n_atoms // 32, 32, 128, 3)
# Compute the valid mask
ref_space_uid = features['ref_space_uid'].unsqueeze(-1) # (bs, n_atoms, 1)
a = partition_tensor(ref_space_uid, self.n_queries, self.n_queries) # (bs, n_atoms // 32, 32)
b = partition_tensor(ref_space_uid, self.n_queries, self.n_keys) # (bs, n_atoms // 32, 128)
valid_mask = a[:, :, :, None] == b[:, :, None, :] # (bs, n_atoms // 32, 32, 128, 1)
valid_mask = valid_mask.to(offsets.dtype) # convert boolean to binary
# Embed the atom offsets and the valid mask
local_atom_pair = self.linear_atom_offsets(offsets) * valid_mask
# Embed pairwise inverse squared distances, and the valid mask
squared_distances = offsets.pow(2).sum(dim=-1, keepdim=True) # (bs, n_atoms // 32, 32, 128, 1)
inverse_dists = torch.reciprocal(torch.add(squared_distances, 1))
local_atom_pair = local_atom_pair + self.linear_atom_distances(inverse_dists) * valid_mask
local_atom_pair = local_atom_pair + self.linear_mask(valid_mask) * valid_mask
# If provided, add trunk embeddings
if self.trunk_conditioning:
local_atom_pair = local_atom_pair + map_token_pairs_to_local_atom_pairs(
self.proj_trunk_pair(z_trunk),
features['atom_to_token']
)
# Add the combined single conditioning to the pair representation
a = partition_tensor(self.linear_single_to_pair_row(F.relu(atom_cond)), self.n_queries, self.n_queries)
b = partition_tensor(self.linear_single_to_pair_col(F.relu(atom_cond)), self.n_queries, self.n_keys)
local_atom_pair = local_atom_pair + (a[:, :, :, None, :] + b[:, :, None, :, :])
# Run a small MLP on the pair activations
local_atom_pair = self.pair_mlp(local_atom_pair)
return local_atom_pair
init_pair_repr
代码解读:
1. 函数定义与注释
def init_pair_repr(
self,
features: Dict[str, Tensor],
atom_cond: Tensor,
z_trunk: Optional[Tensor],
) -> Tensor:
"""
Compute the pair representation for the atom transformer.
Args:
features: Dictionary of input features.
atom_cond: [bs, n_atoms, c_atom] The single atom conditioning from init_single_repr
z_trunk: [bs, n_tokens, n_tokens, c_trunk] the pair representation from the trunk
Returns:
[bs, n_atoms // n_queries, n_queries, n_keys, c_atompair] The pair representation
"""
-
功能描述:
- 方法用于计算原子之间的成对表示(pair representation),描述原子对之间的相互关系。
- 通过输入特征和条件化单原子表示(
atom_cond
)生成成对表示。 - 如果有 trunk 模块输出(
z_trunk
),进一步将其纳入建模。
-
输入参数:
features
: 包含输入原子特征的字典,例如参考位置、掩码等。atom_cond
: 由init_single_repr
生成的单原子条件表示,提供单原子特征。z_trunk
: 可选的 trunk 模块输出,用于加入全局上下文信息。
-
输出:
- 返回形状为
[bs, n_atoms // n_queries, n_queries, n_keys, c_atompair]
的成对表示张量。
- 返回形状为
2. 计算原子间的位移偏移量
a = partition_tensor(features['ref_pos'], self.n_queries, self.n_queries) # (bs, n_atoms // 32, 32, 3)
b = partition_tensor(features['ref_pos'], self.n_queries, self.n_keys) # (bs, n_atoms // 32, 128, 3)
offsets = a[:, :, :, None, :] - b[:, :, None, :, :] # (bs, n_atoms // 32, 32, 128, 3)
- 功能:
- 通过分块操作,将原子的三维参考位置(
ref_pos
)分为 query 和 key 的两个集合,计算原子对的位移向量offsets
。
- 通过分块操作,将原子的三维参考位置(
- 理论基础:
- 原子间的位移向量是物理意义上的距离关系的基础,直接影响距离计算和相互作用建模。
- 细节:
partition_tensor
将输入张量按块划分,便于后续处理。offsets
形状为[bs, n_atoms // n_queries, n_queries, n_keys, 3]
。
原理解读:
什么是 features['ref_pos']
?
features['ref_pos']
是原子在 3D 空间中的参考坐标,形状为(bs, n_atoms, 3)
。bs
是批量大小(batch size)。n_atoms
是蛋白质中的原子数量。- 每个原子的坐标由 3 个值(x, y, z)表示。
为什么使用 partition_tensor
?
partition_tensor
将输入张量按滑动窗口分区,使得可以对局部子集进行高效计算。- 作用:通过滑动窗口对原子的参考坐标进行局部划分:
- 第一次划分
a
:窗口大小为n_queries
,滑动步长为n_queries
,即每次取 32 个原子的局部坐标。 - 第二次划分
b
:窗口大小为n_keys
,滑动步长为n_queries
,即每次取 128 个原子的局部坐标。
- 第一次划分
- 分区后的结果:
a
:形状为(bs, n_atoms // 32, 32, 3)
,表示每个滑动窗口内的原子局部坐标(32 个)。b
:形状为(bs, n_atoms // 32, 128, 3)
,表示每个滑动窗口内的原子扩展区域(128 个)。
为什么计算 offset