文章目录
- 一、任务描述
- 二、环境配置
- 三、加载数据
- 四、定义网络结构
- 五、训练模型
一、任务描述
Karate Club 图任务是一个经典的图结构学习问题,通常用于社交网络分析和社区检测。该数据集是由 Wayne W. Zachary 在1977年收集的,描述了一个美国的空手道俱乐部成员间的社交互动。
任务描述如下:
- 图结构:该数据集包含34个节点(即空手道俱乐部的成员)和78条边(即成员之间的友谊关系)。每条边表示两个成员之间的连接。
- 节点特征:每个节点可以具有一些基本特征,如成员的身份信息(例如,成员的性别、年龄等)。在某些情况下,节点也可能包含其他的上下文信息。
- 社区划分:一个重要的任务是检测社交网络中的社区结构。在Zachary的研究中,发现该俱乐部在1970年代末分裂成两个派系。这两个派系通常被称为“Club A”和“Club B”。
- 任务目标:
节点分类:根据现有的边连接关系和节点特征来预测某个节点属于哪个社区。
二、环境配置
完成该项目需要安装一个关键的第三方依赖torch_geometric
,官方文档如下:torch_geometric,可以通过如下指令一键安装:
conda install pyg -c pyg
除此之外还需要安装pytorch,matplotlib,networkx这三个库,前面的pytorch用于搭建GNN网络结构,后面两个库用来可视化数据。
三、加载数据
通过torch_geometric.datasets直接加载KarateClub数据,查看这个图的大小规模。
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
print(f'Dataset: {dataset}:')
print("========================================")
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')
输出
Dataset: KarateClub():
========================================
Number of graphs: 1
Number of features: 34
Number of classes: 4
查看数据的具体内容:
data = dataset[0] # Get the first graph object.
print(data)
print(type(data))
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34]) #
# <class 'torch_geometric.data.data.Data'>
说明这个图里面有34个节点,每个节点34个特征,156条边,34个标签,34个训练掩码,训练掩码的作用是在计算损失时只计算有标签的节点损失。edge_index就是图的邻接矩阵,但因为如果用全矩阵来表示邻接矩阵过于稀疏,这里用一个2*156的矩阵表示,代表有156条边,两个tensor对应的位置即为连接的边。比如0-1,0-2…
edge_index = data.edge_index
print(edge_index)
tensor([[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,
1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3,
3, 3, 3, 3, 3, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 7, 7,
7, 7, 8, 8, 8, 8, 8, 9, 9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
[ 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 17, 19, 21, 31, 0, 2,
3, 7, 13, 17, 19, 21, 30, 0, 1, 3, 7, 8, 9, 13, 27, 28, 32, 0,
1, 2, 7, 12, 13, 0, 6, 10, 0, 6, 10, 16, 0, 4, 5, 16, 0, 1,
2, 3, 0, 2, 30, 32, 33, 2, 33, 0, 4, 5, 0, 0, 3, 0, 1, 2,
3, 33, 32, 33, 32, 33, 5, 6, 0, 1, 32, 33, 0, 1, 33, 32, 33, 0,
1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33, 2, 23,
24, 33, 2, 31, 33, 23, 26, 32, 33, 1, 8, 32, 33, 0, 24, 25, 28, 32,
33, 2, 8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33, 8, 9, 13, 14, 15,
18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])
用network可视化一下这个数据集:
%matplotlib inline
import torch
import networkx as nx
import matplotlib.pyplot as plt
def visualize_graph(G,color):
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
node_color=color, cmap="Set2")
plt.show()
def visualize_embedding(h,color,epoch=None,loss=None):
plt.figure(figsize=(7,7))
plt.xticks([])
plt.yticks([])
h = h.detach().cpu().numpy()
plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap="Set2")
if epoch is not None and loss is not None:
plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
plt.show()
from torch_geometric.utils import to_networkx
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)
四、定义网络结构
定义一个简单的GNN模型,包含三个GCNconv和一个分类层,激活函数选择tanh,将GCNConv结果和最后分类的结果都做返回。
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv
class GCN(torch.nn.Module):
def __init__(self):
super(GCN, self).__init__()
torch.manual_seed(1234)
self.conv1 = GCNConv(dataset.num_features, 4)
self.conv2 = GCNConv(4, 4)
self.conv3 = GCNConv(4, 2)
self.classifier = Linear(2, dataset.num_classes)
def forward(self, x, edge_index):
h = self.conv1(x, edge_index)
h = h.tanh()
h = self.conv2(h, edge_index)
h = h.tanh()
h = self.conv3(h, edge_index)
h = h.tanh()
out = self.classifier(h)
return out, h
model = GCN()
print(model)
先查看一下未训练模型对于数据处理情况:
model = GCN()
_,h = model(data.x, data.edge_index)
visualize_embedding(h, color=data.y)
五、训练模型
训练过程大同小异了,定义模型,定义优化器,定义损失函数,计算损失,反向传播更新参数,训练400轮,直到训练完成。可视化GCNconv处理之后的特征。
import time
model = GCN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
def train(data):
optimizer.zero_grad()
out, h = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
return loss, h
for epoch in range(401):
loss, h = train(data)
if epoch % 10 == 0:
visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
time.sleep(0.3)