实验记录 | 点云处理 | K-NN算法3种实现的性能比较

引言

K近邻(K-Nearest Neighbors, KNN)算法作为一种经典的无监督学习算法,在点云处理中的应用尤为广泛。它通过计算点与点之间的距离来寻找数据点的邻居,从而有效进行点云分类、聚类和特征提取。本菜在复现点云文章过程,遇到了三种 KNN 的实现方式,故在此一并对比总结,最后对三种实现方案进行了性能比较

在本文中,我将K近邻(KNN)算法的应用分为两种情况:

  • 全局查询:对整个点云的所有 N 个点进行查询,找到每个点的 K 个最近邻点,最终返回的结果维度为 [B, N, K],B 表示批次大小,N 表示点的总数量,K 表示每个点的邻近点数量。

  • 局部查询:针对已知的 S 个查询点,在整个点云的 N 个点中寻找每个查询点的 K 个最近邻点,最终返回的结果维度为 [B, S, K],其中 S 表示查询点的数量。


全局查询

def knn(x, k):
    """
    Input:
        x: all points, [B, C, N]
        k: k nearest points of each point
    Return:
        idx: grouped points index, [B, N, k]
    """
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

这段代码来源于点云网络的高引之作《Dynamic Graph CNN for Learning on Point Clouds》,实现了一个 KNN(K近邻)查询,目的是计算点云中每个点的 k 个最近邻点的索引。

函数清晰易懂,便不赘述。我一直以为点云学习是需要先采样,再用采样得到的中心点进行 KNN 邻域查询,直到看到这篇 DGCNN 的方法,才打破了我的固有认知:DGCNN没有下采样过程,直接使用 N 个点进行近邻查询和特征更新。

插个题外话,这篇文章真的值得一读,简单高效!不愧是高引之作。


局部查询

(1)knn_point 函数

def square_distance(src, dst):
    """
    Calculate Euclid distance between each two points.
    src^T * dst = xn * xm + yn * ym + zn * zm;
    sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
    sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
    dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
         = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
    Input:
        src: source points, [B, N, C]
        dst: target points, [B, M, C]
    Output:
        dist: per-point square distance, [B, N, M]
    """
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

def knn_point(nsample, xyz, new_xyz):
    """
    Input:
        nsample: max sample number in local region
        xyz: all points, [B, N, C]
        new_xyz: query points, [B, S, C]
    Return:
        group_idx: grouped points index, [B, S, nsample]
    """
    sqrdists = square_distance(new_xyz, xyz)
    _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
    return group_idx

这段代码来源于另一个高引之作《Rethinking Network Design and Local Geometry in Point Cloud: A Simple Residual MLP Framework》,代码也是相当眉清目秀,不再赘述。其实这份代码的实现还是比较经典的,很多的模型代码都可以看到它的身影。


(2)knn_cuda 库函数

import torch

# Make sure your CUDA is available.
assert torch.cuda.is_available()

from knn_cuda import KNN
"""
if transpose_mode is True, 
    ref   is Tensor [bs x nr x dim]
    query is Tensor [bs x nq x dim]
    
    return 
        dist is Tensor [bs x nq x k]
        indx is Tensor [bs x nq x k]
else
    ref   is Tensor [bs x dim x nr]
    query is Tensor [bs x dim x nq]
    
    return 
        dist is Tensor [bs x k x nq]
        indx is Tensor [bs x k x nq]
"""

knn = KNN(k=10, transpose_mode=True)

ref = torch.rand(32, 1000, 5).cuda()
query = torch.rand(32, 50, 5).cuda()

dist, indx = knn(ref, query)  # 32 x 50 x 10

大佬把 KNN 封装为了库函数,来源于 KNN_CUDA 此仓库,可以参考 readme 进行安装。库函数的调用也非常方便。

需要强调的是,这里提到的 knn_point 和 knn_cuda 虽然算局部查询,但其实只要将局部查询点云 [B, S, Dim] 换成全局点云 [B, N, Dim] 作为输入,也就是全局查询了


性能比较

(1)测试代码

import torch
import time
from knn_cuda import KNN

def knn(x, k):
    inner = -2*torch.matmul(x.transpose(2, 1), x)
    xx = torch.sum(x**2, dim=1, keepdim=True)
    pairwise_distance = -xx - inner - xx.transpose(2, 1)
 
    idx = pairwise_distance.topk(k=k, dim=-1)[1]   # (batch_size, num_points, k)
    return idx

def square_distance(src, dst):
    B, N, _ = src.shape
    _, M, _ = dst.shape
    dist = -2 * torch.matmul(src, dst.permute(0, 2, 1))
    dist += torch.sum(src ** 2, -1).view(B, N, 1)
    dist += torch.sum(dst ** 2, -1).view(B, 1, M)
    return dist

def knn_point(nsample, xyz, new_xyz):
    sqrdists = square_distance(new_xyz, xyz)
    _, group_idx = torch.topk(sqrdists, nsample, dim=-1, largest=False, sorted=False)
    return group_idx

# Custom knn implementation
def test_knn(query, k, times):
    query = query.permute(0,2,1)
    start_time = time.time()  # Start timer
    for i in range(times):
        indx = knn(query, k = k)
    end_time = time.time()  # End timer
    return end_time - start_time  # Return elapsed time

# Custom knn_point implementation
def test_knn_point(ref, query, k, times):
    start_time = time.time()  # Start timer
    for i in range(times):
        indx = knn_point(k, ref, query)
    end_time = time.time()  # End timer
    return end_time - start_time  # Return elapsed time

# knn_cuda implementation
def test_knn_cuda(ref, query, k, times):
    knn = KNN(k=k, transpose_mode=True)
    start_time = time.time()  # Start timer
    for i in range(times):
        dist, indx = knn(ref, query)
    end_time = time.time()  # End timer
    return end_time - start_time  # Return elapsed time


# Main testing function
def test_knn_methods(ref, query, k, times):

    print("Test times: %d" % times)

    # Test custom knn
    time_knn = test_knn(query, k, times)
    print(f"knn      : {time_knn:.6f} seconds")

    # Test custom knn_point
    time_point = test_knn_point(ref, query, k, times)
    print(f"knn_point: {time_point:.6f} seconds")

    # Test knn_cuda
    time_cuda = test_knn_cuda(ref, query, k, times)
    print(f"knn_cuda : {time_cuda:.6f} seconds")
    

if __name__ == '__main__':

    # Sample input
    B, N, S, C = 32, 1024, 50, 3      # Batch size, total points, query points, coordinates
    k = 24                            # Number of nearest neighbors
    ref = torch.randn(B, N, C).cuda() # Reference points

    # Test above methods
    times_list = [1,2,3,10,50,100]
    for times in times_list:
        test_knn_methods(ref, ref, k, times)

这段代码测试了三种 K 近邻(KNN)算法的实现效率,分别是自定义的 knnknn_point 以及基于 knn_cuda 库的实现。分别对每种方法运行多次,记录每种方法在不同重复次数(如 1、2、3、10、50、100 次)的运行时间,最终输出各方法的执行时间。

图注:三种实现方法的性能测评结果

上图展示了测试代码的结果,可以看到 knn_cuda 的实现方式表现最差的(我也表示非常不理解);knn 和 knn_point 性能表现相当。或许这也是为什么很多较新的模型使用的也是 knn_point,而不是 knn_cuda。

当然,这份测试代码实际是在一个小规模数据的单卡上进行的,或许无法很好地展现出他们在实际训练的性能,因此我又分别将他们部署在 DGCNN 模型上进行训练,对比性能。


(2)模型训练

图注:使用 knn 函数的训练时间
图注:使用 knn 函数的训练时间

图注:使用 knn_point 的训练时间

图注:使用 knn_cuda 库的训练时间

 

直接将他们部署在模型的训练中,能够最真实反映出他们的性能。这次实验,Batchsize 设置为了32,epoch 设置为256,选择前2个epoch观察。从训练状态可以看到,红色框选区域表示训练和测试的时间,knn_cuda 依然稳定发挥,表现最差哈哈哈哈,knn 和 knn_point 的函数实现表现相当。


总结

我原以为 knn_cuda 会很厉害,毕竟是直接封装起来了,但实际表现不尽人意。看似很小的性能差异,放在规模较大的数据集上,训练成本可是指数级倍增的。所以,还是尽可能使用 knn 和 knn_point 来实现全局/局部的邻近查询。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/873280.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

详解React setState调用原理和批量更新的过程

1. React setState 调用的原理 setState目录 1. React setState 调用的原理2. React setState 调用之后发生了什么?是同步还是异步?3. React中的setState批量更新的过程是什么? 具体的执行过程如下(源码级解析)&#x…

基于SpringBoot+Vue+MySQL的宿舍维修管理系统

系统展示 前台界面 管理员界面 维修员界面 学生界面 系统背景 在当今高校后勤管理的日益精细化与智能化背景下,宿舍维修管理系统作为提升校园生活品质、优化资源配置的关键环节,其重要性日益凸显。随着学生规模的扩大及住宿条件的不断提升,宿…

人机交互系统中的人脸讲话生成系统调研

《Human-Computer Interaction System: A Survey of Talking-Head Generation》 图片源:https://github.com/Yazdi9/Talking_Face_Avatar 目录 前言摘要一、背景介绍二、人机交互系统体系结构2.1. 语音模块2.2. 对话系统模块2.3. 人脸说话动作生成 三 人脸动作生成…

来啦| LVMH路威酩轩25届校招智鼎高潜人才思维能力测验高分攻略

路威酩轩香水化妆品(上海)有限公司是LVMH集团于2000年成立,负责集团旗下的部分香水化妆品品牌在中国的销售包括迪奥、娇兰、纪梵希、贝玲妃、玫珂菲、凯卓、帕尔马之水以及馥蕾诗等。作为目前全球最大的奢侈品集团LVMH 集团秉承悠久的历史,不断打破常规&…

【微处理器系统原理和应用设计第六讲】片上微处理器系统系统架构

一、概念辨析 首先来厘清以下概念:微处理器,微控制器,单片机,片上微处理器系统 (1)微处理器:即MPU(Microprocessor Unit),微处理器是一种计算机的中央处理单…

如何打造个性化大学生聊天室?Java SpringBoot Vue实战,2025最新设计指南

✍✍计算机毕业编程指导师** ⭐⭐个人介绍:自己非常喜欢研究技术问题!专业做Java、Python、微信小程序、安卓、大数据、爬虫、Golang、大屏等实战项目。 ⛽⛽实战项目:有源码或者技术上的问题欢迎在评论区一起讨论交流! ⚡⚡ Java…

【深度学习】向量化

1. 什么是向量化 向量化通常是消除代码中显示for循环语句的技巧,在深度学习实际应用中,可能会遇到大量的训练数据,因为深度学习算法往往在这种情况下表现更好,所以代码的运行速度非常重要,否则如果它运行在一个大的数据…

【Linux】翻山越岭——进程地址空间_c语言父子进程地址空间

文章目录 一、是什么 写时拷贝 二、为什么三、怎么做 区域划分和调整 一、是什么 回顾我们学习C/C时的地址空间: 有了这个基本框架,我们对于语言的学习更加易于理解,但是地址空间究竟是什么❓我们对其并不了解,是不是内存呢&…

海外云服务器安装 MariaDB10.6.X (Ubuntu 18.04 记录篇二)

本文首发于 秋码记录 MariaDB 的由来(历史) 谈起新秀MariaDB,或许很多人都会感到陌生吧,但若聊起享誉开源界、业界知名的关系型数据库——Mysql,想必混迹于互联网的人们(coder)无不知晓。 其…

MonoHuman: Animatable Human Neural Field from Monocular Video 精读

一、共享双向变形模块 1. 模块的核心思想 共享双向变形模块的核心目标是解决从单目视频中生成不同姿态下的3D人体形状问题。因为视频中的人物可能处于各种动态姿态下,模型需要能够将这些不同姿态的几何形状进行变形处理,以适应标准的姿态表示并生成新的…

SVN下载安装使用方法

目录 🌕SVN是什么?🌙SVN跟Git比的优势🌙SVN的用处 🌕下载安装使用方法 🌕🌙⭐ 🌕SVN是什么? 代码版本管理工具 它能记住你每次的修改 查看所有的修改记录 恢复到任何历…

【Linux网络】详解TCP协议(1)

🎉博主首页: 有趣的中国人 🎉专栏首页: Linux网络 🎉其它专栏: C初阶 | C进阶 | 初阶数据结构 小伙伴们大家好,本片文章将会讲解 TCP协议 的相关内容。 如果看到最后您觉得这篇文章写得不错&am…

【大数据】深入浅出Hadoop,干货满满

【大数据】深入浅出Hadoop 文章脉络 Hadoop HDFS MapReduce YARN Hadoop集群硬件架构 假设现在有一个PB级别的数据库表要处理。 在单机情况下,只能升级你的内存、磁盘、CPU,那么这台机器就会变成 “超算”,成本太高,商业公司肯…

通过卷积神经网络(CNN)识别和预测手写数字

一:卷积神经网络(CNN)和手写数字识别MNIST数据集的介绍 卷积神经网络(Convolutional Neural Networks,简称CNN)是一种深度学习模型,它在图像和视频识别、分类和分割任务中表现出色。CNN通过模仿…

8. GIS数据分析师岗位职责、技术要求和常见面试题

本系列文章目录: 1. GIS开发工程师岗位职责、技术要求和常见面试题 2. GIS数据工程师岗位职责、技术要求和常见面试题 3. GIS后端工程师岗位职责、技术要求和常见面试题 4. GIS前端工程师岗位职责、技术要求和常见面试题 5. GIS工程师岗位职责、技术要求和常见面试…

【高等代数笔记】线性空间(一到四)

3. 线性空间 令 K n : { ( a 1 , a 2 , . . . , a n ) ∣ a i ∈ K , i 1 , 2 , . . . , n } \textbf{K}^{n}:\{(a_{1},a_{2},...,a_{n})|a_{i}\in\textbf{K},i1,2,...,n\} Kn:{(a1​,a2​,...,an​)∣ai​∈K,i1,2,...,n},称为 n n n维向量 规定(规定…

【技术前沿】智能反向寻车解决方案:提升停车场用户体验与运营效率

亲爱的技术员及停车场管理者们,您是否曾遇到过车主在庞大的停车场中迷失方向,耗费大量时间寻找爱车的困境?这不仅影响了车主的停车体验,也无形中增加了停车场的管理难度和运营成本。本文专为解决这一痛点而生,介绍最新…

基于人工智能的手写数字识别系统

目录 引言项目背景环境准备 硬件要求软件安装与配置系统设计 系统架构关键技术代码示例 数据预处理模型训练模型预测应用场景结论 1. 引言 手写数字识别是一种经典的计算机视觉任务,目标是让机器能够识别手写数字。通过人工智能技术,特别是卷积神经网…

京东物流查询|开发者调用API接口实现

快递聚合查询的优势 1、高效整合多种快递信息。2、实时动态更新。3、自动化管理流程。 聚合国内外1500家快递公司的物流信息查询服务,使用API接口查询京东物流的便捷步骤,首先选择专业的数据平台的快递API接口:物流快递查询API接口-单号查询…

【论文分享】MyTEE: Own the Trusted Execution Environment on Embedded Devices 23‘NDSS

目录 AbstractINTRODUCTIONBACKGROUNDARMv8 ArchitectureSecurity statesTrustZone extensionsVirtualization Communication with Peripherals MOTIVATIONATTACK MODEL AND ASSUMPTIONSYSTEM DESIGNOverviewExecution Environments IsolationDMA FilterExternal DMA controlle…