GCN,GraphSAGE 到底在训练什么呢?

根据DGL 来做的,按照DGL 实现来讲述

1. GCN Cora 训练代码:

import os

os.environ["DGLBACKEND"] = "pytorch"
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv


class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h


def train(g, model):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
    best_val_acc = 0
    best_test_acc = 0

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    for e in range(100):
        # Forward
        logits = model(g, features)

        # Compute prediction
        pred = logits.argmax(1)

        # Compute loss
        # Note that you should only compute the losses of the nodes in the training set.
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        # Compute accuracy on training/validation/test
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        val_acc = (pred[val_mask] == labels[val_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        # Save the best validation accuracy and the corresponding test accuracy.
        if best_val_acc < val_acc:
            best_val_acc = val_acc
            best_test_acc = test_acc

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if e % 5 == 0:
            print(
                f"In epoch {e}, loss: {loss:.3f}, val acc: {val_acc:.3f} (best {best_val_acc:.3f}), test acc: {test_acc:.3f} (best {best_test_acc:.3f})"
            )




if __name__ == "__main__" :
    dataset = dgl.data.CoraGraphDataset()
    # print(f"Number of categories: {dataset.num_classes}")
    g = dataset[0]
    g = g.to('cuda')
    model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes).to('cuda')
    train(g, model)

一些基础python torch.tensor语法概述:

1.  

if __name__ == "__main__" :
    XXXXXXX
    XXXXXXX

当我们直接执行这个脚本时,__name__属性被设置为__main__,因此满足if条件,语句块中的代码被调用。
但如果我们将该脚本作为模块导入到另一个脚本中,则__name__属性会被设置为模块的名称(例如"example"),语句块中的代码不会被执行。

2. 

# Compute prediction
pred = logits.argmax(1)    # 返回沿着第一个维度(即维度索引为1)的最大值的索引。
                           # 即,加入有5个样本,每个样本有3个维度的评分,那么就会给出没个样本3中维度评分最高的哪个维度的索引序号

 

3. numpy 关于 tensor 的一个用法:

在DGL 中使用一串 True 或 False 组成的 一维tensor 来标识 这个节点到底是属于 train test val 哪一类

train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]

而后,由于对于torch中的tensor来说:

就可以:select_label_tensor = labels[train_mask] 了

import torch

# 定义一个Tensor
tensor = torch.tensor([1, 2, 3, 4, 5])

# 定义一个布尔数组,选择索引为1和4的元素
mask = torch.tensor([False, True, False, False, True])

# 通过布尔索引选择元素
selected_tensor = tensor[mask]

print(selected_tensor)  # tensor([2, 5])

顺便,查看一个变量到底是什么类型可以使用 type() 函数:

train_mask = g.ndata["train_mask"]
print(type(train_mask))

# 输出为:
# <class 'torch.Tensor'>

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/217141.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

质量小议35 -- SQL注入

已经记不得上次用到SQL注入是什么时候了&#xff0c;一些概念和操作已经模糊。 最近与人聊起SQL注入&#xff0c;重新翻阅&#xff0c;暂记于此。 重点&#xff1a;敏感信息、权限过大、未脱敏的输入/输出、协议、框架、数据包、明文、安全意识 SQL - Structured Query La…

时间序列预测实战(二十三)进阶版LSTM多元和单元预测(课程设计毕业设计首选)

一、本文介绍 本篇文章给大家带来的是利用我个人编写的架构进行LSTM模型进行时间序列建模&#xff08;专门为了时间序列领域新人编写的架构&#xff0c;简单且不同于市面上大家用GPT写的代码&#xff09;&#xff0c;包括结果可视化、支持单元预测、多元预测、模型拟合效果检测…

(C++)盛水最多的容器--双指针法

个人主页&#xff1a;Lei宝啊 愿所有美好如期而遇 力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://le…

软著项目推荐 深度学习图像风格迁移

文章目录 0 前言1 VGG网络2 风格迁移3 内容损失4 风格损失5 主代码实现6 迁移模型实现7 效果展示8 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 深度学习图像风格迁移 - opencv python 该项目较为新颖&#xff0c;适合作为竞赛课题…

【ArcGIS Pro微课1000例】0048:深度学习--人群计数

文章目录 一、小学回忆录二、深度学习计算人头数三、案例实现一、小学回忆录 加载配套实验数据包中的图片及训练模型。你还记得当年的小学毕业班有多少同学吗?今天我们就用ArcGIS提供的人工智能工具,重温一下童年记忆。 二、深度学习计算人头数 本案例使用到的是深度学习中…

[树莓派3B+][内核版本6.1]的linux内核编译+替换 (超详细)

学习Linux的内核编译&#xff0c;我使用的是x86 64位的18.04的ubuntu-linux虚拟机&#xff1a; 目录 树莓派的Linux内核源码安装 操作系统的启动过程 & Bootloader 单片机裸机&#xff1a;C51,STM32 X86&#xff0c;Intel&#xff1a;windows 嵌入式产品&#xff1a;…

CUDA简介——同步

1. 引言 前序博客&#xff1a; CUDA简介——基本概念CUDA简介——编程模式CUDA简介——For循环并行化CUDA简介——Grid和Block内Thread索引CUDA简介——CUDA内存模式 本文重点关注Thread同步和Barriers。 Threads并行执行&#xff0c;可能存在如下问题&#xff1a; 1&#…

芯擎科技与芯华章深度合作,软硬件协同开发加速车规级芯片创新

12月4日&#xff0c;系统级验证EDA解决方案提供商芯华章&#xff0c;与国产高端车规芯片设计公司芯擎科技正式建立战略合作。双方强强联手&#xff0c;芯擎科技导入芯华章相关EDA验证工具&#xff0c;赋能车规级芯片和应用软件的协同开发&#xff0c;助力大规模缩短产品上市周期…

PoE技术详解

标准的五类网线有四对双绞线&#xff0c;IEEE 802.3af和IEEE 802.3at允许两种用法&#xff1a;通过空闲线对供电或者数据线对供电。IEEE 802.3bt允许通过空闲线对供电、通过数据线对供电或者空闲线对和数据线对一起供电&#xff0c;如图16.1所示。 图 16.1 PoE供电线对 当在一…

第一节JavaScript 简介与使用

JavaScript简介 JavaScript是互联网上最流行的脚本语言&#xff0c;这门语言可用于HTML和Web&#xff0c;更广泛用于服务器、PC、电脑、智能手机等设备上。 JavaScript是一种轻量级的编程语言。 JavaScript是可插入HTML页面的编程代码。 JavaScript插入HTML页面后&#xff…

使用coco数据集进行语义分割:数据预处理与损失函数

如何coco数据集进行目标检测的介绍已经有很多了&#xff0c;但是关于语义分割几乎没有。本文旨在说明如何处理 stuff_train2017.json stuff_val2017.json panoptic_train2017.json panoptic_val2017.json&#xff0c;将上面那些json中的dict转化为图片的label mask&am…

【Proteus仿真】【Arduino单片机】蔬菜大棚温湿度控制系统设计

文章目录 一、功能简介二、软件设计三、实验现象联系作者 一、功能简介 本项目使用Proteus8仿真Arduino单片机控制器&#xff0c;使用PCF8574、LCD1602液晶、DHT11温湿度传感器、按键、继电器、蜂鸣器、加热、水泵电机等。 主要功能&#xff1a; 系统运行后&#xff0c;LCD160…

【模电】设置静态工作点的必要性

设置静态工作点的必要性 静态工作点为什么要设置静态工作点 静态工作点 在放大电路中&#xff0c;当有信号输入时&#xff0c;交流量与直流量共存。将输入信号为零、即直流电源单独作用时晶体管的基极电流 I B I\tiny B IB、集电极电流 I C I\tiny C IC、b - e间电压 U B E U\t…

玩转大数据5:构建可扩展的大数据架构

1. 引言 随着数字化时代的到来&#xff0c;大数据已经成为企业、组织和个人关注的焦点。大数据架构作为大数据应用的核心组成部分&#xff0c;对于企业的数字化转型和信息化建设至关重要。我们将探讨大数据架构的基本要素和原则&#xff0c;以及Java在大数据架构中的角色&…

智能指针与动态内存

动态内存 new placement new 是 C 中的一种内存分配方式&#xff0c;它允许在给定的内存地址上构造对象&#xff0c;而不是在默认的堆上分配新的内存。这对于某些特殊的内存管理场景非常有用&#xff0c;例如在特定的内存池中分配对象。 C11 引入了 "new auto" 语法…

【C语言】递归详解

目录 1.前言2. 递归的定义3. 递归的限制条件4. 递归举例4.1 求n的阶乘4.1.1 分析和代码实现4.1.2 画图演示 4.2 顺序打印一个整数的每一位4.2.1 分析和代码实现4.2.2 画图推演 4.3 求第n个斐波那契数 5. 递归与迭代5.1 迭代求第n个斐波那契数 1.前言 这次博客内容是与递归有关&…

leetcode 876.链表的中间结点

补充上次的环形链表没细讲的快慢指针&#xff08;这三道题现在可以连起来看&#xff09;&#xff0c;希望对你做题思路有帮助 876.链表的中间结点 题目 给你单链表的头结点 head &#xff0c;请你找出并返回链表的中间结点。 如果有两个中间结点&#xff0c;则返回第二个中间结…

kubernetes监控GPA安装部署

本文在于指导如何对k8s的监控GPA(Grafana&#xff0c;prometheus以及alertmanager)进行安装部署。 1. 介绍 Prometheus 在真正部署Prometheus之前&#xff0c;应了解一下Prometheus的各个组件之间的关系及作用&#xff1a; 1&#xff09;MertricServer&#xff1a;是k8s集群…

面试官:说说synchronized与ReentrantLock的区别

程序员的公众号&#xff1a;源1024&#xff0c;获取更多资料&#xff0c;无加密无套路&#xff01; 最近整理了一波电子书籍资料&#xff0c;包含《Effective Java中文版 第2版》《深入JAVA虚拟机》&#xff0c;《重构改善既有代码设计》&#xff0c;《MySQL高性能-第3版》&…

【NI-RIO入门】Real Time(实时系统解释)

1.什么是实时系统&#xff1f; 实时系统可以非常精确可靠的执行需要特定时许要求的系统&#xff0c;对于许多工程项目来说&#xff0c;通用操作系统&#xff0c;例如WINDOWS等标准PC上运行测量和控制程序是无法精确控制计时的&#xff0c;这些系统很容易受用户的其他程序、图像…