import dgl
import dgl.function as fn
import torch
# 实例化一个异构图
g = dgl.heterograph({
('user', 'follows', 'user'): ([0, 1], [1, 1]),
('game', 'attracts', 'user'): ([0], [1])
})
g.nodes['user'].data['h'] = torch.tensor([[1.], [2.]])
g.nodes['game'].data['h'] = torch.tensor([[1.]])
# 更新所有节点
g.multi_update_all(
{'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),
'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},
"sum")
print(g.nodes['user'].data['h']) # 输出: tensor([[0.], [4.]])
# 用户定义的跨类型归约函数,等效于"sum"
def cross_sum(flist):
return torch.sum(torch.stack(flist, dim=0), dim=0) if len(flist) > 1 else flist[0]
# 使用用户定义的跨类型归约函数
g.multi_update_all(
{'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')),
'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))},
cross_sum)
在这段代码中,我们创建了一个包含用户和游戏两种类型节点以及关注和吸引两种类型边的异构图。节点数据通过 multi_update_all 方法更新,使用了两种聚合方法:内置的 “sum” 和用户自定义的 cross_sum 函数。用户自定义的 cross_sum 函数设计为在这种情况下与内置的 “sum” 聚合方法等效。
解释下为什么是print(g.nodes['user'].data['h']) # 输出: tensor([[0.], [4.]])
如下图:
首先'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')
user0的给到user1,user1的也给到自己user1,那就已经是1.0 + 2.0 = 3.0了,
接着 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))
game0的也给到user1,那但看这条边’attracts’,user1获得了1.0。
最后g.multi_update_all( {'follows': (fn.copy_u('h', 'm'), fn.sum('m', 'h')), 'attracts': (fn.copy_u('h', 'm'), fn.sum('m', 'h'))}, "sum")
的那个"sum",合并了边类型’follows’和’attracts’,那对user1来说,就是’follows’的3.0 + 'attracts’上获得的1.0,3.0 + 1.0 = 4.0。
user0都没有入边,一眼为0,因为没有源节点也就是入边。甭管虽然他原来的1.0。
下面是官方对此函数的解释(直接看example,毕竟我们是工科,实践为主我觉得)
### multi_update_all(self, etype_dict, cross_reducer, apply_node_func=None)
沿所有边发送消息,首先按类型归约,然后跨不同类型归约,接着更新所有节点的节点特征。
#### 参数
- **etype_dict** : dict
针对每种边类型的消息传递参数。键是边的类型,值是消息传递的参数。
允许的键格式有:
- **(str, str, str)** 表示源节点类型,边类型和目标节点类型。
- 或者一个可以唯一确定图中三元组格式的 **str** 边类型名称。
值必须是一个元组 **(message_func, reduce_func, [apply_node_func])**,其中:
- **message_func** : dgl.function.BuiltinFunction 或 callable
用于沿边生成消息的消息函数。
必须是一个 :ref:`api-built-in` 或 :ref:`apiudf`。
- **reduce_func** : dgl.function.BuiltinFunction 或 callable
用于聚合消息的聚合函数。
必须是一个 :ref:`api-built-in` 或 :ref:`apiudf`。
- **apply_node_func** : callable, 可选
在消息归约后,进一步更新节点特征的可选应用函数。
必须是一个 :ref:`apiudf`。
- **cross_reducer** : str 或 callable function
跨类型归约器。可以是 ``"sum"``, ``"min"``, ``"max"``, ``"mean"``, ``"stack"`` 之一,或者是一个可调用的函数。如果提供了可调用的函数,输入参数必须是一个包含来自每种边类型的聚合结果的张量列表,而函数的输出必须是一个单一的张量。
- **apply_node_func** : callable, 可选
在消息按类型归约和跨不同类型归约之后的可选应用函数。
必须是一个 :ref:`apiudf`。
#### 注意
DGL推荐在类型化消息传递参数中使用DGL的内置函数作为 message_func 和 reduce_func,因为在这种情况下,DGL将调用高效的内核,避免将节点特征复制到边特征。