Python GNN图神经网络代码实战;GAT代码模版,简单套用,易于修改和提升,图注意力机制代码实战

1.GAT简介

GAT(Graph Attention Network)模型是一种用于图数据的深度学习模型,由Veličković等人在2018年提出。它通过自适应地在图中计算节点之间的注意力来学习节点之间的关系,并在节点表示中捕捉全局和局部信息。

GAT模型的核心思想是通过注意力机制,对图中的节点进行加权聚合。与传统的图卷积网络(GCN)模型不同,GAT不仅考虑节点本身的特征信息,还考虑了节点与其邻居节点之间的关系。每个节点在聚合邻居节点的特征时,会分配不同的注意力权重,以捕捉不同邻居节点对该节点的贡献程度。

GAT模型具有以下特点和优势:

  1. 自适应学习的注意力机制:GAT模型能够根据数据自动学习节点之间的注意力权重,从而捕捉到不同节点之间的重要性和关系。
  2. 并行计算效率高:由于注意力权重是节点间独立计算的,可以高效地并行计算,适用于大规模图数据。
  3. 稀疏性:GAT模型引入了注意力系数,可以将注意力集中在有用的邻居节点上,减小计算量和存储需求。
  4. 灵活性:GAT模型可以根据任务需求设计不同的注意力权重计算方式,适应不同的图学习任务。

2.代码实战

模型架构分为两部分:GAT主体部分,GAT的注意力计算部分

注意力机制:首先输入参数为(节点的特征表示hi,邻接矩阵),注意这个hi可以来源于上一层,也可以是原始的;先计算每个节点到中心节点的权值,也可以称为权重或者系数,然后对所有的权值进行归一化,最后对每个邻居节点与对应的权值相乘,然后相加就得到了中心节点的最终表示,注意求权值的时候是要考虑中心节点本身的;

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class GATLayer(nn.Module):
    def __init__(self, in_features, out_features, dropout, alpha, concat=True):
        super(GATLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.dropout = dropout
        self.alpha = alpha
        self.concat = concat

        self.W = nn.Linear(in_features, out_features)
        self.a = nn.Linear(2*out_features, 1)

    def forward(self, h, adj):
        Wh = self.W(h)  # W*h
        N = h.size()[0]  # Number of nodes

        a_input = torch.cat([Wh.repeat(1, N).view(N*N, -1), Wh.repeat(N, 1)], dim=1).view(N, -1, 2*self.out_features)
        e = F.leaky_relu(self.a(a_input).squeeze(2), negative_slope=self.alpha)
        
        zero_vec = -9e15*torch.ones_like(e)
        attention = torch.where(adj > 0, e, zero_vec)
        attention = F.softmax(attention, dim=1)
        attention = F.dropout(attention, p=self.dropout, training=self.training)
        h_prime = torch.matmul(attention, Wh)

        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime

class GAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):
        super(GAT, self).__init__()
        self.dropout = dropout
        self.hidden = nn.ModuleList([GATLayer(nfeat, nhid, dropout, alpha, concat=True) for _ in range(nheads)])
        self.out_att = GATLayer(nhid*nheads, nclass, dropout, alpha, concat=False)
    
    def forward(self, x, adj):
        x = F.dropout(x, self.dropout, training=self.training)
        x = torch.cat([att(x, adj) for att in self.hidden], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)
        x = F.sigmoid(self.out_att(x, adj))
        return F.log_softmax(x, dim=1)

# 创建示例数据和邻接矩阵
adj = torch.tensor([[0, 1, 1, 0],
                    [1, 0, 1, 1],
                    [1, 1, 0, 1],
                    [0, 1, 1, 0]])  # 邻接矩阵
features = torch.randn(4, 5)  # 特征矩阵

# 创建GAT模型
model = GAT(nfeat=5, nhid=8, nclass=2, dropout=0.6, alpha=0.2, nheads=2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# 训练模型
for epoch in range(100):
    optimizer.zero_grad()
    output = model(features, adj)
    # 假设这里有标签数据y
    y = torch.LongTensor([0, 1, 0, 1])  # 标签
    loss = criterion(output, y)
    loss.backward()
    optimizer.step()

# 测试模型
output = model(features, adj)
_, predictions = output.max(dim=1)
correct = (predictions == y).sum().item()
accuracy = correct / len(y)
print("准确率:", accuracy)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/674151.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

实现spring配置bean类机制

大家好,这里是教授.F 流程说明: 我们自己实现spring配置bean类的机制,要先了解原本是怎么实现的。 原本的机制就是有一个bean配置文件,还有一个ApplicationContext spring文件。bean类写着要扫描的文件信息,spring文…

vscode编译c/c++找不到jni.h文件

解决办法: 一、下载JDK 访问Oracle官网的Java下载页面:Java Downloads | Oracle 选择适合您操作系统的JDK版本: 对于Windows,选择“Windows x64”或“Windows x86”(取决于您的系统是64位还是32位)。对于Linux&#…

扩散世界模型已训练出赶超人类的智能体?

论文标题: Diffusion for World Modeling:Visual Details Matter in Atari 论文作者: Eloi Alonso, Adam Jelley, Vincent Micheli, Anssi Kanervisto, Amos Storkey, Tim Pearce, Franois Fleuret 项目地址: https://github.com/eloial…

封装了一个使用UICollectionViewLayout 实现的吸附居左banner图

首先查看效果图 实现的原理就是通过自定义UICollectionView layout,然后 设置减速速率是快速就可以达到吸附的效果 _collectionView.decelerationRate UIScrollViewDecelerationRateFast; 下面贴出所有代码 这里是.h // // LBMiddleExpandLayout.h // Liubo…

Java零基础-顺序结构

哈喽,各位小伙伴们,你们好呀,我是喵手。运营社区:C站/掘金/腾讯云;欢迎大家常来逛逛 今天我要给大家分享一些自己日常学习到的一些知识点,并以文字的形式跟大家一起交流,互相学习,一…

10 个最佳 MP4 转换器,可帮助您将视频转换为 MP4

许多人正在寻找一种强大的工具将视频转换为 MP4。网上有很多 MP4 转换器,但只有少数能够有效地将视频转换为 MP4。我们根据实验室测试和用户报告确定了前 10 名 MP4 转换器。在这篇文章中,我们将向您展示这些 MP4 转换器具有哪些功能以及如何使用它们。 …

【Python】 Python中的`mkdir -p`功能解析与应用

基本原理 在Linux系统中,mkdir -p是一个常用的命令,用于创建目录。这个命令的特点是,如果目标目录已经存在,它不会报错,而是直接跳过;如果目标目录不存在,它会创建整个目录路径中所需的所有目录…

166.二叉树:相同的树(力扣)

代码解决 /*** Definition for a binary tree node.* struct TreeNode {* int val;* TreeNode *left;* TreeNode *right;* TreeNode() : val(0), left(nullptr), right(nullptr) {}* TreeNode(int x) : val(x), left(nullptr), right(nullptr) {}* Tre…

无线麦克风哪个品牌音质最好?最好的无线麦克风品牌排行推荐

xx 虽然Vlog随手就能拍,不过Vlog不仅要记录画面,还要记录声音,毕竟一段声色俱全的视频要比一张照片有意义得多。把镜头擦拭干净可以留下清晰明朗的画面,但是在户外参杂了各种嘈杂的声音手机很难收录清晰的人声,所以一…

一点连接千家银行,YonSuite让“银行回单”一键获取

在当今日益复杂多变的商业环境中,企业的资金管理变得尤为重要。传统的银行回单管理方式,如手动登录网银、逐一下载回单、核对信息等,不仅效率低下,而且容易出错,给企业的财务管理带来了极大的挑战。 然而,…

OBC充电机的基础认识

OBC是电动汽车上的充电设备,主要用于将外部交流电源转换为直流电源,为电动汽车的动力电池组充电。OBC是电动汽车的重要组成部分,其性能直接影响到电动汽车的续航里程和充电效率。 OBC的主要功能包括:将交流电转换为直流电&#xf…

C++设计模式|结构型 代理模式

1.什么是代理模式? 代理模式Proxy Pattern是一种结构型设计模式,用于控制对其他对象的访问。 在代理模式中,允许一个对象(代理)充当另一个对象(真实对象)的接口,以控制对这个对象的…

《论文阅读》具有人格自适应注意的个性化对话生成 AAAI 2023

《论文阅读》具有人格自适应注意的个性化对话生成 AAAI 2023 前言 简介挑战与机遇任务定义模型架构Context EncoderPersona EncoderDialog DecoderPersona-Adaptive Attention损失函数实验结果 前言 亲身阅读感受分享,细节画图解释,再也不用担心看不懂论…

Linux 服务查询命令(包括 服务器、cpu、数据库、中间件)

Linux 服务查询命令(包括 服务器、cpu、数据库、中间件) Linux获取当前服务器ipLinux使用的是麒麟版本还是cenos版本Linux获取系统信息Linux查询nignx版本 Linux获取当前服务器ip hostname -ILinux使用的是麒麟版本还是cenos版本 这个文件通常包含有关L…

社交媒体数据恢复:易信

我们可以参考其他类似软件的数据恢复方法尝试解决问题。 检查备份:首先,检查您是否在易信或其他云服务中备份了数据。如果有备份,您可以尝试从备份中恢复数据。 联系易信客服:如果找不到备份,您可以联系易信的客户服务…

Redis 持久化: RDB和AOF

文章目录 ⛄1.RDB持久化🪂🪂1.1.执行时机🪂🪂1.2.RDB原理🪂🪂1.3.小结 ⛄2.AOF持久化🪂🪂2.1.AOF原理🪂🪂2.2.AOF配置🪂🪂2.3.AOF文件…

电脑显示屏亮度怎么调?3招帮你调整亮度

在使用电脑时,调整显示屏亮度是一项常见的操作,它可以帮助我们适应不同的环境光线,提高视觉舒适度。然而,许多用户可能不清楚电脑显示屏亮度怎么调。本文将介绍3种简单实用的方法,帮助您轻松调整电脑显示屏的亮度&…

计算机网络介绍

计算机网络介绍 概述网络概述相关硬件 链路层VLAN概念VLAN 特点VLAN 的划分帧格式端口类型原理 STP概念特点原理 Smart Link概念特点组网 网络层ARP概念原理 IP概念版本IP 地址 IPv4IP 地址数据报格式 IPv6特点IP 地址数据报格式 ICMP概念分类报文格式 VRRP概念原理报文格式 OS…

原生APP和H5 APP的区别

原生APP(Native App)和H5 APP(也称为Web App或Hybrid App)是两种不同的移动应用开发方式,它们在开发技术、性能、用户体验、开发成本和维护等方面存在显著区别。以下是它们的主要区别。北京木奇移动技术有限公司&#…

番外篇-用户购物偏好标签BP-推荐算法ALS

引言 推荐系统式信息过载所采用的措施,面对海量的数据信息,从中快速推荐出符合用户特点的物品。 推荐系统是自动化的通过分析用户对历史行为数据,完成用户的个性化建模,从而主动给用户推荐能够满足他们兴趣和需求的软件系统。 数…