【CS224W】(task12)GAT GNN training tips

note

  • GAT使用attention对线性转换后的节点进行加权求和:利用自身节点的特征向量分别和邻居节点的特征向量,进行内积计算score。
  • 异质图的消息传递和聚合: h v ( l + 1 ) = σ ( ∑ r ∈ R ∑ u ∈ N v r 1 c v , r W r ( l ) h u ( l ) + W 0 ( l ) h v ( l ) ) \mathbf{h}_v^{(l+1)}=\sigma\left(\sum_{r \in R} \sum_{u \in N_v^r} \frac{1}{c_{v, r}} \mathbf{W}_r^{(l)} \mathbf{h}_u^{(l)}+\mathbf{W}_0^{(l)} \mathbf{h}_v^{(l)}\right) hv(l+1)=σ rRuNvrcv,r1Wr(l)hu(l)+W0(l)hv(l)

文章目录

  • note
  • 一、GAT model
  • 二、GNN模型训练要点
    • 1. Graph Manipulation
    • 2. GNN training
      • (1)Node-level
      • (2)Edge-level
      • (3)Graph-level
    • 3. Issue of Global pooling
      • (1)Global pooling的毛病
      • (2)DidffPool 社群分层池化:
  • 三、GNN training tips
    • 3.1 Spliting Graphs is special
    • 3.2 异质图 Heterogeneous graph
  • 附:时间安排
  • Reference

一、GAT model

图注意神经网络(GAT)来源于论文 Graph Attention Networks。其数学定义为,
x i ′ = α i , i Θ x i + ∑ j ∈ N ( i ) α i , j Θ x j , \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j}, xi=αi,iΘxi+jN(i)αi,jΘxj,
GAT和所有的attention mechanism一样,GAT的计算也分为两步走:
(1)计算注意力系数(attention coefficient):(下图来自《GRAPH ATTENTION NETWORKS》)
其中注意力系数 α i , j \alpha_{i,j} αi,j的计算方法为,
α i , j = exp ⁡ ( L e a k y R e L U ( a ⊤ [ Θ x i   ∥   Θ x j ] ) ) ∑ k ∈ N ( i ) ∪ { i } exp ⁡ ( L e a k y R e L U ( a ⊤ [ Θ x i   ∥   Θ x k ] ) ) . \alpha_{i,j} = \frac{ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j] \right)\right)} {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top} [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k] \right)\right)}. αi,j=kN(i){i}exp(LeakyReLU(a[ΘxiΘxk]))exp(LeakyReLU(a[ΘxiΘxj])).
在这里插入图片描述

(2)加权求和(aggregate):根据(1)的系数,把特征加权求和(aggregate) h v ( l ) = σ ( ∑ u ∈ N ( v ) α v u W ( l ) h u ( l − 1 ) ) \mathbf{h}_v^{(l)}=\sigma\left(\sum_{u \in N(v)} \alpha_{v u} \mathbf{W}^{(l)} \mathbf{h}_u^{(l-1)}\right) hv(l)=σ uN(v)αvuW(l)hu(l1)

二、GNN模型训练要点

1. Graph Manipulation

在这里插入图片描述

  • feature manipulation:feature augmentation, such as we can use cycle count as augmented node features
  • struture manipulation:
    • sparse graph: add virtual nodes or edges
    • dense graph: sample neighbors when doing message passing
    • large graph: sample subgraphs to compute embeddings

2. GNN training

在这里插入图片描述

(1)Node-level

  • After GNN computation, we have d d d-dim node
     embeddings:  { h v ( L ) ∈ R d , ∀ v ∈ G } \text { embeddings: }\left\{\mathbf{h}_v^{(L)} \in \mathbb{R}^d, \forall v \in G\right\}  embeddings: {hv(L)Rd,vG}
  • such as k-way prediction:
  • y ^ v = Head ⁡ node  ( h v ( L ) ) = W ( H ) h v ( L ) \widehat{\boldsymbol{y}}_v=\operatorname{Head}_{\text {node }}\left(\mathbf{h}_v^{(L)}\right)=\mathbf{W}^{(H)} \mathbf{h}_v^{(L)} y v=Headnode (hv(L))=W(H)hv(L)
    • W ( H ) ∈ R k ∗ d \mathbf{W}^{(H)} \in \mathbb{R}^{k * d} W(H)Rkd : We map node embeddings from h v ( L ) ∈ R d \mathbf{h}_v^{(L)} \in \mathbb{R}^d hv(L)Rd to y ^ v ∈ R k \widehat{y}_v \in \mathbb{R}^k y vRk
    • compute the loss

(2)Edge-level

  • use pairs of node embeddings
  • such as k-way prediction: y ^ u v = Head ⁡ edge  ( h u ( L ) , h v ( L ) ) \widehat{\boldsymbol{y}}_{u v}=\operatorname{Head}_{\text {edge }}\left(\mathbf{h}_u^{(L)}, \mathbf{h}_v^{(L)}\right) y uv=Headedge (hu(L),hv(L))
    • Concatenation + Linear: y ^ u v = Linear ⁡ ( Concat ⁡ ( h u ( L ) , h v ( L ) ) ) \hat{\boldsymbol{y}}_{u v}=\operatorname{Linear}\left(\operatorname{Concat}\left(\mathbf{h}_u^{(L)}, \mathbf{h}_v^{(L)}\right)\right) y^uv=Linear(Concat(hu(L),hv(L))),and Linear ⁡ \operatorname{Linear} Linear can map 2d-dim embeddings to k-dim embeddings
    • Dot product : y ^ u v = ( h u ( L ) ) T h v ( L ) \hat{\boldsymbol{y}}_{\boldsymbol{u} v}=\left(\mathbf{h}_u^{(L)}\right)^T \mathbf{h}_v^{(L)} y^uv=(hu(L))Thv(L)
      • this approach only applies to 1-way prediction(预测边是否存在)
      • k-way prediction:
      • 在这里插入图片描述

(3)Graph-level

  • use all the node embeddings in our graph
  • such as k-way prediction: y ^ G = Head ⁡ graph ⁡ ( { h v ( L ) ∈ R d , ∀ v ∈ G } ) \widehat{\boldsymbol{y}}_G=\operatorname{Head}_{\operatorname{graph}}\left(\left\{\mathbf{h}_v^{(L)} \in \mathbb{R}^d, \forall v \in G\right\}\right) y G=Headgraph({hv(L)Rd,vG})
    • Head ⁡ graph ⁡ \operatorname{Head}_{\operatorname{graph}} Headgraph ≈ AGG(`) in a GNN layer
    • Gloal pooling:use Gloal mean or max or sum pooling instead of Head ⁡ graph ⁡ \operatorname{Head}_{\operatorname{graph}} Headgraph

3. Issue of Global pooling

(1)Global pooling的毛病

  • Useing global pooling over a large graph will lose information
  • toy example(1-dim node embeddings):
    • Node embeddings for G 1 : { − 1 , − 2 , 0 , 1 , 2 } G_1:\{-1,-2,0,1,2\} G1:{1,2,0,1,2}, global sum pooling ans:0
    • Node embeddings for G 2 : { − 10 , − 20 , 0 , 10 , 20 } G_2:\{-10,-20,0,10,20\} G2:{10,20,0,10,20},global sum pooling ans:0
  • 特点:只看均值,不看方差
  • so we can use hierarchical pooling 分层池化
  • toy example:We will aggregate via ReLU ⁡ ( Sum ⁡ ( ⋅ ) ) \operatorname{ReLU}(\operatorname{Sum}(\cdot)) ReLU(Sum())
    • We first separately aggregate the first 2 nodes and last 3 nodes;Then we aggregate again to make the final prediction
    • G 1 G_1 G1 node embeddings: { − 1 , − 2 , 0 , 1 , 2 } \{-1,-2,0,1,2\} {1,2,0,1,2}
      • Round 1: y ^ a = ReLU ⁡ ( Sum ⁡ ( { − 1 , − 2 } ) ) = 0 , y ^ b = \hat{y}_a=\operatorname{ReLU}(\operatorname{Sum}(\{-1,-2\}))=0, \hat{y}_b= y^a=ReLU(Sum({1,2}))=0,y^b=
        ReLU ⁡ ( Sum ⁡ ( { 0 , 1 , 2 } ) ) = 3 \quad \operatorname{ReLU}(\operatorname{Sum}(\{0,1,2\}))=3 ReLU(Sum({0,1,2}))=3
      • Round 2:  ⁡ y ^ G = ReLU ⁡ ( Sum ⁡ ( { y a , y b } ) ) = 3 \operatorname{Round~2:~} \hat{y}_G=\operatorname{ReLU}\left(\operatorname{Sum}\left(\left\{y_a, y_b\right\}\right)\right)=3 Round 2: y^G=ReLU(Sum({ya,yb}))=3
    • G 2 G_2 G2 node embeddings: { − 10 , − 20 , 0 , 10 , 20 } \{-10,-20,0,10,20\} {10,20,0,10,20}
      • Round 1:  ⁡ y ^ a = ReLU ⁡ ( Sum ⁡ ( { − 10 , − 20 } ) ) = 0 , y ^ b = 2 = \operatorname{Round~1:~} \hat{y}_a=\operatorname{ReLU}(\operatorname{Sum}(\{-10,-20\}))=0, \hat{y}_b={ }^2= Round 1: y^a=ReLU(Sum({10,20}))=0,y^b=2=
        ReLU ⁡ ( Sum ⁡ ( { 0 , 10 , 20 } ) ) = 30 \quad \operatorname{ReLU}(\operatorname{Sum}(\{0,10,20\}))=30 ReLU(Sum({0,10,20}))=30
      • Round 2: ⁡ y ^ G = ReLU ⁡ ( Sum ⁡ ( { y a , y b } ) ) = 30 \operatorname{Round~2:} \hat{y}_G=\operatorname{ReLU}\left(\operatorname{Sum}\left(\left\{y_a, y_b\right\}\right)\right)=30 Round 2:y^G=ReLU(Sum({ya,yb}))=30

(2)DidffPool 社群分层池化:

在这里插入图片描述
每层(将每个社群当作一层,进行社群检测)利用两个独立的GNN层(可以联合训练):

  • GNN 1:计算节点embedding
  • GNN 2:计算一个节点属于的社群
  • 之前的图分类方法是先生成每个节点的embedding,对所有节点的embedding进行全局的pooling;而DidffPool(微分池化)通过逐渐压缩信息方式进行图分类,上一层GNN的节点进行聚类结果,作为下一层GNN的输入。

三、GNN training tips

3.1 Spliting Graphs is special

  • 像图片和文本分类的样本,每个数据样本之间满足独立同分布
  • 但GNN数据中不同节点可能会互相影响(消息传递)
    • transductive 直推式学习:
      • 划分数据集时,让图结构还是能看到,可以只根据节点label进行划分。在训练和验证阶段,都是使用全图信息,如下图,利用一二节点及其label进行训练,在验证阶段也是利用整图信息,利用三四节点及其label进行验证。
      • 只适合于节点or边分类任务
    • inductive 归纳式学习:
      • 拆分边,得到多重图
      • 适合于节点or边or图分类

在这里插入图片描述

3.2 异质图 Heterogeneous graph

异质图比同构图多了两个属性, R 、 T R 、 T RT, 其中 R R R 表示边的类型、 T T T 表示节点的类型, 最后整张图可以表示为:
G = ( V , E , R , T ) G=(V, E, R, T) G=(V,E,R,T)
同质图的聚合: h v ( l ) = σ ( ∑ u ∈ N ( v ) W ( l ) h u ( l − 1 ) ∣ N ( v ) ∣ ) \mathbf{h}_v^{(l)}=\sigma\left(\sum_{u \in N(v)} \mathbf{W}^{(l)} \frac{\mathbf{h}_u^{(l-1)}}{|N(v)|}\right) hv(l)=σ uN(v)W(l)N(v)hu(l1)
异质图的消息传递和聚合: h v ( l + 1 ) = σ ( ∑ r ∈ R ∑ u ∈ N v r 1 c v , r W r ( l ) h u ( l ) + W 0 ( l ) h v ( l ) ) \mathbf{h}_v^{(l+1)}=\sigma\left(\sum_{r \in R} \sum_{u \in N_v^r} \frac{1}{c_{v, r}} \mathbf{W}_r^{(l)} \mathbf{h}_u^{(l)}+\mathbf{W}_0^{(l)} \mathbf{h}_v^{(l)}\right) hv(l+1)=σ rRuNvrcv,r1Wr(l)hu(l)+W0(l)hv(l)
其中对于每种类型的边r,对应的邻居节点u,两节点之间传播的信息为: m u , r ( l ) = 1 c v , r W r ( l ) h u ( l ) \mathbf{m}_{u, r}^{(l)}=\frac{1}{c_{v, r}} \mathbf{W}_r^{(l)} \mathbf{h}_u^{(l)} mu,r(l)=cv,r1Wr(l)hu(l)

附:时间安排

任务任务内容截止时间注意事项
2月11日开始
task1图机器学习导论2月14日周二完成
task2图的表示和特征工程2月15、16日周四完成
task3NetworkX工具包实践2月17、18日周六完成
task4图嵌入表示2月19、20日周一完成
task5deepwalk、Node2vec论文精读2月21、22、23、24日周五完成
task6PageRank2月25、26日周日完成
task7标签传播与节点分类2月27、28日周二完成
task8图神经网络基础3月1、2日周四完成
task9图神经网络的表示能力3月3日周五完成
task10图卷积神经网络GCN3月4日周六完成
task11图神经网络GraphSAGE3月5日周七完成
task12图神经网络GAT3月6日周一完成

Reference

[1] https://docs.dgl.ai/en/0.8.x/generated/dgl.nn.pytorch.conv.GINConv.html?highlight=ginconv#dgl.nn.pytorch.conv.GINConv
[2] CS224W官网:https://web.stanford.edu/class/cs224w/index.html
[3] https://github.com/TommyZihao/zihao_course/tree/main/CS224W
[4] cs224w(图机器学习)2021冬季课程学习笔记18 Colab 4:异质图
[5] https://github.com/dmlc/dgl
[6] DIFFPOOL:一种图网络的分层池化方法
[7] https://relph1119.github.io/my-team-learning/#/cs224w_learning46/ext-task
[8] 【CS224W学习笔记 day09】 异质图神经网络

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/2284.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

第十八天 Vue-前端工程化总结

目录 Vue-前端工程化 1. 前后端分离开发 1.1 介绍 1.2 Yapi 2. 前端工程化 2.1 环境准备 2.2 Vue项目简介 2.3 Vue项目开发流程 3. Vue组件库Element 3.1 快速入门 3.2 常用组件 3.3 案例 Vue-前端工程化 前面我们已经讲解了HTML、CSS、JavaScript以及Vue等知识。已…

【粉丝投稿】上海某大厂的面试题,岗位是测开(25K*16)

简单介绍一句,大专出身,三年经验。跳了四次槽,面试了无数次,现在把自己的面试经验整理出来分享给大家,堪称必杀技! 1,一切从实际出发,对实际工作进行适当修饰 2,不会的简…

【进阶数据结构】平衡搜索二叉树 —— AVL树

🌈感谢阅读East-sunrise学习分享——[进阶数据结构]AVL树 博主水平有限,如有差错,欢迎斧正🙏感谢有你 码字不易,若有收获,期待你的点赞关注💙我们一起进步🚀 🌈我们上一篇…

学习Linux只要学会这个命令就够了!

大家好,我是良许。 这段时间又是搬家,又是找新办公室,现在终于安顿下来了,有时间给大家分享干货了。 今天给大家介绍一个 Linux 超级实用命令,有了这个命令,你就可以愉快使用 Linux 上几乎所有常用命令了…

【Unity入门】3D物体

【Unity入门】3D物体 大家好,我是Lampard~~ 欢迎来到Unity入门系列博客,所学知识来自B站阿发老师~感谢 (一)物体移动旋转缩放 (1)物体移动 在上一篇文章【Unity入门】场景视图操作我们学会了在场景中创建3…

Java现在好找工作吗?

Java到2023年已经28岁了,可能你会怀疑它是否还一如当年一样的强大,在应用层领域独占鳌头。但是基于Java庞大的市场占有率和需求,它依然在保持着更新迭代,依然是最常用的底层开发语言,基于其安全性、开放性、稳定性和跨…

springboot Aspect切面

问题描述 配置切面,但未切到目标类上 切面类 Component Aspect public class ControllerAspect {//Pointcut("execution(* com.yzk.learn.springbootsecurity.controller.UserController.info(..))")Pointcut("execution(* com.learn..*.controlle…

类ChatGPT开源项目的部署与微调:从LLaMA到ChatGLM-6B

前言 近期,除了研究ChatGPT背后的各种技术细节 不断看论文(至少100篇,100篇目录见此:ChatGPT相关技术必读论文100篇),还开始研究一系列开源模型(包括各自对应的模型架构、训练方法、训练数据、本地私有化部署、硬件配置要求、微调…

Java代码是如何被CPU狂飙起来的?

📣📣📣📣📣📣📣 🎍大家好,我是慕枫 🎍前阿里巴巴高级工程师,InfoQ签约作者、阿里云专家博主,一直致力于用大白话讲解技术知识 &#x…

安全防御之防火墙篇(二)

目录 1.防火墙如何处理双通道协议? 2.防火墙如何处理NAT? 3.防火墙支持哪些NAT技术,主要应用的场景是什么? 4.当内网PC通过公网域名解析访问内网服务器的时候,会存在什么问题,如何解决?请详细…

【MySQL】CentOS编译安装MySQL5.7实战

前言 这篇文章是关于MySQL编译安装的,重点掌握的是编译的过程,以及体会排错的痛苦。出错在所难免,最重要的是要有一颗不放弃的心。 本文收录于《数据库入门与精通》专栏, 本专栏写作的过程中,联合了csdn几位DBA大佬,…

SpringBoot整合Kafka(包含Kafka_2.12-3.3.1单节点安装,kafka可视化程序efak v3.0.1安装)

SpringBoot整合Kafka(包含Kafka_2.12-3.3.1单节点安装,kafka可视化程序efka v3.0.1安装)kafka、efak安装包下载kafka安装资源下载:下载tgz安装包:http://archive.apache.org/dist/kafka/ //解压 tar -zxvf /home/soft/…

自定义类型的超详细讲解ᵎᵎ了解结构体和位段这一篇文章就够了ᵎ

目录 1.结构体的声明 1.1基础知识 1.2结构体的声明 1.3结构体的特殊声明 1.4结构体的自引用 1.5结构体变量的定义和初始化 1.6结构体内存对齐 那对齐这么浪费空间,为什么要对齐 1.7修改默认对齐数 1.8结构体传参 2.位段 2.1什么是位段 2.2位段的内存分配…

【java】笔试强训Day1

⛳选择题 1.在 Java 中,存放字符串常量的对象属于 ( )类对象 A、Character B、String C、StringBuffer D、Vector 🙈大家觉得答案是什么呢 🙉答案是…

GPT-4发布:人工智能新高度,以图生文技术震撼,短时间内挤爆OpenAI模型付费系统

“GPT-4,起飞!”今日凌晨1点,OpenAI正式推出史上最强大的GPT-4文本生成AI系统 GPT-4:人工智能的新里程碑 你可能已经听说过GPT-3,它是一种能够生成自然语言文本的强大模型,可以用来回答问题、写文章、编程…

【Java SE】变量的本质

目录一. 前言二. 变量(variable)2.1 性质2.2 变量类型2.2.1 核心区别2.3 变量的使用三. 总结一. 前言 一天一个Java小知识点,助力小伙伴更好地入门Java,掌握更深层次的语法。 二. 变量(variable) 2.1 性质 变量本质上就是代表一个”可操作的存储空间”…

STL库中list的迭代器实现痛点分析

前文本篇文章准备换个模式,之前都是先详解模拟实现,但是模拟实现的基本逻辑大多数老铁都是明白的,所以我们这次主要讲解STL库中list的独特性,也就是模拟实现中的重难点文末有模拟实现的源码一,list实现的特殊类list实现…

【pytorch】使用deepsort算法进行目标跟踪,原理+pytorch实现

目录deepsort流程一、匈牙利算法二、卡尔曼滤波车速预测例子动态模型的概念卡尔曼滤波在deepsort中的动态模型三、预测值及测量值的含义deepsort在pytorch中的运行deepsort流程 DeepSORT是一种常用的目标跟踪算法,它结合了深度学习和传统的目标跟踪方法。DeepSORT的…

WireShark如何抓包,各种协议(HTTP、ARP、ICMP)的过滤或分析,用WireShark实现TCP三次握手和四次挥手

WireShark一、开启WireShark的大门二、如何抓包 搜索关键字2.1 协议过滤2.2 IP过滤2.3 过滤端口2.4 过滤MAC地址2.5 过滤包长度2.6 HTTP模式过滤三、ARP协议分析四、WireShark之ICMP协议五、TCP三次握手与四次挥手5.1 TCP三次握手实验5.2 可视化看TCP三次握手5.3 TCP四次挥手5.…

PCL 使用ICP点云拼接

一、简介 ICP算法详解——我见过最清晰的解释_负壹的博客-CSDN博客 两个点集,source和target,target不变,source经过旋转(Rotation)和平移(Translation)甚至加上尺度(Scale&#x…