图神经网络:在自定义数据集上动手实现图神经网络

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。

文章目录

  • 自定义数据集动手实现图神经网络
    • 自定义数据集
    • 训验测集拆分,创建Data的数据结构,观察Data的基本信息,可视化图网络
    • 搭建模型,训练前的准备,训练模型得出结果并可视化
  • 结果分析
  • 完整代码
  • 后记

自定义数据集动手实现图神经网络

自定义数据集

导库

from random import randint,sample

数据背景描述:这段代码生成北京化工大学三个学院(国际教育学院(100个学生),数理学院(300个学生),信息学院(500个学生))共计900个学生社交网络。每个学生,学院内部随机认识随机个人;学院外部随机认识随机个人,从而搭建边的关系。具体如何随机只能请看代码,鉴于篇幅原因不再过多赘述。

class dataset:

    def __init__(self):
        self.data_x=[];self.data_y=[]
        for i in range(100):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(0)
        for i in range(100,400):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(1)
        for i in range(400,900):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(2)
        self.data_edge=[[],[]]
        lt1=[i for i in range(100)];lt2=[i for i in range(100,400)];lt3=[i for i in range(400,900)]
        lt4=lt2+lt3;lt5=lt1+lt3;lt6=lt1+lt2;lt7=lt1+lt2+lt3
        for i in range(100):
            j=randint(30,70)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[:i]+lt7[i+1:100],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt4,j))
        for i in range(100,400):
            j=randint(50,100)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[100:i]+lt7[i+1:400],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt5,j))
        for i in range(400,900):
            j=randint(75,125)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[400:i]+lt7[i+1:],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt6,j))
    
    def x(self):
        return self.data_x
    
    def edge(self):
        return self.data_edge
    
    def y(self):
        return self.data_y

训验测集拆分,创建Data的数据结构,观察Data的基本信息,可视化图网络

导库

from torch_geometric.data import Data
import torch

训验测集拆分:对数据集随机拆分。train_val_test_split中,a代表训练集比例,b代表验证集比例,所以测试集比例为1-a-b啦。c代表了至少保证训练集中每个类别至少c个,否则一直循环随机去构建训练集,直到满足每个类别至少c个。最后返回三个长度为900的布尔列表吧。

def caozuo(lt):
    lt_new=[False for i in range(900)]
    for i in lt:
        lt_new[i]=True
    return lt_new

def T_or_F(lt,d):
    a,b,c=0,0,0
    for i in lt:
        if i<100:
            a+=1
        elif i<400:
            b+=1
        elif i<900:
            c+=1
        if a>=d and b>=d and c>=d:
            return True
    return False

def train_val_test_split(a,b,c):
    while True:
        lt=sample([i for i in range(900)],900)
        train_index=lt[:int(900*a)];val_index=lt[int(900*a):int(900*a+900*b)];test_index=lt[int(900*a+900*b):]
        if T_or_F(train_index,c)==False:
            continue
        else:
            return caozuo(train_index),caozuo(val_index),caozuo(test_index)
data=dataset();x=torch.Tensor(data.x());edge=torch.LongTensor(data.edge());y=torch.LongTensor(data.y());lt1,lt2,lt3=train_val_test_split(0.1,0.3,5)

创建Data的数据结构,观察Data的基本信息:简单说明一下 x x x 吧 , x : 900 × 900 x:900 \times 900 x:900×900 单位阵,表示第i个节点在第i维度。这个特征矩阵不算是通常意义上特征矩阵,信息有限。

data=Data(x=x,edge_index=edge,y=y,train_mask=torch.BoolTensor(lt1),val_mask=torch.BoolTensor(lt2),test_mask=torch.BoolTensor(lt3));print(data)
#输出:Data(x=[900, 900], edge_index=[2, 81246], y=[900], train_mask=[900], val_mask=[900], test_mask=[900])

导库

from torch_geometric.utils import to_networkx
import matplotlib.pyplot as plt
import networkx as nx

可视化图网络

def visualize_graph(G,color):
    plt.style.use("seaborn");plt.rcParams['font.family']='SimHei';plt.rcParams['font.sans-serif']=['SimHei']
    plt.figure(figsize=(16,9));plt.title("国教,数理,信息学院学生社交网络",size=20);plt.xticks([]);plt.yticks([])
    nx.draw_networkx(G,pos=nx.spring_layout(G),with_labels=False,node_size=3,node_color=color,width=0.01,edge_color="black",cmap="Set2")
    legend_dict={"red":"国教","blue": "数理","green":"信息"};plt.legend(handles=[plt.Line2D([],[],color=c,label=l,linestyle='None',marker='o') for c, l in zip(legend_dict.keys(),legend_dict.values())],loc='upper right')
    plt.savefig("figure",dpi=1000)
G=to_networkx(data,to_undirected=True);visualize_graph(G,color=["red"]*100+["blue"]*300+["green"]*500)

在这里插入图片描述

搭建模型,训练前的准备,训练模型得出结果并可视化

导库

from torch_geometric.nn import GCNConv
import torch.nn.functional as F

搭建模型

class GCN(torch.nn.Module):
    def __init__(self,hidden_channels):
        super().__init__()
        self.conv1=GCNConv(900,hidden_channels)
        self.conv2=GCNConv(hidden_channels,3)
    def forward(self,x,edge_index):
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.conv2(x,edge_index)
        return x

训练前的准备

model=GCN(9);optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4);criterion=torch.nn.CrossEntropyLoss()

训练模型得出结果并可视化

def train():
    model.train()
    optimizer.zero_grad()
    out=model(data.x,data.edge_index)
    loss=criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

def test():
    model.eval()
    out=model(data.x,data.edge_index)
    pred=out.argmax(dim=1)
    test_correct=pred[data.test_mask]==data.y[data.test_mask]
    test_acc=int(test_correct.sum())/int(data.test_mask.sum())
    return test_acc
for epoch in range(1,101):
    loss=train()
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
test_acc=test()
print(f'Test Accuracy: {test_acc:.4f}')
#这里就只展示测试集的结果
#输出:1.00

损失函数下降曲线
在这里插入图片描述

T-SNE可视化
模型未训练只经过一次正向传播
在这里插入图片描述

训练好后
在这里插入图片描述

结果分析

模型在测试集表现很好为1,那么效果为什么会这么好呢。数据集就很好,创建数据集的时候忽略了很多的现实因素,比如,有的同学它是社交恐怖分子认识全校80%的人,或者有的同学刚刚转院等等现实情况,所以这里创建的数据集十分理想,加之模型本身很好,最终导致了很好的结果。我们可以在数据集构建几个转院同学看看结果但是我就不做了吧。

完整代码

from torch_geometric.utils import to_networkx
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from random import randint,sample
from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import torch.nn.functional as F
import networkx as nx
import numpy as np
import torch

def train():
    model.train()
    optimizer.zero_grad()
    out=model(data.x,data.edge_index)
    loss=criterion(out[data.train_mask],data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss

def test():
    model.eval()
    out=model(data.x,data.edge_index)
    pred=out.argmax(dim=1)
    test_correct=pred[data.test_mask]==data.y[data.test_mask]
    test_acc=int(test_correct.sum())/int(data.test_mask.sum())
    return test_acc

def visualize(h,color,s):
    z=TSNE(n_components=2).fit_transform(h.detach().cpu().numpy())
    plt.style.use("seaborn");plt.figure(figsize=(16,9));plt.xticks([]);plt.yticks([])
    plt.scatter(z[:,0],z[:,1],s=3,c=color)
    plt.savefig("figure"+str(s),dpi=1000)

class GCN(torch.nn.Module):
    def __init__(self,hidden_channels):
        super().__init__()
        self.conv1=GCNConv(900,hidden_channels)
        self.conv2=GCNConv(hidden_channels,3)
    def forward(self,x,edge_index):
        x=self.conv1(x,edge_index)
        x=x.relu()
        x=F.dropout(x,p=0.5,training=self.training)
        x=self.conv2(x,edge_index)
        return x
    
def visualize_graph(G,color):
    plt.style.use("seaborn");plt.rcParams['font.family']='SimHei';plt.rcParams['font.sans-serif']=['SimHei']
    plt.figure(figsize=(16,9));plt.title("国教,数理,信息学院学生社交网络",size=20);plt.xticks([]);plt.yticks([])
    nx.draw_networkx(G,pos=nx.spring_layout(G),with_labels=False,node_size=3,node_color=color,width=0.01,edge_color="black",cmap="Set2")
    legend_dict={"red":"国教","blue": "数理","green":"信息"};plt.legend(handles=[plt.Line2D([],[],color=c,label=l,linestyle='None',marker='o') for c, l in zip(legend_dict.keys(),legend_dict.values())],loc='upper right')
    plt.savefig("figure1",dpi=1000)

def caozuo(lt):
    lt_new=[False for i in range(900)]
    for i in lt:
        lt_new[i]=True
    return lt_new

def T_or_F(lt,d):
    a,b,c=0,0,0
    for i in lt:
        if i<100:
            a+=1
        elif i<400:
            b+=1
        elif i<900:
            c+=1
        if a>=d and b>=d and c>=d:
            return True
    return False

def train_val_test_split(a,b,c):
    while True:
        lt=sample([i for i in range(900)],900)
        train_index=lt[:int(900*a)];val_index=lt[int(900*a):int(900*a+900*b)];test_index=lt[int(900*a+900*b):]
        if T_or_F(train_index,c)==False:
            continue
        else:
            return caozuo(train_index),caozuo(val_index),caozuo(test_index)

class dataset:

    def __init__(self):
        self.data_x=[];self.data_y=[]
        for i in range(100):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(0)
        for i in range(100,400):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(1)
        for i in range(400,900):
            lt=[0 for i in range(900)]
            lt[i]=1
            self.data_x.append(lt)
            self.data_y.append(2)
        self.data_edge=[[],[]]
        lt1=[i for i in range(100)];lt2=[i for i in range(100,400)];lt3=[i for i in range(400,900)]
        lt4=lt2+lt3;lt5=lt1+lt3;lt6=lt1+lt2;lt7=lt1+lt2+lt3
        for i in range(100):
            j=randint(30,70)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[:i]+lt7[i+1:100],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt4,j))
        for i in range(100,400):
            j=randint(50,100)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[100:i]+lt7[i+1:400],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt5,j))
        for i in range(400,900):
            j=randint(75,125)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt7[400:i]+lt7[i+1:],j))
            j=randint(0,10)
            for k in range(j):
                self.data_edge[0].append(i)
            self.data_edge[1].extend(sample(lt6,j))
    
    def x(self):
        return self.data_x
    
    def edge(self):
        return self.data_edge
    
    def y(self):
        return self.data_y
    
if __name__=="__main__":
    data=dataset();x=torch.Tensor(data.x());edge=torch.LongTensor(data.edge());y=torch.LongTensor(data.y());lt1,lt2,lt3=train_val_test_split(0.1,0.3,5)
    data=Data(x=x,edge_index=edge,y=y,train_mask=torch.BoolTensor(lt1),val_mask=torch.BoolTensor(lt2),test_mask=torch.BoolTensor(lt3));print(data)
    G=to_networkx(data,to_undirected=True);visualize_graph(G,color=["red"]*100+["blue"]*300+["green"]*500)
    model=GCN(9);optimizer=torch.optim.Adam(model.parameters(),lr=0.01,weight_decay=5e-4);criterion=torch.nn.CrossEntropyLoss()
    model.eval();out=model(data.x,data.edge_index);visualize(out,color=["red"]*100+["blue"]*300+["green"]*500,s=2)
    lt=[]
    for epoch in range(1,101):
        loss=train();lt.append(loss.item())
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}')
    test_acc=test()
    print(f'Test Accuracy: {test_acc:.4f}')
    plt.style.use("seaborn");plt.rcParams['font.family']='SimHei';plt.rcParams['font.sans-serif']=['SimHei']
    plt.figure(figsize=(16,9));plt.title("损失函数下降曲线",size=20)
    plt.plot([i for i in range(len(lt))],lt,marker="o",ms=3,color="red",linewidth=1,label="交叉熵")
    plt.xticks([i for i in range(0,len(lt),10)],[i+1 for i in range(0,len(lt),10)]);plt.legend();plt.savefig("figure3",dpi=1000)
    model.eval();out=model(data.x,data.edge_index);visualize(out,color=["red"]*100+["blue"]*300+["green"]*500,s=4)

后记

应该会再写一篇文章吧,我们从底层实现PYG,就不导库了吧,然后具体讲讲GCN是怎么操作。然后,我们结束图神经网络在无向图中节点分类这个话题。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/17447.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

震惊,为了学会泛型类竟做这种事?!

上一节&#xff0c;我们基本学会了Java泛型类的用法。 传送门&#xff1a;彻底弄懂Java的泛型 - 泛型类 这一节&#xff0c;我们转变一下风格&#xff0c;具体是什么风格呢&#xff0c;你马上就懂了。 宝子们&#xff0c;欢迎大家来到我们的泛型直播间&#xff0c;这一讲呢&a…

Ansible的脚本-playbook 剧本

目录 1.剧本&#xff08;playbook&#xff09; 1.playbook介绍 2. playbooks 的组成 3.案例&#xff1a;编写httpd的playbook 4.定义、引用变量 5.指定远程主机sudo切换用户 6.when条件判断 7.迭代 2.playbook的模块 1.Templates 模块 2.tags 模块 3.Roles 模块 1.…

【Linux从入门到精通】vim的基本使用各种操作详解

文章目录 一、vim编辑器简单介绍 二、vim编辑器的四种模式 2、1 正常/普通/命令模式(Normal mode) 2、2 插入模式(Insert mode) 2、3 末行模式(last line mode) 三、命令模式的相关操作实例 3、1 光标的相关操作 3、2 文本操作 四、插入模式下的相关操作 五、末行模式下的相关操…

Java—JDK8新特性—函数式接口

目录 函数式接口 3.1 什么是函数式接口 3.2 functionalinterface注解 源码分析 3.3 Lambda表达式和函数式接口关系 3.4 使用函数式接口 函数式接口 3.1 什么是函数式接口 如果一个接口中只包含一个抽象方法&#xff0c;这个接口称为函数式接口 如果一个接口包含&#xff0…

K8S管理系统项目实战[API开发]-2

后端: gogin 后端代码地址GitHub - yunixiangfeng/k8s-platform: K8s管理系统后端: gogin 5、存储与配置 5.1 ConfigMap 5.2 Secret 5.3 PersistentVolumeClaims 6、工作流 6.1 流程设计 6.2 数据库操作&#xff08;GORM&#xff09; &#xff08;1&#xff09;初始化…

交换机-Exchanges

交换机 Exchanges 概念 RabbitMQ 消息传递模型的核心思想是: 生产者生产的消息从不会直接发送到队列。实际上&#xff0c;通常生产者甚至都不知道这些消息传递传递到了哪些队列中。相反&#xff0c;生产者只能将消息发送到交换机(exchange)&#xff0c;交换机工作的内容非常简…

正则表达式-基本元字符和语法规则

© Ptw-cwl 文章目录 字符匹配元字符.元字符[]元字符[^]元字符*元字符元字符?元字符{}元字符|元字符()元字符^元字符$元字符\元字符\d元字符\w元字符\s元字符\b元字符\B元字符*?、?、??、{n,m}?元字符(?)、(?!)元字符(?:)元字符\1、\2等元字符^、$元字符&#x…

JavaSE基础(二)—— 类型转换、运算符、键盘录入

目录 一、类型转换 1. 自动类型转换 1.1 自动类型转换的底层原理&#xff1a; ​1.2 自动类型转换的其他形式​编辑 2. 表达式的自动类型转换 3. 强制类型转换 3.1 强制类型转换底层原理​编辑 3.2 注意事项 二、运算符 1. 算数运算符 1.1 案例&#xff1a;数值拆分…

PCA主成成分分析例题详解

主成分分析是一种降维算法&#xff0c;它能将多个指标转换为少数几个主成分&#xff0c;这些主成分是原始变量的线性组合&#xff0c;且彼此之间互不相关&#xff0c;其能反映出原始数据的大部分信息 需要了解具体细节可看此视频&#x1f449;&#xff1a;什么是主成成分分析PC…

Linux安装MongoDB数据库,并内网穿透远程连接

文章目录 前言1. 配置Mongodb源2. 安装MongoDB3. 局域网连接测试4. 安装cpolar内网穿透5. 配置公网访问地址6. 公网远程连接7. 固定连接公网地址8. 使用固定地址连接 转载自Cpolar Lisa文章&#xff1a;Linux服务器安装部署MongoDB数据库 - 无公网IP远程连接「内网穿透」 前言 …

SpringBoot访问静态资源

SpringBoot项目中没有WebApp目录&#xff0c;只有src目录。在src/main/resources下面有static和templates两个文件夹。SpringBoot默认在static目录中存放静态资源&#xff0c;而templates中放动态页面。 static目录 SpringBoot通过/resources/static目录访问静态资源&#xff…

完成A轮融资,倍思如何发力场景化为品牌创造广阔未来?

凛冬过后的消费电子正在重新凝聚资本的目光。 近日&#xff0c;深圳市倍思科技有限公司宣布完成由深创投、中金资本联合领投&#xff0c;越秀产业基金、高榕资本跟投&#xff0c;金额数亿元人民币的A轮融资。 分析人士指出&#xff0c;消费电子的行业景气度在逐渐恢复&#x…

中国社科院与美国杜兰大学金融管理硕士项目——迎接立夏,切莫忘记自我成长

五月的风吹走了春季&#xff0c;今天我们迎来立夏。作为夏季的第一个节气&#xff0c;立夏常被人们当做万物蓄满能量&#xff0c;即将加速生长的标志。而在职的我们&#xff0c;也应该跟这世间万物一样&#xff0c;在季节交替之时沉淀自己、努力向上成长。在社科院与杜兰大学金…

“人工智能教父”从谷歌离职 称后悔发展AI,为世人敲响警钟?

在加入谷歌的第十年、深度学习迎来爆发式发展的当下&#xff0c;被誉为“人工智能教父”的Geoffrey Hinton已从谷歌离职&#xff0c;只是为了告诫人们AI已经变得很危险。 公开资料显示&#xff0c;Geoffrey Hinton在2013年加入谷歌&#xff0c;曾任副总裁&#xff0c;研究机器学…

成为数据分析师,需要具备哪些技能?

随着互联网的发展&#xff0c;数据分析师的特点越来越明显&#xff0c;对数据分析师综合素质的要求也较高。 1、较强的数据挖掘、信息整理、和逻辑分析能力 数据分析&#xff0c;也是数据分析师的一个方向。 制作日常性的经营报表&#xff0c;对公司或者行业KPI指标进行拆解…

Mysql索引(3):索引分类

1 索引分类 在MySQL数据库&#xff0c;将索引的具体类型主要分为以下几类&#xff1a;主键索引、唯一索引、常规索引、全文索引。 分类含义特点关键字主键索引针对于表中主键创建的索引 默认自动创建, 只能有一个 PRIMARY 唯一索引 避免同一个表中某数据列中的值重复可以有多…

【Android入门到项目实战-- 8.4】—— 如何解析JSON格式数据

目录 一、准备工作 二、使用JSONObject 三、使用GSON 比起XML&#xff0c;JSON的主要优势在于它的体积更小&#xff0c;在网络上传输的时候可以更省流量&#xff0c;但缺点是语义性较差&#xff0c;看起来不直观。 一、准备工作 还是使用前面文章的方法&#xff0c;在服务器…

每日学术速递4.29

CV - 计算机视觉 | ML - 机器学习 | RL - 强化学习 | NLP 自然语言处理 Subjects: cs.LG 1.A Cookbook of Self-Supervised Learning 标题&#xff1a;自监督学习食谱 作者&#xff1a;Randall Balestriero, Mark Ibrahim, Vlad Sobal, Ari Morcos, Shashank Shekhar, Tom…

【黑马程序员 C++教程从0到1入门编程】【笔记8】 泛型编程——模板

https://www.bilibili.com/video/BV1et411b73Z?p167 C泛型编程是一种编程范式&#xff0c;它的核心思想是编写通用的代码&#xff0c;使得代码可以适用于多种不同的数据类型。 而模板是C中实现泛型编程的一种机制&#xff0c;它允许我们编写通用的代码模板&#xff0c;然后在需…

【Spring篇】IOC/DI注解开发

&#x1f353;系列专栏:Spring系列专栏 &#x1f349;个人主页:个人主页 目录 一、IOC/DI注解开发 1.注解开发定义bean 2.纯注解开发模式 1.思路分析 2.实现步骤 3.注解开发bean作用范围与生命周期管理 1.环境准备 2.Bean的作用范围 3.Bean的生命周期 4.注解开发依赖…