图神经网络:(处理点云)PointNet++的实现

文章说明:
1)参考资料:PYG官方文档。超链。
2)博主水平不高,如有错误还望批评指正。
3)我在百度网盘上传了这篇文章的jupyter notebook和有关文献。超链。提取码8848。

文章目录

    • 简单前置工作学习
    • 文献阅读
    • PointNet++的实现
    • 模型问题

简单前置工作学习

工作目标:根据点云去进行40分类。
工作流程:1.读取PyG内置的几何图形数据。2.随机但是均匀采样。3.K最邻近算法构边建图。4.使用PointNet++进行图分类。
导库,下载数据,导库,定义函数

from torch_geometric.datasets import GeometricShapes
dataset=GeometricShapes(root='/Data/GeometricShapes')
import matplotlib.pyplot as plt
def visualize_mesh(pos,face):
    fig=plt.figure()
    ax=fig.add_subplot(111,projection='3d')
    ax.axes.xaxis.set_ticklabels([])
    ax.axes.yaxis.set_ticklabels([])
    ax.axes.zaxis.set_ticklabels([])
    ax.plot_trisurf(pos[:,0],pos[:,1],pos[:,2],triangles=face.t(),antialiased=False)
    plt.show()

PS1:这段代码会在C盘生成一个DATA的文件并将数据集放在DATA中,有强迫症注意一下。
PS2:就是几何图形网格。细节可以点击这里。
打印信息与可视化

print(dataset)
data=dataset[0]
print(data)
visualize_mesh(data.pos,data.face)
data=dataset[4]
print(data)
visualize_mesh(data.pos,data.face)

jupyter notebook内输出如下
在这里插入图片描述
导库以及定义函数

from torch_geometric.transforms import SamplePoints
import torch
def visualize_points(pos,edge_index=None,index=None):
    fig=plt.figure(figsize=(4, 4))
    if edge_index is not None:
        for (src,dst) in edge_index.t().tolist():
            src=pos[src].tolist()
            dst=pos[dst].tolist()
            plt.plot([src[0],dst[0]],[src[1],dst[1]],linewidth=1,color='black')
    if index is None:
        plt.scatter(pos[:,0],pos[:,1],s=50,zorder=1000)
    else:
       mask=torch.zeros(pos.size(0),dtype=torch.bool)
       mask[index]=True
       plt.scatter(pos[~mask,0],pos[~mask,1],s=50,color='lightgray',zorder=1000)
       plt.scatter(pos[mask,0],pos[mask,1],s=50,zorder=1000)
    plt.axis('off')
    plt.show()

从图形表面均匀地采样,打印信息与可视化

dataset.transform=SamplePoints(num=256)
data=dataset[0]
print(data)
visualize_points(data.pos)
data=dataset[4]
print(data)
visualize_points(data.pos)

jupyter notebook内输出如下
在这里插入图片描述

文献阅读

参考文献: PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space

文章概述: “Deep learning on point sets for 3d classification and segmentation”是参考文献之前前沿工作,核心思想对每个点空间编码然后聚合所有单点要素到全局的空间。显然这样无法捕捉局部特征。受到卷积神经网络启发,这里参考文献便就来了。具体步骤: 第一步:进行局部划分;第二步:组合局部特征;第三步:加工局部特征,重复上述过程直到点云所有特征都被利用。所以面临三个问题。第一问:如何进行局部划分。第二问:如何组合局部特征。第三问:如何加工局部特征。解决第一问:Farthest Point Sampling,FPS。解决第二问:Ball Query。解决第三问:上面那篇文章的Point

分层的点云学习器: Sampling layer: Farthest Point Sampling,FPS。 可以使用K最近邻算法但是不好。固定一个区域更加有普适性。PS:注意一下KNN与Ball Query的区别。Grouping layer: 输入: N × ( d + C ) N \times (d+C) N×(d+C) 以及 N ′ × d N'\times d N×d 输出: N ′ × K × ( d + C ) N' \times K \times (d + C) N×K×(d+C)。符号说明: N N N是点的数量, d d d是质心坐标, C C C是点的特征维数, N ′ N' N是质心数量, K K K是邻域内点数量。Point Net layer: 输入: N ′ × K × ( d + C ) N' \times K \times (d + C) N×K×(d+C) 输出: N ′ × ( d + C ′ ) N' \times (d + C') N×(d+C) 。这个模型鲁棒性强,对于不均匀的数据效果同样。这个图挺好的。
在这里插入图片描述
PS1:原文还有其他很好工作,有兴趣有时间建议去看,但是我们这里跳过。
PS2:对于上面前置工作,由于采用是均匀的,可以这样建图。如下:
导库

from torch_cluster import knn_graph

打印信息与可视化

data=dataset[0]
data.edge_index=knn_graph(data.pos,k=6)
print(data.edge_index.shape)
visualize_points(data.pos,edge_index=data.edge_index)
data=dataset[4]
data.edge_index=knn_graph(data.pos, k=6)
print(data.edge_index.shape)
visualize_points(data.pos,edge_index=data.edge_index)

jupyter notebook内输出如下
在这里插入图片描述

PointNet++的实现

我们使用数学公式首先进行EdgeConv的描述: h i ( l ) = m a x j ∈ N i M L P ( h i ( l − 1 ) , h j ( l − 1 ) − h i ( l − 1 ) ) h_i^{(l)}=max_{j\in \mathcal{N}_i}MLP(h_i^{(l-1)},h_j^{(l-1)}-h_i^{(l-1)}) hi(l)=maxjNiMLP(hi(l1),hj(l1)hi(l1))。Point++类似于这个公式: h i ( l ) = m a x j ∈ N i M L P ( h i ( l − 1 ) , p j ( l − 1 ) − p i ( l − 1 ) ) h_i^{(l)}=max_{j\in \mathcal{N}_i}MLP(h_i^{(l-1)},p_j^{(l-1)}-p_i^{(l-1)}) hi(l)=maxjNiMLP(hi(l1),pj(l1)pi(l1))
搭建多层的PointNel++

from torch_geometric.nn import MessagePassing
from torch.nn import Sequential,Linear,ReLU

class PointNetLayer(MessagePassing):

    def __init__(self,in_channels,out_channels):
        super().__init__(aggr='max')
        self.mlp=Sequential(Linear(in_channels+3,out_channels),ReLU(),Linear(out_channels,out_channels))
        
    def forward(self,h,pos,edge_index):
        return self.propagate(edge_index,h=h,pos=pos)
    
    def message(self,h_j,pos_j,pos_i):
        input=pos_j-pos_i
        if h_j is not None:
            input=torch.cat([h_j,input],dim=-1)
        return self.mlp(input)
from torch_geometric.nn import global_max_pool

class PointNet(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1=PointNetLayer(3,32)
        self.conv2=PointNetLayer(32,32)
        self.classifier=Linear(32,dataset.num_classes)
        
    def forward(self,pos,batch):
        edge_index=knn_graph(pos,k=16,batch=batch,loop=True)
        h=self.conv1(h=pos,pos=pos,edge_index=edge_index)
        h=h.relu()
        h=self.conv2(h=h,pos=pos,edge_index=edge_index)
        h=h.relu()
        h=global_max_pool(h,batch)
        return self.classifier(h)

model=PointNet()
print(model)
#输出如下
#PointNet(
#  (conv1): PointNetLayer()
#  (conv2): PointNetLayer()
#  (classifier): Linear(in_features=32, out_features=40, bias=True)
#)

导库,训测拆分数据变换以及划分批量

from torch_geometric.loader import DataLoader
train_dataset=GeometricShapes(root='/Data/GeometricShapes',train=True,transform=SamplePoints(128))
test_dataset=GeometricShapes(root='/Data/GeometricShapes',train=False,transform=SamplePoints(128))
train_loader=DataLoader(train_dataset,batch_size=10,shuffle=True)
test_loader=DataLoader(test_dataset,batch_size=10)

进行实验

model=PointNet();optimizer=torch.optim.Adam(model.parameters(),lr=0.01);criterion=torch.nn.CrossEntropyLoss()

def train(model,optimizer,loader):
    model.train()
    total_loss=0
    for data in loader:
        optimizer.zero_grad()
        logits=model(data.pos,data.batch)
        loss=criterion(logits,data.y)
        loss.backward()
        optimizer.step()
        total_loss+=loss.item()*data.num_graphs
    return total_loss/len(train_loader.dataset)

def test(model,loader):
    model.eval()
    total_correct=0
    for data in loader:
        logits=model(data.pos,data.batch)
        pred=logits.argmax(dim=-1)
        total_correct+=int((pred==data.y).sum())
    return total_correct/len(loader.dataset)

for epoch in range(1,51):
    loss=train(model,optimizer,train_loader)
    test_acc=test(model,test_loader)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Test Accuracy: {test_acc:.4f}')
#输出如下(这里只有最后一次):
#Epoch: 50, Loss: 0.7294, Test Accuracy: 0.8250

模型问题

出现问题: 由于模型使用坐标进行输入并且选择笛卡尔坐标系传递信息所以旋转坐标就不可行。可以按照如下方式进行实验。

from torch_geometric.transforms import Compose,RandomRotate
random_rotate=Compose([
    RandomRotate(degrees=180,axis=0),
    RandomRotate(degrees=180,axis=1),
    RandomRotate(degrees=180,axis=2),
])
dataset=GeometricShapes(root='/DATA//GeometricShapes',transform=random_rotate)
data=dataset[0]
print(data)
visualize_mesh(data.pos,data.face)
data=dataset[4]
print(data)
visualize_mesh(data.pos,data.face)

jupyter notebook内输出如下
在这里插入图片描述

transform=Compose([
    random_rotate,
    SamplePoints(num=128),
])
test_dataset=GeometricShapes(root='/DATA/GeometricShapes',train=False,transform=transform)
test_loader=DataLoader(test_dataset,batch_size=10)
test_acc=test(model,test_loader)
print(f'Test Accuracy: {test_acc:.4f}')
#输出如下:
#Test Accuracy: 0.2000
print(len(test_dataset))
#输出如下:
#40

可以看到,模型效果,就不好了。有解决方法的。暂时就这样吧。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/22092.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

智慧井盖监测终端,智能井盖-以科技解决智慧城市“顽疾”,守护城市生命线

平升电子智慧井盖监测终端,智能井盖-以科技解决智慧城市“顽疾”,守护城市生命线-智慧井盖,实现对井下设备和井盖状态的监测及预警,是各类智慧管网管理系统中不可或缺的重要设备,解决了井下监测环境潮湿易水淹、电力供应困难、通讯不畅等难题…

XDP入门--BPF程序如何转发报文到其它网卡

本文目录 1、测试环境:2、实现的功能,使用bpf_redirect直接转发收到的报文到另外一张网卡3、测试步骤与测试结果 1、测试环境: 参照把树莓派改造成无线网卡(3)-----共享无线网络,无线网络转换成有线网络,让有线网络设…

插入排序、选择排序、冒泡排序小结(45)

小朋友们好,大朋友们好! 我是猫妹,一名爱上Python编程的小学生。 和猫妹学Python,一起趣味学编程。 今日主题 插入排序、选择排序、冒泡排序有什么区别? 原理不同 插入排序是将未排序的元素逐个插入到已排序序列中…

Unity之ASE从入门到精通 目录

前言 Amplify Shader Editor (ASE) 是受行业领先软件启发的基于节点着色器创建工具。它是一个开放且紧密集成的解决方案,提供了熟悉和连贯的开发环境,使 Unity 的 UI 约定和着色器的使用无缝地融合一起 目录 这里是ASE从入门到精通专栏的目录,不停更新中,有问题随时留…

入门JavaScript编程:上手实践四个常见操作和一个轮播图案例

部分数据来源:ChatGPT 简介 JavaScript是一门广泛应用于Web开发的脚本语言,它主要用于实现动态效果和客户端交互。下面我们将介绍几个例子,涵盖了JavaScript中一些常见的操作,包括:字符串、数组、对象、事件等。 例子…

rk3568 适配rk809音频

rk3568 适配rk809音频 RK809是一款集成了多种功能的电源管理芯片,主要用于笔记本电脑、平板电脑、工控机等设备的电源管理。以下是RK809的详细功能介绍: 电源管理:控制电源的开关、电压、电流等参数,保证设备的稳定运行。音频管…

Unity之使用Photon PUN开发多人游戏教程

前言 Photon是一个网络引擎和多人游戏平台,可以处理其服务器上的所有请求,我们可以在 Unity(或其他游戏引擎)中使用它,并快速把游戏接入Photon的网络中,而我们就可以专注于在项目中添加逻辑,专注于游戏玩法和功能了。 PUN(Photon Unity Networking)是一种开箱即用的解…

什么是DevOps?如何理解DevOps思想?

博文参考总结自:https://www.kuangstudy.com/course/play/1573900157572333569 仅供学习使用,若侵权,请联系我删除! 1、什么是DevOps? DevOps是一种思想或方法论,它涵盖开发、测试、运维的整个过程。DevOps强调软件开…

Maven方式构建Spring Boot项目

文章目录 一,创建Maven项目二,添加依赖三,创建入口类四,创建控制器五,运行入口类六,访问Web页面七,修改访问映射路径八,定制启动标语1、创建标语文件2、生成标语字符串3、编辑标语文…

DNDC模型在土地利用变化、未来气候变化下的建模方法及温室气体时空动态模拟实践技术

DNDC模型讲解 1.1 碳循环模型简介 1.2 DNDC模型原理 1.3 DNDC下载与安装 1.4 DNDC注意事项 ​ DNDC初步操作 2.1 DNDC界面介绍 2.2 DNDC数据及格式 2.3 DNDC点尺度模拟 2.4 DNDC区域尺度模拟 2.5 DNDC结果分析 ​ DNDC气象数据制备 3.1 数据制备中的遥感和GIS技术 3…

Vue3 + TypeScript + Uniapp 开发小程序【医疗小程序完整案例·一篇文章精通系列】

当今的移动应用市场已经成为了一个日趋竞争激烈的领域,而开发一个既能在多个平台上运行,又能够高效、可维护的应用则成为了一个急需解决的问题。 在这个领域中,Vue3 TypeScript Uniapp 的组合已经成为了一种受欢迎的选择,特别…

ODB 2.4.0 使用延迟指针 lazy_shared_ptr 时遇到的问题

最近在学习使用C下的ORM库——ODB,来抽象对数据库的CURD,由于C的ORM实在是太冷门了,ODB除了官方英语文档,几乎找不到其他好用的资料,所以在使用过程中也是遇到很多疑惑,也解决很多问题。近期遇到的一个源码…

推荐系统系列之推荐系统概览(下)

在推荐系统概览的第一讲中,我们介绍了推荐系统的常见概念,常用的评价指标以及首页推荐场景的通用召回策略。本文我们将继续介绍推荐系统概览的其余内容,包括详情页推荐场景中的通用召回策略,排序阶段常用的排序模型,推…

Keil Debug 逻辑分析仪使用

Keil Debug 逻辑分析仪使用 基础配置 更改对应的bebug窗口参数 两边的 Dialog DLL 更改为:DARMSTM.DLL两边的 Parameter (这里的根据单片机型号更改)更改为:-pSTM32F103VE 选择左边的 Use Simulator 选项。 打开Debug和其中的逻…

数据全生命周期管理

数据存储 时代"海纳百川,有容乃大"意味结构化、半结构和非结构化多样化的海量的 ,也意味着批数据和流数据多种数据形式的存储和计算。面对不同数据结构、数据形式、时效性与性能要求和存储与计算成本等因素考虑,应该使用适合的存储…

iptables防火墙(二)

iptables防火墙(二) 一、SNAT策略1、SNAT策略简述2、配置实验 二、DNAT策略1、DNAT策略简述2、配置实验 三、Linux抓包工具tcpdump四、防火墙规则保存 一、SNAT策略 1、SNAT策略简述 SNAT策略就是将从内网传给外网的数据包的源IP由私网IP转换成公网IP&…

四川省信创联盟2023年第一次理事会顺利召开,MIAOYUN荣获“信创企业优秀奖”!

5月18日,四川省技术创新促进会信创工委会(四川省信创产业联盟)在成都市高新区新川科技园成功召开《2023年第一次理事单位(扩大)会议》,四川省技术创新促进会专家组杜纯文副组长、四川省技术创新促进会任渝英…

EasyRecovery16适用于Windows和Mac的专业硬盘恢复软件

无论你对数据恢复了解多少, 我们将为您处理所有复杂的流程并简化恢复!适用于Windows和Mac的 专业硬盘恢复软件 硬盘数据无法保证绝对安全。有时会发生数据丢失,需要使用硬盘恢复工具。支持恢复不同存储介质数据:硬盘、光盘、U盘/移动硬盘、数…

AC规则-1

本文主要参考规范 GPD_Secure Element Access Control_vxxx.pdf OMA 架构 基本定义 GP(GlobalPlatform)定义了一套允许各应用提供方独立且安全地管理其在SE上的应用的安全框架,而AC(Access Control),顾名思义,是对外部应用进行SE上应用访问…

网络知识点之-动态路由

动态路由是指路由器能够自动地建立自己的路由表,并且能够根据实际情况的变化适时地进行调整。 中文名:动态路由外文名:dynamic routing 简述 动态路由是与静态路由相对的一个概念,指路由器能够根据路由器之间的交换的特定路由信息…