动手学图神经网络(2):跆拳道俱乐部案例实战
在深度学习领域,图神经网络(GNNs)能将传统深度学习概念推广到不规则的图结构数据,使神经网络能够处理对象及其关系。将基于 PyTorch Geometric 库,一步步探索图神经网络的奥秘。
安装必要的包
首先, 安装所需的 Python 包。在开始之前, 需要获取当前使用的 PyTorch 版本,
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)
接下来, 可以使用以下命令安装必要的库:
# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q git+https://github.com/pyg-team/pytorch_geometric.git
同时, 还需要一些用于可视化的辅助函数:
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt
def visualize_graph(G, color):
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
node_color=color, cmap="Set2")
plt.show()
def visualize_embedding(h, color, epoch=None, loss=None):
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
h = h.detach().cpu().numpy()
plt.scatter(h[:, 0], h[:, 1], s=140, c=color, cmap="Set2")
if epoch is not None and loss is not None:
plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
plt.show()
图神经网络基础介绍
图神经网络(GNNs)旨在将经典深度学习概念推广到不规则结构数据(与图像或文本不同),使神经网络能够推理对象及其关系。遵循简单的神经消息传递方案,在图
G
=
(
V
,
E
)
\mathcal{G} = (\mathcal{V}, \mathcal{E})
G=(V,E) 中,所有节点
v
∈
V
v \in \mathcal{V}
v∈V 的节点特征
x
v
(
ℓ
)
\mathbf{x}_v^{(\ell)}
xv(ℓ) 通过聚合其邻居
N
(
v
)
\mathcal{N}(v)
N(v) 的局部信息来迭代更新:
x
v
(
ℓ
+
1
)
=
f
θ
(
ℓ
+
1
)
(
x
v
(
ℓ
)
,
{
x
w
(
ℓ
)
:
w
∈
N
(
v
)
}
)
\mathbf{x}_v^{(\ell + 1)} = f^{(\ell + 1)}_{\theta} \left( \mathbf{x}_v^{(\ell)}, \left\{ \mathbf{x}_w^{(\ell)} : w \in \mathcal{N}(v) \right\} \right)
xv(ℓ+1)=fθ(ℓ+1)(xv(ℓ),{xw(ℓ):w∈N(v)})
本教程将基于 PyTorch Geometric (PyG) 库 介绍图神经网络的一些基本概念。PyTorch Geometric 是流行深度学习框架 PyTorch 的扩展库,包含各种方法和实用工具,便于实现图神经网络。
将以著名的 Zachary’s karate club network 为例,深入了解图神经网络。这个图描述了一个空手道俱乐部 34 名成员的社交网络,并记录了俱乐部外成员之间的联系。我们的目标是检测由成员互动产生的社区。
加载数据集
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
print(f'Dataset: {dataset}:')
print('======================')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
初始化 KarateClub
数据集后,可以检查其一些属性。可以看到,这个数据集只包含 一个图,每个节点都被分配了一个 34 维的特征向量,用于唯一描述空手道俱乐部的成员。此外,图中正好有 4 个类,代表每个节点所属的社区。
查看图的详细信息
data = dataset[0] # Get the first graph object.
print(data)
print('==============================================================')
# Gather some statistics about the graph.
print(f'Number of nodes: {data.num_nodes}')
print(f'Number of edges: {data.num_edges}')
print(f'Average node degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of training nodes: {data.train_mask.sum()}')
print(f'Training node label rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Has isolated nodes: {data.has_isolated_nodes()}')
print(f'Has self-loops: {data.has_self_loops()}')
print(f'Is undirected: {data.is_undirected()}')
在 PyTorch Geometric 中,每个图都由一个 Data
对象表示,它包含了描述图所需的所有信息。通过 print(data)
可以查看数据对象的属性和形状的简要摘要。
Data(edge_index=[2, 156], x=[34, 34], y=[34], train_mask=[34])
可以看到,这个 data
对象包含 4 个属性:
edge_index
属性保存了 图连接性 的信息,即每个边的源节点和目标节点的索引对。- 节点特征 用
x
表示(34 个节点中的每个节点都被分配了一个 34 维的特征向量)。 - 节点标签 用
y
表示(每个节点都被分配到一个类)。 - 还有一个额外的属性
train_mask
,它描述了哪些节点的社区分配是已知的。
查看图的连接信息
from IPython.display import Javascript # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 300})'''))
edge_index = data.edge_index
print(edge_index.t())
通过打印 edge_index
,我们可以了解 PyG 如何在内部表示图的连接性。可以看到,对于每条边,edge_index
包含一个两个节点索引的元组,其中第一个值描述源节点的索引,第二个值描述目标节点的索引。
tensor([[ 0, 1],
[ 0, 2],
[ 0, 3],
[ 0, 4],
[ 0, 5],
[ 0, 6],
[ 0, 7],
[ 0, 8],
[ 0, 10],
[ 0, 11],
[ 0, 12],
[ 0, 13],
[ 0, 17],
[ 0, 19],
[ 0, 21],
[ 0, 31],
[ 1, 0],
[ 1, 2],
[ 1, 3],
[ 1, 7],
[ 1, 13],
[ 1, 17],
[ 1, 19],
[ 1, 21],
[ 1, 30],
[ 2, 0],
[ 2, 1],
[ 2, 3],
[ 2, 7],
[ 2, 8],
[ 2, 9],
[ 2, 13],
[ 2, 27],
[ 2, 28],
[ 2, 32],
[ 3, 0],
[ 3, 1],
[ 3, 2],
[ 3, 7],
[ 3, 12],
[ 3, 13],
[ 4, 0],
[ 4, 6],
[ 4, 10],
[ 5, 0],
[ 5, 6],
[ 5, 10],
[ 5, 16],
[ 6, 0],
[ 6, 4],
[ 6, 5],
[ 6, 16],
[ 7, 0],
[ 7, 1],
[ 7, 2],
[ 7, 3],
[ 8, 0],
[ 8, 2],
[ 8, 30],
[ 8, 32],
[ 8, 33],
[ 9, 2],
[ 9, 33],
[10, 0],
[10, 4],
[10, 5],
[11, 0],
[12, 0],
[12, 3],
[13, 0],
[13, 1],
[13, 2],
[13, 3],
[13, 33],
[14, 32],
[14, 33],
[15, 32],
[15, 33],
[16, 5],
[16, 6],
[17, 0],
[17, 1],
[18, 32],
[18, 33],
[19, 0],
[19, 1],
[19, 33],
[20, 32],
[20, 33],
[21, 0],
[21, 1],
[22, 32],
[22, 33],
[23, 25],
[23, 27],
[23, 29],
[23, 32],
[23, 33],
[24, 25],
[24, 27],
[24, 31],
[25, 23],
[25, 24],
[25, 31],
[26, 29],
[26, 33],
[27, 2],
[27, 23],
[27, 24],
[27, 33],
[28, 2],
[28, 31],
[28, 33],
[29, 23],
[29, 26],
[29, 32],
[29, 33],
[30, 1],
[30, 8],
[30, 32],
[30, 33],
[31, 0],
[31, 24],
[31, 25],
[31, 28],
[31, 32],
[31, 33],
[32, 2],
[32, 8],
[32, 14],
[32, 15],
[32, 18],
[32, 20],
[32, 22],
[32, 23],
[32, 29],
[32, 30],
[32, 31],
[32, 33],
[33, 8],
[33, 9],
[33, 13],
[33, 14],
[33, 15],
[33, 18],
[33, 19],
[33, 20],
[33, 22],
[33, 23],
[33, 26],
[33, 27],
[33, 28],
[33, 29],
[33, 30],
[33, 31],
[33, 32]])
这种表示方式称为 COO 格式(坐标格式),通常用于表示稀疏矩阵。PyG 以稀疏方式表示图,只保存邻接矩阵 A \mathbf{A} A 中非零元素的坐标和值。
可视化图
from torch_geometric.utils import to_networkx
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)
可以将图转换为 networkx
库的格式,利用其强大的可视化工具来可视化图。
实现图神经网络
在了解了 PyG 的数据处理之后, 实现 第一个图神经网络 ! 将使用最简单的 GNN 算子之一,即 GCN 层 (Kipf et al. (2017)),其定义为:
x
v
(
ℓ
+
1
)
=
W
(
ℓ
+
1
)
∑
w
∈
N
(
v
)
∪
{
v
}
1
c
w
,
v
⋅
x
w
(
ℓ
)
\mathbf{x}_v^{(\ell + 1)} = \mathbf{W}^{(\ell + 1)} \sum_{w \in \mathcal{N}(v) \, \cup \, \{ v \}} \frac{1}{c_{w,v}} \cdot \mathbf{x}_w^{(\ell)}
xv(ℓ+1)=W(ℓ+1)w∈N(v)∪{v}∑cw,v1⋅xw(ℓ)
其中
W
(
ℓ
+
1
)
\mathbf{W}^{(\ell + 1)}
W(ℓ+1) 表示形状为 [num_output_features, num_input_features]
的可训练权重矩阵,
c
w
,
v
c_{w,v}
cw,v 指的是每条边的固定归一化系数。
定义图神经网络模型
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self):
super().__init__()
torch.manual_seed(1234)
self.conv1 = GCNConv(dataset.num_features, 4)
self.conv2 = GCNConv(4, 4)
self.conv3 = GCNConv(4, 2)
self.classifier = Linear(2, dataset.num_classes)
def forward(self, x, edge_index):
h = self.conv1(x, edge_index)
h = h.tanh()
h = self.conv2(h, edge_index)
h = h.tanh()
h = self.conv3(h, edge_index)
h = h.tanh() # Final GNN embedding space.
# Apply a final (linear) classifier.
out = self.classifier(h)
return out, h
model = GCN()
print(model)
在 __init__
方法中,初始化了所有的构建块,并在 forward
方法中定义了网络的计算流程。定义并堆叠了 三个图卷积层,这相当于聚合每个节点周围的 3 跳邻域信息。每个 GCNConv
层后都应用了一个 tanh 非线性激活函数。
应用一个线性变换 (torch.nn.Linear
) 作为分类器,将节点映射到 4 个类/社区之一。
可视化节点嵌入
model = GCN()
_, h = model(data.x, data.edge_index)
print(f'Embedding shape: {list(h.shape)}')
visualize_embedding(h, color=data.y)
即使在训练模型之前,模型产生的节点嵌入已经很好地反映了图的社区结构。相同颜色(社区)的节点在嵌入空间中已经紧密聚集在一起,这表明 GNNs 引入了强大的归纳偏置,使得在输入图中彼此接近的节点具有相似的嵌入。
训练模型
import time
from IPython.display import Javascript # Restrict height of output cell.
display(Javascript('''google.colab.output.setIframeHeight(0, true, {maxHeight: 430})'''))
model = GCN()
criterion = torch.nn.CrossEntropyLoss() # Define loss criterion.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # Define optimizer.
def train(data):
optimizer.zero_grad() # Clear gradients.
out, h = model(data.x, data.edge_index) # Perform a single forward pass.
loss = criterion(out[data.train_mask], data.y[data.train_mask]) # Compute the loss solely based on the training nodes.
loss.backward() # Derive gradients.
optimizer.step() # Update parameters based on gradients.
return loss, h
for epoch in range(401):
loss, h = train(data)
if epoch % 10 == 0:
visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
time.sleep(0.3)
训练过程与其他 PyTorch 模型类似。定义了损失函数 (CrossEntropyLoss
) 和随机梯度优化器 (Adam
)。在每一轮优化中,进行前向传播和反向传播,计算模型参数相对于损失的梯度,并更新参数。
通过观察节点嵌入的变化,可以看到 3 层 GCN 模型能够很好地线性分类,并正确分类大多数节点。
本教程是对图神经网络和 PyTorch Geometric 的初步介绍。