Progressive Feature Fusion Framework Based on Graph Convolutional Network

以Resnet50作为主干网络,然后使用GCN逐层聚合多级特征,逐级聚合这种模型架构早已不新鲜,这篇文章使用GCN的方式对特征进行聚合,没有代码。这篇文章没有过多的介绍如何构造的节点特征和邻接矩阵,我觉得对于图卷积来说,最重要的一点就是确定那些特征作为图节点以及节点直接的连接关系。

很多方法是直接将特征图的每个像素作为一个节点,那这样的话怎么确定每个像素之间的连接关系呢?

对于邻接矩阵来说,两个节点相连置为一,两个节点不相连置为零,通过将节点矩阵和邻接矩阵进行相乘来进行节点之间的信息交互。这种交互是只要两个节点之间相连就将两个节点的特征值进行相加。

这种直接相加的方式忽略了节点与节点之间的重要程度,可以使用图注意力来给图的节点与节点之间施加一个权重,这个权重可以通过自注意力的方式得到,也可以通过图注意力网络中的计算方式得到节点与节点之间的权重关系。图注意力网络的代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import networkx as nx
 
 
def get_weights(size, gain=1.414):
    weights = nn.Parameter(torch.zeros(size=size))
    nn.init.xavier_uniform_(weights, gain=gain)
    return weights
 
class GraphAttentionLayer(nn.Module):
    '''
    Simple GAT layer 图注意力层 (inductive graph)
    '''
    def __init__(self, in_features, out_features, dropout, alpha, concat = True, head_id = 0):
        ''' One head GAT '''
        super(GraphAttentionLayer, self).__init__()
        self.in_features = in_features  #节点表示向量的输入特征维度
        self.out_features = out_features    #节点表示向量的输出特征维度
        self.dropout = dropout  #dropout参数
        self.alpha = alpha  #leakyrelu激活的参数
        self.concat = concat    #如果为true,再进行elu激活
        self.head_id = head_id  #表示多头注意力的编号
 
        self.W_type = nn.ParameterList()
        self.a_type = nn.ParameterList()
        self.n_type = 1 #表示边的种类
        for i in range(self.n_type):
            self.W_type.append(get_weights((in_features, out_features)))
            self.a_type.append(get_weights((out_features * 2, 1)))
 
        #定义可训练参数,即论文中的W和a
        self.W = nn.Parameter(torch.zeros(size = (in_features, out_features)))
        nn.init.xavier_uniform_(self.W.data, gain = 1.414)  #xavier初始化
        self.a = nn.Parameter(torch.zeros(size = (2 * out_features, 1)))
        nn.init.xavier_uniform_(self.a.data, gain = 1.414)  #xavier初始化
 
        #定义dropout函数防止过拟合
        self.dropout_attn = nn.Dropout(self.dropout)
        #定义leakyrelu激活函数
        self.leakyrelu = nn.LeakyReLU(self.alpha)
 
    def forward(self, node_input, adj, node_mask = None):
        '''
        node_input: [batch_size, node_num, feature_size] feature_size 表示节点的输入特征向量维度
        adj: [batch_size, node_num, node_num] 图的邻接矩阵
        node_mask:  [batch_size, node_mask]
        '''
 
        zero_vec = torch.zeros_like(adj)
        scores = torch.zeros_like(adj)
 
        for i in range(self.n_type):
            h = torch.matmul(node_input, self.W_type[i])
            h = self.dropout_attn(h)
            N, E, d = h.shape   # N == batch_size, E == node_num, d == feature_size
 
            a_input = torch.cat([h.repeat(1, 1, E).view(N, E * E, -1), h.repeat(1, E, 1)], dim = -1)
            a_input = a_input.view(-1, E, E, 2 * d)     #([batch_size, E, E, out_features])
 
            score = self.leakyrelu(torch.matmul(a_input, self.a_type[i]).squeeze(-1))   #([batch_size, E, E, 1]) => ([batch_size, E, E])
            #图注意力相关系数(未归一化)
 
            zero_vec = zero_vec.to(score.dtype)
            scores = scores.to(score.dtype)
            scores += torch.where(adj == i+1, score, zero_vec.to(score.dtype))
 
        zero_vec = -1*30 * torch.ones_like(scores)  #将没有连接的边置为负无穷
        attention = torch.where(adj > 0, scores, zero_vec.to(scores.dtype))    #([batch_size, E, E])
        # 表示如果邻接矩阵元素大于0时,则两个节点有连接,则该位置的注意力系数保留;否则需要mask并置为非常小的值,softmax的时候最小值不会被考虑
 
        if node_mask is not None:
            node_mask = node_mask.unsqueeze(-1)
            h = h * node_mask   #对结点进行mask
 
        attention = F.softmax(attention, dim = 2)   #[batch_size, E, E], softmax之后形状保持不变,得到归一化的注意力权重
        h = attention.unsqueeze(3) * h.unsqueeze(2) #[batch_size, E, E, d]
        h_prime = torch.sum(h, dim = 1)             #[batch_size, E, d]
 
        # h_prime = torch.matmul(attention, h)    #[batch_size, E, E] * [batch_size, E, d] => [batch_size, N, d]
 
        #得到由周围节点通过注意力权重进行更新的表示
        if self.concat:
            return F.elu(h_prime)
        else:
            return h_prime
 
class GAT(nn.Module):
    def __init__(self, in_dim, hid_dim, dropout, alpha, n_heads, concat = True):
        '''
        Dense version of GAT
        in_dim输入表示的特征维度、hid_dim输出表示的特征维度
        n_heads 表示有几个GAL层,最后进行拼接在一起,类似于self-attention从不同的子空间进行抽取特征
        '''
        super(GAT, self).__init__()
        assert hid_dim % n_heads == 0
        self.dropout = dropout
        self.alpha = alpha
        self.concat = concat
 
        self.attn_funcs = nn.ModuleList()
        for i in range(n_heads):
            self.attn_funcs.append(
                #定义multi-head的图注意力层
                GraphAttentionLayer(in_features = in_dim, out_features = hid_dim // n_heads,
                                    dropout = dropout, alpha = alpha, concat = concat, head_id = i)
            )
 
        self.dropout = nn.Dropout(self.dropout)
 
    def forward(self, node_input, adj, node_mask = None):
        '''
        node_input: [batch_size, node_num, feature_size]    输入图中结点的特征
        adj:    [batch_size, node_num, node_num]    图邻接矩阵
        node_mask:  [batch_size, node_num]  表示输入节点是否被mask
        '''
        hidden_list = []
        for attn in self.attn_funcs:
            h = attn(node_input, adj, node_mask = node_mask)
            hidden_list.append(h)
 
        h = torch.cat(hidden_list, dim = -1)
        h = self.dropout(h) #dropout函数防止过拟合
        x = F.elu(h)     #激活函数
        return x
 
 
#特征矩阵
x = torch.randn((2, 4, 8))
#邻接矩阵
adj = torch.tensor([[[0, 1, 0, 1],
                    [1, 0, 1, 0],
                    [0, 1, 0, 1],
                    [1, 0, 1, 0]]])
adj = adj.repeat(2, 1, 1)
#mask矩阵
node_mask = torch.Tensor([[1, 0, 0, 1],
                          [0, 1, 1, 1]])
 
 
gat_layer = GraphAttentionLayer(in_features = 8, out_features = 8, dropout = 0.1, alpha = 0.2, concat = True)  #输入特征维度8, 输出特征维度8, 使用多头注意力机制
gat_ = GAT(in_dim = 8, hid_dim = 8, dropout = 0.1, alpha = 0.2, n_heads = 2, concat = True)    #输入特征维度8, 输出特征维度8, 使用多头注意力机制
 
output_ = gat_(x, adj, node_mask)
print(output_.shape)  
 
output_ = gat_(x, adj, node_mask)
print(output_.shape)
 
 
#输出:
torch.Size([2, 4, 8])
torch.Size([2, 4, 8])

自注意力和图注意力在计算节点之间权重的方式稍有不同,在自注意力的计算方式中之进行了矩阵相乘并没有可训练的参数。在图注意力计算节点之间权重时,采用了线性映射的方式,这两种权重计算方式那个更好一点还要通过实验来进行验证。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/690986.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MyBatisPlus插件生成代码

文章目录 概要安装插件使用插件 概要 MyBatis-Plus 是 MyBatis 的增强工具,旨在简化 MyBatis 的开发。MyBatis-Plus 代码生成器插件可以自动生成项目中常见的代码,如实体类、Mapper 接口、Service 接口和实现类、Controller 等,从而减少手动…

【Vue】mutations

文章目录 一、定义mutations二、组件中提交 mutations三、带参数的 mutations 一、定义mutations mutations是vuex中的对象,这个对象可以定义在当前store的配置项中 const store new Vuex.Store({state: {count: 0},// 定义mutations// mutations是一个对象&#x…

Java面试八股之什么是反射,实现原理是什么

Java中什么是反射,实现原理是什么 Java中的反射(Reflection)是一种强大的特性,它允许程序在运行时检查和操作类、接口、字段和方法的信息。简而言之,反射机制使得程序能够在运行时动态地了解和使用自身或其他程序集中…

ubuntu 用户名及密码忘记操作

1、重启系统,长按Shift键,直到出现菜单,选则高级设置。选择recovery mode,即恢复模式 2、选择root 3、# 后面敲入 sudo passwd 用户名 4、# passwd "用户名" 之后再敲两次密码就可以了。(如果提示修改失败可先执行&a…

【Ardiuno】实验使用ESP32连接Wifi(图文)

ESP32最为精华和有特色的地方当然是wifi连接,这里我们就写程序实验一下适使用ESP32主板连接wifi,为了简化实验我们这里只做了连接部分,其他实验在后续再继续。 由于本实验只要在串口监视器中查看结果状态即可,因此电路板上无需连…

vue2 中如何使用 render 函数编写组件

vue2 中如何使用 render 函数编写组件 render 基础语法createElement返回值:VNode参数处理样式和类组件 propsHTML 特性和 DOM 属性处理事件插槽指令v-model 指令其他属性 使用 render 封装一个输入框其他问题参考 vue 提供了声明式编写 UI 的方式,即 vu…

tcp aimd 窗口的推导

旧事重提,今天用微分方程的数值解观测 tcp aimd 窗口值。 设系统 AI,MD 参数分别为 a 1,b 0.5,丢包率由 buffer 大小,red 配置以及线路误码率共同决定,设为 p,窗口为 W,则有&…

C++STL简介

一、STL介绍 STL(standard template libaray-标准模板库):是C标准库的重要组成部分,不仅是一个可复用的组件库,而且是一个包罗数据结构与算法的软件框架。 二、STL的版本 1、原始版本 Alexander Stepanov、Meng Lee 在惠普实验室完…

数字证书和CA

CA(Certificate Authority)证书颁发机构 验证数字证书是否可信需要使用CA的公钥 操作系统或者软件本身携带一些CA的公钥,同时也可以向提供商申请公钥 数字证书的内容 数字证书通常包含以下几个主要部分: 主体信息&#xff08…

搭建多平台比价系统需要了解的电商API接口?

搭建一个多平台比价系统涉及多个步骤,以下是一个大致的指南: 1. 确定需求和目标 平台选择:确定你想要比较价格的平台,例如电商网站、在线旅行社等。数据类型:明确你需要收集哪些数据,如产品价格、产品名称…

仪表板展示|DataEase看中国:2024年高考数据前瞻

背景介绍 2024年高考即将来临。根据教育部公布的数据,2024年全国高考报名人数为1342万人,相比2023年增加了51万人。高考报名人数的增加,既体现了我国基础教育的普及范围之广,也反映了社会对高等教育的重视和需求。 随着中央和各…

VL830 USB4 最高支持40Gbps芯片功能阐述以及原理图分享

前文斥巨资拆了一个扩展坞供大家参考。其中核心即为本文要说的这个VL830,USB4的HUB芯片。 拆解报告传送门:USB4 Gen3x2 最高40Gbps传输速率的HUB扩展坞拆解分析 OK,闲话少叙。直接进入主题,我就直接翻译规格书了。 VL830是一款USB4端点设备…

【无标题】 Notepad++ plugin JSONViewer 下载地址32位

JSONViewer download | SourceForge.net 1、下载插件压缩包并解压出dll:Jsonviewer2.dll(64位)或NPPJSONViewer.dll(32位); 2.、拷贝对应dll到Notepad安装目录下的plugins目录。 3、重启Notepad程序,在插…

【Jenkins】Jenkins - 节点

选择系统设置 - 节点设置 -添加节点 下载对应的 jar包 ,执行命令 测试运行节点生效 1. 创建测试项目 test1 2. 选择节点执行: 在配置页面的“General”部分,找到“限制项目的运行节点”(Restrict where this project can be run…

类似crossover的容器软件有哪些 除了crossover还有什么 Mac虚拟机替代品

CrossOver是Mac用来运行exe文件的一款软件,但是并不是所有的exe文件CrossOver都支持运行。想要在Mac上运行exe文件的方法并不是只有使用CrossOver这一种,那么有没有类似的软件也可以实现exe文件在Mac上运行呢? CrossOver类似软件有哪些 1、Pl…

Electron qt开发教程

模块安装打包 npm install -g electron-forge electron-forge init my-project --templatevue npm start //进入目录启动 //打包成一个目录到out目录下,注意这种打包一般用于调试,并不是用于分发 npm run package //打出真正的分发包,放在o…

java的基础知识包括哪些?

java入门基础知识点需要学什么?入门学习一定要找到适合自己的方法才能事半功倍,对需要掌握的知识点有一个大概的了解,Java入门基础知识包含:标识符、变量、AScii码和Unicod码、基本数据类型转化String类、进制、运算符、程序流程控…

【AIGC】基于大模型+知识库的Code Review实践

一、背景描述 一句话介绍就是:基于开源大模型 知识库的 Code Review 实践,类似一个代码评审助手(CR Copilot)。信息安全合规问题:公司内代码直接调 ChatGPT / Claude 会有安全/合规问题,为了使用 ChatGPT…

【数据结构】初识数据结构之复杂度与链表

【数据结构】初识数据结构之复杂度与链表 🔥个人主页:大白的编程日记 🔥专栏:C语言学习之路 文章目录 【数据结构】初识数据结构之复杂度与链表前言一.数据结构和算法1.1数据结构1.2算法1.3数据结构和算法的重要性 二.时间与空间…

[AIGC] Springboot 自动配置的作用及理由

在详细解释SpringBoot的自动配置之前,先介绍以下背景知识。在创建现代复杂的应用程序时,一个困难的部分是正确地设置您的开发环境。这个问题尤其在Java世界中尤为突出,因为您必须管理和配置许多独立的标准和技术。 当我们谈论Spring Boot的自…