基于图卷积网络的人体3D网格分割

深度学习在 2D 视觉识别任务上取得了巨大成功。十年前被认为极其困难的图像分类和分割等任务,现在可以通过具有类似人类性能的神经网络来解决。这一成功归功于卷积神经网络 (CNN),它取代了手工制作的描述符。

NSDT工具推荐: Three.js AI纹理开发包 - YOLO合成数据生成器 - GLTF/GLB在线编辑 - 3D模型格式在线转换 - 可编程3D场景编辑器 - REVIT导出3D模型插件 - 3D模型语义搜索引擎 - Three.js虚拟轴心开发包 - 3D模型在线减面 - STL模型在线切割 

受 CNN 在 2D 图像上取得成功的推动,研究人员已将 CNN 扩展到不规则图数据,例如 3D 网格。将卷积运算扩展到非结构化图数据并非易事,因为相对位置的概念已经丢失。因此,我们专门撰写了这篇博文和一个 Colab笔记本 来回答如何将 CNN 推广到不规则图。特别是,我们讨论了如何修改 CNN 以预测 3D 网格上的分割蒙版。

1、为什么 CNN 不适用?

在欧几里得空间中定义的规则图具有固定的节点排序和邻域大小。通用图没有邻居节点排序,并且其邻居数量是可变的。

不幸的是,CNN 不能直接转换为不规则图。CNN 在欧几里得空间和规则图上工作,其中相邻节点可以通过它们与中心节点的相对位置相互区分,并且相邻节点的数量是恒定的。卷积运算利用这两个属性来得出信息节点嵌入。对于不规则和非结构化图,这些属性不满足。非结构化图没有相对位置的概念—没有顶部、底部、右侧或左侧。此外,邻域连接/大小可能因节点而异。因此,不可能在不规则图上执行卷积运算。

使用 3x3 CNN 层计算节点嵌入 h 的更新公式。请注意,为每个邻居学习唯一的权重矩阵 W。这是可能的,因为可以使用邻居与中心节点的相对位置来区分邻居。

2、图卷积网络

由于普通卷积运算不适用于不规则图,因此需要重新制定该运算。图卷积网络 (GCN) 和图卷积运算就是这样的重新制定。与传统 CNN 类似,GCN 对局部邻域进行推理。与 CNN 一样,GCN 通过聚合来自直接邻居的信息来计算嵌入。然而,GCN 确保此信息聚合操作不依赖于节点的相对位置,也不依赖于节点排序,并且与节点的邻域大小无关 - CNN 不会强制执行不变性。

  • 不规则图有节点顺序不变性,因此 GCN 的输出必须与图的节点排序无关。具体而言,无论节点排序如何,神经网络的输出都必须相同。如果输入了两个不同的图节点排序,我们不希望神经网络输出两个不同的值。此属性称为置换不变性。我们在下面进行说明。
  • 节点的连接性在整个图中可能有所不同。因此,图卷积运算必须与邻域大小无关。

对于不同顺序的计划,学习函数必须将节点映射到相同的输出,而不管它们的顺序如何(排列等变性——如图所示),并且必须将图映射到相同的向量(排列不变性),而不管图的节点顺序如何。

单个 GCN 层由三个操作定义:消息计算、聚合和更新。执行这些操作是为了导出节点嵌入。

  • 消息计算:消息计算相当于转换图中每个节点的特征向量。应用的学习转换对所有节点都是通用的,以强制置换不变性。
  • 聚合:导出所有消息后,每个节点都会通过置换等变函数(例如均值、最大值或总和函数)聚合其邻居的消息。这种消息聚合使信息能够在整个图中传播。
  • 更新:最后,更新函数将聚合消息和当前节点的先前节点嵌入结合起来。从而更新节点的嵌入。通过结合来自节点邻域和节点本身的信息,导出的节点描述符对结构和节点级信息进行编码。

具有节点嵌入 h 的 GCN 的一般消息、聚合和更新公式。消息按节点度进行归一化,聚合是邻居和当前节点的总和,并应用激活来更新嵌入。请注意,W 不依赖于当前节点或其邻居。

如需更详细的解释,请参阅斯坦福大学关于图形机器学习的课程 CS224W 的第 6 和第 7 讲。

3、图卷积网络的应用

为了探索引入的图卷积操作,我们转向图分割任务。与图像分割类似,其中像素被分配给语义类别,图分割相当于找到图节点和一组类别之间的映射。我们专注于将身体部位标签分配给人形 3D 网格节点的问题。

图形分割任务:网格中的每个顶点被分配给十二个身体部位之一

4、3D 网格数据

为了解决所提出的分割任务,我们利用了 3D 网格中编码的所有数据。3D 网格通过顶点和三角面的集合来定义表面。它由两个矩阵表示:

  • 一个维度为 (n, 3) 的顶点矩阵,其中每行指定 3D 空间中顶点的空间位置 [x, y, z]。
  • 一个维度为 (m, 3) 的面整数矩阵,其中每行包含定义三角面的顶点矩阵的三个索引。

请注意,顶点矩阵捕获节点级特征信息,而面矩阵描述节点连通性。正式地,每个网格可以转换为具有顶点 V 的图 G = (X, A),其中 X 具有维度  (|V|, 3) 并定义 V 中每个节点 u 的空间 xyz 特征,而邻接矩阵 A 具有维度  (|V|, |V|) 并定义每个节点的连通邻域。我们使用这个推导出的网格图。

5、人体3D网格分割GCN实现

我们依靠特征引导图卷积从上述推到网格图中得出节点级身体部位分配。这个 GCN 层是由 Verma 等人提出的。

通过特征引导图卷积进行节点级嵌入更新

特征引导图卷积的工作原理如下。通过聚合相邻节点(绿色)的变换嵌入以及其自身的变换嵌入来更新中心节点(红色)的嵌入。具体来说,通过将其嵌入传递到 M 个权重矩阵,为每个节点创建 M 条不同的消息。因此,总共创建了 M x Ni 条消息。这些 M x Ni 通过可调加权平均值进行聚合。这可以表示为:

其中 Nv 是顶点 v 的相邻顶点(邻居)集(包括 v 本身)。

M 条消息通过可学习的注意力机制加权。这些消息按以下方式针对每一层进行计算:

其中 u 和 c 是可学习的参数,特定于每个层。请注意,此注意力权重函数是平移不变的,这在使用空间输入特征时是可取的。它确保输入网格的平移不会影响计算出的注意力权重。

6、在 PyG 中实现

此图卷积操作可以使用 PyTorch geometric(PyG) 消息传递类来实现。三种类方法定义了一个 PyG 消息传递类。它们是 forward(执行前向传递)、 message(构造节点级消息)和 aggregate (聚合)。 message和 forward函数定义如下:

class FeatureSteeredConvolution(MessagePassing):
    """Implementation of FeatureSteeredConvolution
    References
    ----------
    .. [1] Verma, Nitika, Edmond Boyer, and Jakob Verbeek.
       "Feastnet: Feature-steered graph convolutions for 3d shape analysis."
       Proceedings of the IEEE conference on computer vision and pattern recognition. 2018.
    """
    
    ...

    def forward(self, x, edge_index):
        """Forward pass through a feature steered convolution layer.
        Parameters
        ----------
        x: torch.tensor [|V|, in_features]
            Input feature matrix, where each row describes
            the input feature descriptor of a node in the graph.
        edge_index: torch.tensor [2, E]
            Edge matrix capturing the graph's
            edge structure, where each row describes an edge
            between two nodes in the graph.
        Returns
        -------
        torch.tensor [|V|, out_features]
            Output feature matrix, where each row corresponds
            to the updated feature descriptor of a node in the graph.
        """
        if self.with_self_loops:
            edge_index, _ = remove_self_loops(edge_index)
            edge_index, _ = add_self_loops(edge_index=edge_index, num_nodes=x.shape[0])

        out = self.propagate(edge_index, x=x)
        return out if self.bias is None else out + self.bias

    def _compute_attention_weights(self, x_i, x_j):
        """Computation of attention weights.
        Parameters
        ----------
        x_i: torch.tensor [|E|, in_feature]
            Matrix of feature embeddings for all central nodes,
            collecting neighboring information to update its embedding.
        x_j: torch.tensor [|E|, in_features]
            Matrix of feature embeddings for all neighboring nodes
            passing their messages to the central node along
            their respective edge.
        Returns
        -------
        torch.tensor [|E|, M]
            Matrix of attention scores, where each row captures
            the attention weights of transformed node in the graph.
        """
        if x_j.shape[-1] != self.in_channels:
            raise ValueError(
                f"Expected input features with {self.in_channels} channels."
                f" Instead received features with {x_j.shape[-1]} channels."
            )
        if self.v is None:
            attention_logits = self.u(x_i - x_j) + self.c
        else:
            attention_logits = self.u(x_i) + self.b(x_j) + self.c
        return F.softmax(attention_logits, dim=1)

    def message(self, x_i, x_j):
        """Message computation for all nodes in the graph.
        Parameters
        ----------
        x_i: torch.tensor [|E|, in_feature]
            Matrix of feature embeddings for all central nodes,
            collecting neighboring information to update its embedding.
        x_j: torch.tensor [|E|, in_features]
            Matrix of feature embeddings for all neighboring nodes
            passing their messages to the central node along
            their respective edge.
        Returns
        -------
        torch.tensor [|E|, out_features]
            Matrix of updated feature embeddings for
            all nodes in the graph.
        """
        attention_weights = self._compute_attention_weights(x_i, x_j)
        x_j = self.linear(x_j).view(-1, self.num_heads, self.out_channels)
        return (attention_weights.view(-1, self.num_heads, 1) * x_j).sum(dim=1)

完整实现可在我们的 Colab 中找到。

7、网络架构

特征引导图卷积构成了我们身体部位标记网络的主干。我们的一般网络架构如下:

  • 首先使用多层感知器编码器嵌入节点级输入特征向量(xyz 坐标)。
  • 然后,编码后的节点级特征依次通过四个特征引导卷积层。这些层各自聚合来自 12 个注意头(身体部位输出标签的数量)的消息。
  • 最后,精炼后的节点级特征通过预测头,即多感知器预测。它为每个节点输出一个类对应关系。

我们训练神经网络以最小化预测和真实分割标签之间的交叉熵损失。该损失通过整个网络反向传播以更新参数。由此产生的训练循环定义如下:

def train(net, train_data, optimizer, loss_fn, device):
    net.train()
    cumulative_loss = 0.0
    for data in train_data:
        data = data.to(device)
        optimizer.zero_grad()
        out = net(data)
        loss = loss_fn(out, data.segmentation_labels.squeeze())
        loss.backward()
        cumulative_loss += loss.item()
        optimizer.step()
    return cumulative_loss / len(train_data)

最后,我们通过计算预测分割标签与真实分割标签之间的准确度来评估网络的性能。计算结果如下所示:

def accuracy(predictions, gt_seg_labels):
    """Compute accuracy of predicted segmentation labels.
    Parameters
    ----------
    predictions: [|V|, num_classes]
        Soft predictions of segmentation labels.
    gt_seg_labels: [|V|]
        Ground truth segmentations labels.
    Returns
    -------
    float
        Accuracy of predicted segmentation labels.
    """
    predicted_seg_labels = predictions.argmax(dim=-1, keepdim=True)
    if predicted_seg_labels.shape != gt_seg_labels.shape:
        raise ValueError("Expected Shapes to be equivalent")
    correct_assignments = (predicted_seg_labels == gt_seg_labels).sum()
    num_assignemnts = predicted_seg_labels.shape[0]
    return float(correct_assignments / num_assignemnts)

8、数据集

我们在 MPI FAUST 数据集 [2] 上训练我们的网络。它包含 100 个严密的人形网格。我们通过标记单个网格来生成分割标签。这些标签通过数据集中包含的地面真实顶点对应关系转移到所有其他网格。我们使用 80 个网格来训练我们的神经网络,并对 20 个网格进行评估。

MPI FAUST 数据集中人形网格的 10 个姿势

网络在此数据集上的学习过程如下所示。我们看到类别预测收敛到正确的身体部位标签。经过训练的网络在测试数据集上的准确率达到 95%。

9、结束语

在本文中,我们展示了如何将卷积运算从欧几里得空间中定义的规则图推广到具有任意连通性的不规则图。这种推广使我们定义了 GCN,即 CNN 的扩展。与传统 CNN 不同,GCN 可以在不规则图上运行。因此,GCN 可以应用于各种数据——我们所示的网格、分子、社交网络,甚至 2D 图像。


原文链接:基于GCN的3D网格分割 - BimAnt

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/657805.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

阿里云和AWS的CDN产品对比分析

在现代互联网时代,内容分发网络(CDN)已成为确保网站和应用程序高性能和可用性的关键基础设施。作为两家领先的云服务提供商,阿里云和Amazon Web Services(AWS)都提供了成熟的CDN解决方案,帮助企业优化网络传输和提升用户体验。我们九河云一直致力于阿里云和AWS云相关业务&#…

【全开源】西陆家政系统源码小程序(FastAdmin+ThinkPHP+原生微信小程序)

打造高效便捷的家政服务平台 一、引言:家政服务的数字化转型 随着人们生活节奏的加快,家政服务需求日益增长。为了满足广大用户对高效、便捷的家政服务的需求,家政小程序系统源码应运而生。这款源码不仅能够帮助家政服务提供商快速搭建自己…

IDEA2023.2单击Setting提示报错:Cannot get children Easy Code

1、单击Setting,不能弹出对话框 2、打开IDE Internal Errors发生错误 原因: 报错信息 "Cannot get children Easy Code" 通常指的是 IntelliJ IDEA 在尝试访问或操作 Easy Code 插件的子设置时遇到了问题。 主要检查是有网络(断断…

R包Colorfindr识别图片颜色|用刀剑神域方式打开SCI科研配色

1.前言 最近忙里偷闲,捣鼓一下配色,把童年回忆里的动漫都搬进来,给科研信仰充值吧~ 提取颜色之前写过一个Py的,那个很准确不过调参会有点麻烦。这里分享一个比较懒人点的R包吧,虽然会有一定误差&#xff…

【XR806开发板试用】基础篇,从零开始搭建一个LCD彩屏时钟(ST7735S驱动)

本文从搭建环境开始,step by step教大家使用XR806实现驱动SPI屏幕(ST7735S驱动),并连接WiFi实现ntp对时,最终实现把时间显示到屏幕上。 #1. 搭建开发环境 1. 安装编译环境所需的依赖包 基于ubuntu 20.04,按…

作业-day-240522

思维导图 使用IO多路复用实现并发 select实现TCP服务器端 #include <myhead.h>#define SER_IP "192.168.125.112" #define SER_PORT 8888int main(int argc, const char *argv[]) {int sfdsocket(AF_INET,SOCK_STREAM,0);if(sfd -1){perror("socket er…

李廉洋:5.29黄金震荡,原油持续走高,今日美盘行情走势分析及策略。

黄金消息面分析&#xff1a;当前美国存在一个令人担忧且未被充分关注的问题&#xff1a;房地产行业低迷、高利率和抵押贷款利率、租金高涨以及美联储的紧缩政策构成了一个恶性循环。由于高房价和高抵押贷款利率&#xff0c;美国住房经济活动远低于两年前的水平。为了让该行业好…

Java特性之设计模式【备忘录模式】

一、备忘录模式 概述 备忘录模式&#xff08;Memento Pattern&#xff09;保存一个对象的某个状态&#xff0c;以便在适当的时候恢复对象&#xff0c;备忘录模式属于行为型模式 备忘录模式允许在不破坏封装性的前提下&#xff0c;捕获和恢复对象的内部状态 主要解决&#xff…

Python爬虫实战(实战篇)—17获取【CSDN某一专栏】数据转为Markdown列表放入文章中

文章目录 专栏导读背景结果预览1、页面分析2、通过返回数据发现适合利用lxmlxpath3、进行Markdown语言拼接总结 专栏导读 在这里插入图片描述 &#x1f525;&#x1f525;本文已收录于《Python基础篇爬虫》 &#x1f251;&#x1f251;本专栏专门针对于有爬虫基础准备的一套基…

【Linux】22. 线程控制

Linux线程控制 POSIX线程库 与线程有关的函数构成了一个完整的系列&#xff0c;绝大多数函数的名字都是以“pthread_”打头的 要使用这些函数库&#xff0c;要通过引入头文<pthread.h> 链接这些线程函数库时要使用编译器命令的“-lpthread”选项 线程创建 pthread_cr…

AI办公自动化:kimi批量新建文件夹

工作任务&#xff1a;批量新建多个文件夹&#xff0c;每个文件夹中的年份不一样 在kimi中输入提示词&#xff1a; 你是一个Python编程专家&#xff0c;要完成一个编写关于录制电脑上的键盘和鼠标操作的Python脚本的任务&#xff0c;具体步骤如下&#xff1a; 打开文件夹&…

二叉树习题精讲-相同的树

相同的树 100. 相同的树 - 力扣&#xff08;LeetCode&#xff09;https://leetcode.cn/problems/same-tree/description/ /*** Definition for a binary tree node.* struct TreeNode {* int val;* struct TreeNode *left;* struct TreeNode *right;* };*/ bool i…

夏日防晒笔记

1 防晒霜 使用方法&#xff1a;使用前上下摇晃瓶身4至5次&#xff0c;在距离肌肤10至15cm处均匀喷上。如在面部使用&#xff0c;请先喷在掌心再均匀涂抹于面部。排汗量较多时或擦拭肌肤后&#xff0c;请重复涂抹以确保防晒效果。卸除时使用普通洁肤产品洗净即可。

通过date命令给日志文件添加日期

一、背景 服务的日志没有使用日志工具&#xff0c;每次重启后生成新日志文件名称相同&#xff0c;新日志将会把旧日志文件冲掉&#xff0c;旧日志无法保留。 为避免因旧日志丢失导致无法定位问题&#xff0c;所以需要保证每次生成的日志文件名称不同。 二、解决 在启动时&am…

cs61B-sp21 | lab6

cs61B-sp21 | lab6 TODO 1 在 CapersRepository.java 中 static final File CAPERS_FOLDER null; // TODO Hint: look at the join // function in Utils在 Utils.java 我们找到 join 函数&#xff0c;第一个 join 的作用是将 first 和 others 连接起来形成一个路径…

【ArcGISPro】3.1.5下载和安装教程

下载教程 arcgis下载地址&#xff1a;Трекер (rutracker.net) 点击磁力链下载弹出对应的软件进行下载 ArcGISPro3.1新特性 ArcGIS Pro 3.1是ArcGIS Pro的最新版本&#xff0c;它引入了一些新的特性和功能&#xff0c;以提高用户的工作效率和数据分析能力。以下是ArcGIS…

c#对操作系统的时间无法更新?

&#x1f3c6;本文收录于「Bug调优」专栏&#xff0c;主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案&#xff0c;希望能够助你一臂之力&#xff0c;帮你早日登顶实现财富自由&#x1f680;&#xff1b;同时&#xff0c;欢迎大家关注&&收藏&&…

I.MX6ULL主频和时钟配置实验

系列文章目录 I.MX6ULL主频和时钟配置实验 I.MX6ULL主频和时钟配置实验 系列文章目录一、前言二、I.MX6U 时钟系统详解三、硬件原理四、 7 路 PLL 时钟源五、时钟树简介六、内核时钟设置七、PFD 时钟设置八、AHB、IPG 和 PERCLK 根时钟设置九、实验程序编写十、编译下载10.1编写…

线程池的实现

线程池是一种池式组件&#xff0c;通过创建和维护一定数量的线程&#xff0c;实现这些线程的重复使用&#xff0c;避免了频繁创建和销毁线程的开销&#xff0c;从而提升了性能 线程池的作用&#xff1a; 1.复用线程资源&#xff1b; 2.减少线程创建和销毁的开销&#xff1b; …