Keras实现图注意力模型GAT

简介:本文实现了一个GAT图注意力机制的网络层,可以在Keras中像调用Dense网络层、Input网络层一样直接搭积木进行网络组合。

一,基本展示

        如下图所示,我们输入邻接矩阵和节点特征矩阵之后,可以直接调用myGraphAttention网络层得到每一头的注意力输出(节点emdbeding),十分的方便。

        注意:上图有个BUG,最终的输出层应该是8,和输入节点特征保持一致,上图只是举一个例子。

二,代码实现

2.1 GAT网络层

        GAT网络层的代码如下。

from __future__ import absolute_import
from keras.activations import relu
from keras import activations, constraints, initializers, regularizers
from keras import backend as K
from keras.layers import Layer, Dropout, LeakyReLU

# 定义图卷积层
class myGraphAttention(Layer):

    def __init__(self,
                 F_,
                 activation='relu',
                 use_bias=True,
                 drop_rate = 0,
                 kernel_initializer='glorot_uniform',
                 bias_initializer='zeros',
                 attn_kernel_initializer='glorot_uniform',
                 kernel_regularizer=None,
                 bias_regularizer=None,
                 attn_kernel_regularizer=None,
                 activity_regularizer=None,
                 kernel_constraint=None,
                 bias_constraint=None,
                 attn_kernel_constraint=None,
                 **kwargs):

        self.F_ = F_  # 输出的节点embeding维度
        self.activation = activations.get(activation) # 输出结果之前的激活函数
        self.use_bias = use_bias

        """
        其他代码………………
        """
        super(myGraphAttention, self).__init__(**kwargs)

    def build(self, input_shape):
        """
        其他代码………………
        """

    def call(self, inputs):
        X = inputs[0]  # 节点特征 (N x F)
        A = inputs[1]  # 邻接矩阵 (N x N)
 
        """
        其他代码………………
        """
  
        # 加上偏置
        if self.use_bias:
            node_features = K.bias_add(node_features, self.bias)

        # 最终的输出之前得激活一下
        output = self.activation(node_features)
        return output

    def compute_output_shape(self, input_shape):
        output_shape = input_shape[0][0], self.output_dim
        return output_shape

2.2 模型搭建

        模型搭建的代码如下。

from keras.layers import Layer,Input,Dense,add,Lambda
from keras.models import Model

inp_adj_martrix = Input(shape=(5,5),name='adj_martrix')
inp_node_features = Input(shape=(5,8),name='node_features_martrix')

# 在这里直接调用网络层
flat0 = myGraphAttention(12,name="head_0")([inp_node_features,inp_adj_martrix])
flat1 = myGraphAttention(12,name="head_1")([inp_node_features,inp_adj_martrix])
flat2 = myGraphAttention(12,name="head_2")([inp_node_features,inp_adj_martrix])

flat = add([flat0,flat1,flat2])

lorder = 1
one_node = Lambda(lambda inp: inp[:,0,:],name = "the-first-node-feature")(flat)


o1 = Dense(32,activation="relu")(one_node)
o2 = Dense(32,activation="relu")(o1)
out = Dense(12)(o2)

model = Model([inp_adj_martrix,inp_node_features],[out])

        创作不易,需要完整代码4_liao我哦。还有很多预测网络结构。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/135944.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

C语言之pthread_once实例总结(八十三)

简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…

史上最全最新Ubuntu20.04安装教程(图文)

总的来说,安装Ubantu包含以下三个步骤: 一、安装虚拟机 二、Ubuntu镜像下载 三、虚拟机配置 一、安装虚拟机 选择安装VMware Workstation,登录其官网下载安装包,链接如下: 下载 VMware Workstation Pro​www.vmwa…

设计模式之--原型模式(深浅拷贝)

原型模式 缘起 某天,小明的Leader找到小明:“小明啊,如果有个发简历的需求,就是有个简历的模板,然后打印很多份,要去一份一份展示出来,用编程怎么实现呢?” 小明一听,脑袋里就有了…

ARM64 linux并发与同步之内存屏障

1.2 内存屏障 1.2.1 概念理解 原理部分比较苦涩难懂,我们先不过多详细介绍这部分的由来和经过,接下来着重讲解什么用途和实现; ARM64架构中提供了3条内存屏障指令。 数据存储屏障(Data Memory Barrier, DMB)指令。数据同步屏障(Data Synch…

劲松HPV防治诊疗中心发布:HPV感染全面防治解决方案

在当今社会,HPV(人乳头瘤病毒)感染问题已成为广大公众关注的焦点。作为一种高度传染性的病毒,HPV感染不仅可能导致生殖器疣,还可能引发各种癌症。面对这一严重威胁,劲松HPV防治诊疗中心以其专业的医疗团队、正规的治疗流程和全方位…

Python基础入门例程51-NP51 列表的最大与最小(循环语句)

最近的博文: Python基础入门例程50-NP50 程序员节(循环语句)-CSDN博客 Python基础入门例程49-NP49 字符列表的长度-CSDN博客 Python基础入门例程48-NP48 验证登录名与密码(条件语句)-CSDN博客 目录 最近的博文&…

函数极限求解方法归纳

1、连续函数直接代入值(加减不可以部分代入值) 例题1 配凑构造等价无穷小 等价无穷小 注意:不要在加减中部分使用等价无穷小,可以利用拆极限的方式求,拆出来的每一部分都要有极限,如果有一部分没有极限就是…

STM32F4X定时器之通用定时器

一、STM32通用定时器概述 通用定时器包括一个16位或32位自动重载计数器,可通过可编程预分频器进行驱动。定时器可以实现多种功能,包括测量输入信号的脉冲宽度和生成输出波形,通过使用定时器预分频器和RCC时钟控制器预分频器,可以…

目标检测——Yolo系列(YOLOv1/2/v3/4/5/x/6/7/8)

目标检测概述 什么是目标检测? 滑动窗口(Sliding Window) 滑动窗口的效率问题和改进 滑动窗口的效率问题:计算成本很大 改进思路 1:使用启发式算法替换暴力遍历 例如 R-CNN,Fast R-CNN 中使用 Selectiv…

C++算法:完美矩形

题目 给你一个数组 rectangles ,其中 rectangles[i] [xi, yi, ai, bi] 表示一个坐标轴平行的矩形。这个矩形的左下顶点是 (xi, yi) ,右上顶点是 (ai, bi) 。 如果所有矩形一起精确覆盖了某个矩形区域,则返回 true ;否则&#xf…

AI 绘画 | Stable Diffusion WebUI的基本设置和插件扩展

前言 Stable Diffusion WebUI是一个基于Gradio库的浏览器界面,用于配置和生成AI绘画作品,并且进行各种精细地配置。它支持目前主流的开源AI绘画模型,例如NovelAI/Stable Diffusion。 在基本设置方面,Stable Diffusion WebUI的默…

asp.net外卖网站系统VS开发mysql数据库web结构c#编程Microsoft Visual Studio

一、源码特点 asp.net外卖网站系统 是一套完善的web设计管理系统,系统采用mvc模式(BLLDALENTITY)系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环境为vs2010,数据库为mysql,使用c#语…

Git的原理与使用(一)

目录 Git初始 Git安装 Git基本操作 创建git本地仓库 配置git 工作区,暂存区,版本库 添加文件,提交文件 查看.git文件 修改文件 版本回退 小结 Git初始 git是一个非常强大的版本控制工具.可以快速的将我们的文档和代码等进行版本管理. 下面这个实例看理解下为什么需…

CountDownLatch和CyclicBarrier详解

1. CountDownLatch 1.1 简介 CountDownLatch 是 Java 中并发包(java.util.concurrent)提供的一种同步工具,用于在多线程环境中协调多个线程之间的执行顺序。它的作用是允许一个或多个线程等待其他线程完成操作。 CountDownLatch 通过一个计…

java使用geotools导出shp文件

SHP格式是一种矢量数据格式,用于存储地理信息系统(GIS)数据。 SHP文件由一系列有序的文件组成,我们导出的shp文件包括.shp、.shx、.dbf、.prj以及.fix文件。 .shp(shape)文件:存储矢量地图数据&…

自定义类型:联合和枚举

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 目录 前言 1. 联合体 1.1 联合体类型的声明 1.2 联合体的特点 1.3 相同成员的结构体和联合体对比 1.4 联合体大小的计算 1.5 联合的一个练习 2. 枚举类型 2.1 枚举类型的声明…

思维模型 目标效应

本系列文章 主要是 分享 思维模型,涉及各个领域,重在提升认知。明确目标,激发内在动机。 1 目标效应的应用 1.1 目标效应在教育领域的应用-棉花糖实验 美国斯坦福大学心理学系的教授米歇尔(Walter Mischel)曾经进行了…

1204. 错误票据

题目: 1204. 错误票据 - AcWing题库 思路: 将输入的数据存入数组,从小到大排序后遍历,若 (a[i] a[i - 1])res1 a[i]--->重号;若(a[i] - a[i - 1] > 2)res2 a[i] - 1--->断号。 难点:题目只告诉我们输入…

1977 智慧校园平台开发与实现JSP【程序源码+文档+调试运行】

摘要 本文旨在设计和实现一个智慧校园平台系统,以满足管理员、教师和学生三类用户的需求。管理员拥有最高管理权限,可对教师和学生用户进行管理;教师用户可查看和管理本班学生信息、成绩等;学生用户可查看个人信息、考试成绩等。…

Leetcode139单词拆分及其多种变体问题

1 声明 1.1 首先,大家常常把这道题当作了背包问题来做,因为其循环结构和背包问题刚好相反,但事实如此嘛? 背包问题通常都是组合问题,这其实是一道“”面向目标值的排列问题“,具体和背包问题有什么不同可…