GNN入门案例——KarateClub结点分类

文章目录

  • 一、任务描述
  • 二、环境配置
  • 三、加载数据
  • 四、定义网络结构
  • 五、训练模型

一、任务描述

Karate Club 图任务是一个经典的图结构学习问题,通常用于社交网络分析和社区检测。该数据集是由 Wayne W. Zachary 在1977年收集的,描述了一个美国的空手道俱乐部成员间的社交互动。

任务描述如下:

  • 图结构:该数据集包含34个节点(即空手道俱乐部的成员)和78条边(即成员之间的友谊关系)。每条边表示两个成员之间的连接。
  • 节点特征:每个节点可以具有一些基本特征,如成员的身份信息(例如,成员的性别、年龄等)。在某些情况下,节点也可能包含其他的上下文信息。
  • 社区划分:一个重要的任务是检测社交网络中的社区结构。在Zachary的研究中,发现该俱乐部在1970年代末分裂成两个派系。这两个派系通常被称为“Club A”和“Club B”。
  • 任务目标:
    节点分类:根据现有的边连接关系和节点特征来预测某个节点属于哪个社区。

二、环境配置

完成该项目需要安装一个关键的第三方依赖torch_geometric,官方文档如下:torch_geometric,可以通过如下指令一键安装:

conda install pyg -c pyg

除此之外还需要安装pytorch,matplotlib,networkx这三个库,前面的pytorch用于搭建GNN网络结构,后面两个库用来可视化数据。

三、加载数据

通过torch_geometric.datasets直接加载KarateClub数据,查看这个图的大小规模。

from torch_geometric.datasets import KarateClub

dataset = KarateClub()
print(f'Dataset: {dataset}:')   
print("========================================")
print(f'Number of graphs: {len(dataset)}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

输出

Dataset: KarateClub():
========================================
Number of graphs: 1
Number of features: 34
Number of classes: 4

查看数据的具体内容:

data = dataset[0]  # Get the first graph object.
print(data)
print(type(data))
# Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34]) # 
# <class 'torch_geometric.data.data.Data'>

说明这个图里面有34个节点,每个节点34个特征,156条边,34个标签,34个训练掩码,训练掩码的作用是在计算损失时只计算有标签的节点损失。edge_index就是图的邻接矩阵,但因为如果用全矩阵来表示邻接矩阵过于稀疏,这里用一个2*156的矩阵表示,代表有156条边,两个tensor对应的位置即为连接的边。比如0-1,0-2…

edge_index = data.edge_index
print(edge_index)
tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,
          3,  3,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  9,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
         13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
         21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
         27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
         31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
         33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
        [ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 17, 19, 21, 31,  0,  2,
          3,  7, 13, 17, 19, 21, 30,  0,  1,  3,  7,  8,  9, 13, 27, 28, 32,  0,
          1,  2,  7, 12, 13,  0,  6, 10,  0,  6, 10, 16,  0,  4,  5, 16,  0,  1,
          2,  3,  0,  2, 30, 32, 33,  2, 33,  0,  4,  5,  0,  0,  3,  0,  1,  2,
          3, 33, 32, 33, 32, 33,  5,  6,  0,  1, 32, 33,  0,  1, 33, 32, 33,  0,
          1, 32, 33, 25, 27, 29, 32, 33, 25, 27, 31, 23, 24, 31, 29, 33,  2, 23,
         24, 33,  2, 31, 33, 23, 26, 32, 33,  1,  8, 32, 33,  0, 24, 25, 28, 32,
         33,  2,  8, 14, 15, 18, 20, 22, 23, 29, 30, 31, 33,  8,  9, 13, 14, 15,
         18, 19, 20, 22, 23, 26, 27, 28, 29, 30, 31, 32]])

用network可视化一下这个数据集:

%matplotlib inline
import torch 
import networkx as nx
import matplotlib.pyplot as plt
def visualize_graph(G,color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()
def visualize_embedding(h,color,epoch=None,loss=None):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    h = h.detach().cpu().numpy()
    plt.scatter(h[:,0],h[:,1],s=140,c=color,cmap="Set2")
    if epoch is not None and loss is not None:
        plt.xlabel(f'Epoch: {epoch}, Loss: {loss.item():.4f}', fontsize=16)
    plt.show()
from torch_geometric.utils import to_networkx
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

在这里插入图片描述

四、定义网络结构

定义一个简单的GNN模型,包含三个GCNconv和一个分类层,激活函数选择tanh,将GCNConv结果和最后分类的结果都做返回。

import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv

class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(1234)
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)
    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()
        out = self.classifier(h)
        return out, h
model = GCN()
print(model)

先查看一下未训练模型对于数据处理情况:

model = GCN()
_,h = model(data.x, data.edge_index)
visualize_embedding(h, color=data.y)

在这里插入图片描述

五、训练模型

训练过程大同小异了,定义模型,定义优化器,定义损失函数,计算损失,反向传播更新参数,训练400轮,直到训练完成。可视化GCNconv处理之后的特征。

import time
model = GCN()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

def train(data):
    optimizer.zero_grad()
    out, h = model(data.x, data.edge_index)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss, h
for epoch in range(401):
    loss, h = train(data)
    if epoch % 10 == 0:
        visualize_embedding(h, color=data.y, epoch=epoch, loss=loss)
        time.sleep(0.3)

在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/917113.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

173. 二叉搜索树迭代器【 力扣(LeetCode) 】

文章目录 零、原题链接一、题目描述二、测试用例三、解题思路四、参考代码 零、原题链接 173. 二叉搜索树迭代器 一、题目描述 实现一个二叉搜索树迭代器类BSTIterator &#xff0c;表示一个按中序遍历二叉搜索树&#xff08;BST&#xff09;的迭代器&#xff1a; BSTIterato…

【lamafactory BLEU ROUGLE L评测】

1、BLEU/ROUGLE评测界面 2、这个是用BLEU 和ROUGL来评测 目录&#xff1a;saves/Qwen2-7B-Chat/lora/eval_2024-11-14-16-28-19/ 在saves文件夹 生成的文件如下 all_results.json文件 说明模型在这个测试集上是不好的 3、可以查看预测结果的文件 predict_result.json

前端开发中常用的包管理器(npm、yarn、pnpm、bower、parcel)

文章目录 1. npm (Node Package Manager)2. Yarn (Yarn Package Manager)3. pnpm4. Bower5. Parcel总结 前端开发中常用的包管理器主要有以下几个&#xff1a; 1. npm (Node Package Manager) 简介&#xff1a; npm 是 Node.js 的默认包管理器&#xff0c;也是最广泛使用的包…

Python爬虫项目 | 一、网易云音乐热歌榜歌曲

文章目录 1.文章概要1.1 实现方法1.2 实现代码1.3 最终效果 2.具体讲解2.1 使用的Python库2.2 代码说明2.2.1 创建目录保存文件2.2.2 爬取网易云音乐热歌榜单歌曲 2.3 过程展示 3 总结 1.文章概要 学习Python爬虫知识&#xff0c;实现简单的一个小案例&#xff0c;网易云音乐热…

SHELL脚本(Linux)

声明 学习视频来自 B 站UP主泷羽sec&#xff0c;如涉及侵权马上删除文章。 笔记的只是方便各位师傅学习知识&#xff0c;以下网站只涉及学习内容&#xff0c;其他的都与本人无关&#xff0c;切莫逾越法律红线&#xff0c;否则后果自负。 ✍&#x1f3fb;作者简介&#xff1a;致…

reduce-scatter:适合分布式计算;Reduce、LayerNorm和Broadcast算子的执行顺序对计算结果的影响,以及它们对资源消耗的影响

目录 Gather Scatter Reduce reduce-scatter:适合分布式计算 Reduce、LayerNorm和Broadcast算子的执行顺序对计算结果的影响,以及它们对资源消耗的影响 计算结果理论正确性 资源消耗方面 Gather 这个也很好理解,就是把多个进程的数据拼凑在一起。 Scatter 不同于Br…

arkUI:水果选择与管理:基于 ArkUI 的长按编辑功能实现

水果选择与管理&#xff1a;基于 ArkUI 的长按编辑功能实现 1 主要内容说明2 相关内容2.1 相关内容2.1.1 源码1内容的相关说明2.1.1.1 数据结构与状态管理2.1.1.2 添加水果功能2.1.1.3 水果列表展示2.1.1.4 长按进入编辑模式2.1.1.5 复选框的多选功能2.1.1.6 删除水果功能2.1.1…

(时序论文阅读)TimeMixer: Decomposable Multiscale Mixing for Time Series Forecasting

来源论文iclr2024 论文地址&#xff1a;https://arxiv.org/abs/2405.14616 源码地址&#xff1a; https://github.com/kwuking/TimeMixer 背景 数据是有连续性&#xff0c;周期性&#xff0c;趋势性的。我们这篇文章主要围绕的是用MLP结构来预测数据的周期性具体为&#xff…

聊天服务器(8)用户登录业务

目录 登录状态业务层代码数据模型层代码记录用户的连接信息以及线程安全问题客户端异常退出业务 登录状态 登录且状态变为online 业务层代码 #include "chatservice.hpp" #include "public.hpp" #include <string> #include <muduo/base/Loggi…

18.UE5怪物视野、AI感知、攻击范围、散弹技能

2-20 怪物视野、AI感知、攻击范围、散弹技能_哔哩哔哩_bilibili 目录 1.AI感知组件 2.AI感知更新的函数 3.攻击范围 4.散弹技能 4.1创建发射物 4.2创建远程攻击方式 4.3散弹自定义事件的实现 4.4动画通知实现攻击 1.AI感知组件 为怪物蓝图添加AI感知组件&#xff0c;…

单片机智能家居火灾环境安全检测

目录 前言 一、本设计主要实现哪些很“开门”功能&#xff1f; 二、电路设计原理图 电路图采用Altium Designer进行设计&#xff1a; 三、实物设计图 四、程序源代码设计 五、获取资料内容 前言 在现代社会&#xff0c;火灾安全始终是人们关注的重点问题。随着科技的不…

安全见闻1-5

涵盖了编程语言、软件程序类型、操作系统、网络通讯、硬件设备、web前后端、脚本语言、病毒种类、服务器程序、人工智能等基本知识&#xff0c;有助于全面了解计算机科学和网络技术的各个方面。 安全见闻1 1.编程语言简要概述 C语言&#xff1a;面向过程&#xff0c;适用于系统…

【计算机网络】TCP协议特点3

心跳机制 什么是心跳机制 心跳机制是在计算机系统、网络通信和许多其他技术领域广泛应用的一种机制&#xff0c;用于检测两个实体之间的连接是否仍然活跃&#xff0c;或者设备是否还在正常运行。就是每隔一段时间发送一个固定的消息给服务端&#xff0c;服务端回复一个固定…

HarmonyOs DevEco Studio小技巧31--画布组件Canvas

那天我们用画布实现了文字颜色的渐变&#xff0c;实际上画布还有很多好玩的功能&#xff0c;接下来让我们一起试一下画布怎么玩 Canvas 提供画布组件&#xff0c;用于自定义绘制图形。 接口 Canvas Canvas(context?: CanvasRenderingContext2D | DrawingRenderingContext…

蓝桥杯每日真题 - 第10天

题目&#xff1a;&#xff08;班级活动&#xff09; 题目描述&#xff08;14届 C&C B组C题&#xff09; 解题思路&#xff1a; 题目要求我们找到最小的修改次数&#xff0c;使得任意一名同学的 ID 都与另一名同学的 ID 不相同。可以总结出这是一个寻找“重复数字最少修改…

HTML5实现俄罗斯方块小游戏

文章目录 1.设计来源1.1 主界面1.2 皮肤风格1.2 游戏中界面1.3 游戏结束界面 2.效果和源码2.1 动态效果2.2 源代码 源码下载 作者&#xff1a;xcLeigh 文章地址&#xff1a;https://blog.csdn.net/weixin_43151418/article/details/143788449 HTML5实现俄罗斯方块小游戏&#x…

外网访问 WebDav 服务

从外部网络环境&#xff08;比如异地和家中网络&#xff09;来访问公司内网的 WebDav 服务&#xff08;基于 IIS &#xff09;并映射成本地虚拟磁盘。 步骤如下 第一步 在公司内网的电脑上设置 webDav。 1&#xff0c;找到【控制面板】&#xff0c;双击进入。 2&#xff0c…

深度学习每周学习总结J5(DenseNet-121 +SE 算法实战与解析 - 猴痘识别)

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 | 接辅导、项目定制 0. 总结 数据导入及处理部分&#xff1a;本次数据导入没有使用torchvision自带的数据集&#xff0c;需要将原始数据进行处理包括数据导入…

【玩具蛇——DFS】

题目 代码 #include <bits/stdc.h> using namespace std; int g[5][5]; int dx[] {0, 0, -1, 1}, dy[] {-1, 1, 0, 0}; int ans; void dfs(int x, int y, int t) {g[x][y] t;if (t > 16){ans;g[x][y] 0;return;}for (int i 0; i < 4; i){int nx x dx[i], n…

STM32 串口输出调试信息

软硬件信息 CubeMX version 6.12.1Keil uVision V5.41.0.0 注意 串口有多种&#xff1a; TTL232485 串口的相关知识&#xff1a; 01-【HAL库】STM32实现串口打印&#xff08;printf方式) &#xff0c; 内含 TTL 和 232 区别。 我把 232 串口连进 STM32 串口助手收到的信息…