基于ROPNet项目训练modelnet40数据集进行3d点云的配置

项目地址: https://github.com/zhulf0804/ROPNet 在 MVP Registration Challenge (ICCV Workshop 2021)(ICCV Workshop 2021)中获得了第二名。项目可以在win10环境下运行。
论文地址: https://arxiv.org/abs/2107.02583

网络简介: 一种新的深度学习模型,该模型利用具有区别特征的代表性重叠点进行配准,将部分到部分配准转换为部分完全配准。基于pointnet输出的特征设计了一个上下文引导模块,使用一个编码器来提取全局特征来预测点重叠得分。为了更好地找到有代表性的重叠点,使用提取的全局特征进行粗对齐。然后,引入一种变压器来丰富点特征,并基于点重叠得分和特征匹配去除非代表性点。在部分到完全的模式下建立相似度矩阵,最后采用加权支持向量差来估计变换矩阵。
在这里插入图片描述
实施效果: 从数据上看ROPNet与RPMNet与保持了断崖式的领先地位
在这里插入图片描述

1、运行环境安装

1.1 项目下载

打开https://github.com/zhulf0804/ROPNet,点Download ZIP然后将代码解压到指定目录下即可。
在这里插入图片描述

1.2 依赖项安装

在装有pytorch的环境终端,进入ROPNet-master/src目录,执行以下安装命令。如果已经安装了torch 环境和open3d包,则不用再进行安装了

pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

pip install open3d

1.3 模型与数据下载

modelnet40数据集 here [435M]
数据集下载后存储为以下路径即可。
在这里插入图片描述

官网预训练模型,无。
第三方预训练模型:使用ROPNet项目在modelnet40数据集上训练的模型

2、关键代码

2.1 dataloader

作者所提供的dataloader只能加载https://shapenet.cs.stanford.edu/media/modelnet40_ply_hdf5_2048.zip 数据集,其所返回的tgt_cloud, src_cloud实质上是基于一个点云采样而来的。 其中的self.label2cat, self.cat2label, self.symmetric_labels等对象代码实际上是没有任何作用的。

import copy
import h5py
import math
import numpy as np
import os
import torch

from torch.utils.data import Dataset
import sys

BASE_DIR = os.path.dirname(os.path.abspath(__file__))
ROOR_DIR = os.path.dirname(BASE_DIR)
sys.path.append(ROOR_DIR)
from utils import  random_select_points, shift_point_cloud, jitter_point_cloud, \
    generate_random_rotation_matrix, generate_random_tranlation_vector, \
    transform, random_crop, shuffle_pc, random_scale_point_cloud, flip_pc
    


half1 = ['airplane', 'bathtub', 'bed', 'bench', 'bookshelf', 'bottle', 'bowl',
         'car', 'chair', 'cone', 'cup', 'curtain', 'desk', 'door', 'dresser',
         'flower_pot', 'glass_box', 'guitar', 'keyboard', 'lamp']
half1_symmetric = ['bottle', 'bowl', 'cone', 'cup', 'flower_pot', 'lamp']

half2 = ['laptop', 'mantel', 'monitor', 'night_stand', 'person', 'piano',
         'plant', 'radio', 'range_hood', 'sink', 'sofa', 'stairs', 'stool',
         'table', 'tent', 'toilet', 'tv_stand', 'vase', 'wardrobe', 'xbox']
half2_symmetric = ['tent', 'vase']


class ModelNet40(Dataset):
    def __init__(self, root, split, npts, p_keep, noise, unseen, ao=False,
                 normal=False):
        super(ModelNet40, self).__init__()
        self.single = False # for specific-class visualization
        assert split in ['train', 'val', 'test']
        self.split = split
        self.npts = npts
        self.p_keep = p_keep
        self.noise = noise
        self.unseen = unseen
        self.ao = ao # Asymmetric Objects
        self.normal = normal
        self.half = half1 if split in 'train' else half2
        self.symmetric = half1_symmetric + half2_symmetric
        self.label2cat, self.cat2label = self.label2category(
            os.path.join(root, 'shape_names.txt'))
        self.half_labels = [self.cat2label[cat] for cat in self.half]
        self.symmetric_labels = [self.cat2label[cat] for cat in self.symmetric]
        files = [os.path.join(root, 'ply_data_train{}.h5'.format(i))
                 for i in range(5)]
        if split == 'test':
            files = [os.path.join(root, 'ply_data_test{}.h5'.format(i))
                     for i in range(2)]
        self.data, self.labels = self.decode_h5(files)
        print(f'split: {self.split}, unique_ids: {len(np.unique(self.labels))}')

        if self.split == 'train':
            self.Rs = [generate_random_rotation_matrix() for _ in range(len(self.data))]
            self.ts = [generate_random_tranlation_vector() for _ in range(len(self.data))]

    def label2category(self, file):
        with open(file, 'r') as f:
            label2cat = [category.strip() for category in f.readlines()]
            cat2label = {label2cat[i]: i for i in range(len(label2cat))}
        return label2cat, cat2label

    def decode_h5(self, files):
        points, normal, label = [], [], []
        for file in files:
            f = h5py.File(file, 'r')
            cur_points = f['data'][:].astype(np.float32)
            cur_normal = f['normal'][:].astype(np.float32)
            cur_label = f['label'][:].flatten().astype(np.int32)
            if self.unseen:
                idx = np.isin(cur_label, self.half_labels)
                cur_points = cur_points[idx]
                cur_normal = cur_normal[idx]
                cur_label = cur_label[idx]
            if self.ao and self.split in ['val', 'test']:
                idx = ~np.isin(cur_label, self.symmetric_labels)
                cur_points = cur_points[idx]
                cur_normal = cur_normal[idx]
                cur_label = cur_label[idx]
            if self.single:
                idx = np.isin(cur_label, [8])
                cur_points = cur_points[idx]
                cur_normal = cur_normal[idx]
                cur_label = cur_label[idx]
            points.append(cur_points)
            normal.append(cur_normal)
            label.append(cur_label)
        points = np.concatenate(points, axis=0)
        normal = np.concatenate(normal, axis=0)
        data = np.concatenate([points, normal], axis=-1).astype(np.float32)
        label = np.concatenate(label, axis=0)
        return data, label

    def compose(self, item, p_keep):
        tgt_cloud = self.data[item, ...]
        if self.split != 'train':
            np.random.seed(item)
            R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()
        else:
            tgt_cloud = flip_pc(tgt_cloud)
            R, t = generate_random_rotation_matrix(), generate_random_tranlation_vector()

        src_cloud = random_crop(copy.deepcopy(tgt_cloud), p_keep=p_keep[0])
        src_size = math.ceil(self.npts * p_keep[0])
        tgt_size = self.npts
        if len(p_keep) > 1:
            tgt_cloud = random_crop(copy.deepcopy(tgt_cloud),
                                    p_keep=p_keep[1])
            tgt_size = math.ceil(self.npts * p_keep[1])

        src_cloud_points = transform(src_cloud[:, :3], R, t)
        src_cloud_normal = transform(src_cloud[:, 3:], R)
        src_cloud = np.concatenate([src_cloud_points, src_cloud_normal],
                                   axis=-1)
        src_cloud = random_select_points(src_cloud, m=src_size)
        tgt_cloud = random_select_points(tgt_cloud, m=tgt_size)

        if self.split == 'train' or self.noise:
            src_cloud[:, :3] = jitter_point_cloud(src_cloud[:, :3])
            tgt_cloud[:, :3] = jitter_point_cloud(tgt_cloud[:, :3])
        tgt_cloud, src_cloud = shuffle_pc(tgt_cloud), shuffle_pc(
            src_cloud)
        return src_cloud, tgt_cloud, R, t

    def __getitem__(self, item):
        src_cloud, tgt_cloud, R, t = self.compose(item=item,
                                                  p_keep=self.p_keep)
        if not self.normal:
            tgt_cloud, src_cloud = tgt_cloud[:, :3], src_cloud[:, :3]
        return tgt_cloud, src_cloud, R, t

    def __len__(self):
        return len(self.data)

2.2 模型设计

模型设计如下:
在这里插入图片描述

2.3 loss设计

其主要包含Init_loss、Refine_loss和Ol_loss。
其中Init_loss是用于计算 预测点 云 0 预测点云_0 预测点0与目标点云的mse或mae loss,
Refine_loss用于计算 预测点 云 [ 1 : ] 预测点云_{[1:]} 预测点[1:]与目标点云的加权mae loss
Ol_loss用于计算两个输入点云输出的重叠分数,使两个点云对应点的重叠分数是一样的。
在这里插入图片描述

具体实现代码如上:


import math
import torch
import torch.nn as nn
from utils import square_dists


def Init_loss(gt_transformed_src, pred_transformed_src, loss_type='mae'):

    losses = {}
    num_iter = 1
    if loss_type == 'mse':
        criterion = nn.MSELoss(reduction='mean')
        for i in range(num_iter):
            losses['mse_{}'.format(i)] = criterion(pred_transformed_src[i],
                                                   gt_transformed_src)
    elif loss_type == 'mae':
        criterion = nn.L1Loss(reduction='mean')
        for i in range(num_iter):
            losses['mae_{}'.format(i)] = criterion(pred_transformed_src[i],
                                                   gt_transformed_src)
    else:
        raise NotImplementedError

    total_losses = []
    for k in losses:
        total_losses.append(losses[k])
    losses = torch.sum(torch.stack(total_losses), dim=0)
    return losses


def Refine_loss(gt_transformed_src, pred_transformed_src, weights=None, loss_type='mae'):
    losses = {}
    num_iter = len(pred_transformed_src)
    for i in range(num_iter):
        if weights is None:
            losses['mae_{}'.format(i)] = torch.mean(
                torch.abs(pred_transformed_src[i] - gt_transformed_src))
        else:
            losses['mae_{}'.format(i)] = torch.mean(torch.sum(
                weights * torch.mean(torch.abs(pred_transformed_src[i] -
                                               gt_transformed_src), dim=-1)
                / (torch.sum(weights, dim=-1, keepdim=True) + 1e-8), dim=-1))

    total_losses = []
    for k in losses:
        total_losses.append(losses[k])
    losses = torch.sum(torch.stack(total_losses), dim=0)

    return losses


def Ol_loss(x_ol, y_ol, dists):
    CELoss = nn.CrossEntropyLoss()
    x_ol_gt = (torch.min(dists, dim=-1)[0] < 0.05 * 0.05).long() # (B, N)
    y_ol_gt = (torch.min(dists, dim=1)[0] < 0.05 * 0.05).long() # (B, M)
    x_ol_loss = CELoss(x_ol, x_ol_gt)
    y_ol_loss = CELoss(y_ol, y_ol_gt)
    ol_loss = (x_ol_loss + y_ol_loss) / 2
    return ol_loss


def cal_loss(gt_transformed_src, pred_transformed_src, dists, x_ol, y_ol):
    losses = {}
    losses['init'] = Init_loss(gt_transformed_src,
                               pred_transformed_src[0:1])
    if x_ol is not None:
        losses['ol'] = Ol_loss(x_ol, y_ol, dists)
    losses['refine'] = Refine_loss(gt_transformed_src,
                                   pred_transformed_src[1:],
                                   weights=None)
    alpha, beta, gamma = 1, 0.1, 1
    if x_ol is not None:
        losses['total'] = losses['init'] + beta * losses['ol'] + gamma * losses['refine']
    else:
        losses['total'] = losses['init'] + losses['refine']
    return losses

3、训练与预测

先进入src目录,并将modelnet40_ply_hdf5_2048.zip解压在src目录下
在这里插入图片描述

3.1 训练

训练命令及训练输出如下所示

python train.py --root modelnet40_ply_hdf5_2048/ --noise --unseen

python请添加图片描述
在训练过程中会在work_dirs\models\checkpoints目录下生成两个模型文件
在这里插入图片描述

3.2 验证

训练命令及训练输出如下所示

python eval.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --cuda --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

请添加图片描述

3.3 测试

测试训练数据的命令如下

python vis.py --root modelnet40_ply_hdf5_2048/  --unseen --noise  --checkpoint work_dirs/models/checkpoints/min_rot_error.pth

具体配准效果如下所示,其中绿色点云为输入点云,红色点云为参考点云,蓝色点云为配准后的点云。可以看到蓝色点云基本与红色点云重合,可以确定其配准效果十分完好。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.4 处理自己的数据集

基于该项目训练并处理自己数据的教程后续会给出。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/220322.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

flask web学习之flask与http(一)

文章目录 一、请求响应循环二、HTTP请求1. 请求报文2. request对象3. 在flask中处理请求3.1 路由匹配3.2 设置监听的http方法3.3 URL处理 三、请求钩子 一、请求响应循环 每一个web应用都包含这种处理方式&#xff0c;请求-响应循环&#xff1a;客户端发出请求&#xff0c;服务…

实战经验分享,Python 连接 Oracle 踩坑实录

最近的一个测试任务需要测试 oracle 同步 hive 数据库的性能&#xff0c;那就需要对 oracle 数据库灌注测试数据。我就又打开了我的IDE&#xff0c;准备把我之前一下可以灌50w数据到 MySQL 的代码&#xff0c;改一改&#xff0c;直接用。 因为我在网上看到&#xff0c;语法上也…

基于Springboot的社区医院管理服务系统(有报告)。Javaee项目,springboot项目。

演示视频&#xff1a; 基于Springboot的社区医院管理服务系统&#xff08;有报告&#xff09;。Javaee项目&#xff0c;springboot项目。 项目介绍&#xff1a; 采用M&#xff08;model&#xff09;V&#xff08;view&#xff09;C&#xff08;controller&#xff09;三层体系…

高低温交变湿热实验箱

产品概述 武汉凯迪正大高低温实验箱&#xff08;恒温恒湿试验箱&#xff09;乃针对各种材质表面处理&#xff0c;包含涂料、电镀、有机及无机皮膜&#xff0c;阳极处理&#xff0c;防锈油等防腐处理后测试其耐腐蚀性&#xff0c;从而确立产品的质量。 产品特点 1、内箱尺寸…

全网最新最全面的Appium自动化:Appium常用操作之按键类操作

按键类操作 按键类操作用来模拟在手机设备上进行按键操作&#xff08;推荐使用 方式一 &#xff09; 方式一、press_keycode(self,keycode,metastateNone,flagsNone)&#xff1a;模拟按键输入&#xff0c;其中&#xff1a; keycode&#xff1a;发送到设备的键值编码可以通过An…

华为快应用中自定义Slider效果

文章目录 一、前言二、实现代码三、参考链接 一、前言 在华为快应用中官方提供了<slider>控件&#xff0c;但是这个控件的限制比较多&#xff0c;比如滑块无法自定义&#xff0c;所以这里进行下自定义&#xff0c;自己修改样式。 二、实现代码 整体效果如下: 源码如下…

【数据结构(七)】查找算法

文章目录 查找算法介绍1. 线性查找算法2. 二分查找算法2.1. 思路分析2.2. 代码实现2.3. 功能拓展 3. 插值查找算法3.1. 前言3.2. 相关概念3.3. 实例应用 4. 斐波那契(黄金分割法)查找算法4.1. 斐波那契(黄金分割法)原理4.2. 实例应用 查找算法介绍 在 java 中&#xff0c;我们…

全面解析修复msvcr120.dll缺失问题的方法,msvcr120.dll丢失的原因

在计算机使用过程中&#xff0c;我们经常会遇到一些错误提示&#xff0c;其中最常见的就是“msvcr120.dll丢失”。这个错误通常会导致某些程序无法正常运行&#xff0c;给用户带来很大的困扰。那么&#xff0c;当我们遇到这个问题时&#xff0c;应该如何修复呢&#xff1f;本文…

Linux基础项目开发1:量产工具——UI系统(五)

前言&#xff1a; 前面我们已经把显示系统、输入系统、文字系统搭建好了&#xff0c;现在我们就要给它实现按钮操作了&#xff0c;也就是搭建UI系统&#xff0c;下面让我们一起实现UI系统的搭建吧 目录 一、按钮数据结构抽象 ui.h 二、按键编程 1.button.c 2.disp_manager…

使用rust slint开发桌面应用

安装QT5&#xff0c;过程省略 安装rust&#xff0c;过程省略 创建工程 cargo new slint_demo 在cargo.toml添加依赖 [dependencies] slint "1.1.1" [build-dependencies] slint-build "1.1.1" 创建build.rs fn main() {slint_build::compile(&quo…

使用 async/await 是必须避免的陷阱

使用 async/await 是必须避免的陷阱 如果我们使用过 nodejs&#xff0c;那么我们可能已经在 javaSoript 中使用了异步操作。异步任务是一个独立于 JavaSoript 引擎的主线程执行的操作。从本质上讲&#xff0c;这就是应用程序功能没有阻塞的 UI 的原因。 nodejs 的单线程性质&a…

华容道问题求解第一部分_思路即方案设计

一、前言 华容道是一种传统的益智游戏&#xff0c;通常由一个长方形木板和若干个方块组成。其中包括一个或多个不同颜色的方块&#xff08;也称为车块&#xff09;和其他大小相同的方块&#xff08;也称为障碍块&#xff09;。游戏的目标是将车块从木板的一个端点移动到另一个…

【mysql】mysgld.log文件太大怎么办

我们有一台测试服务器。跑着一个msyq&#xff0c;发现没有空间了。差看日志文件占用了很多。 怎么破 使用下面命令 echo "" >mysqld.log 执行命令后

PostGIS学习教程九:空间连接

PostGIS学习教程九&#xff1a;空间连接 空间连接&#xff08;spatial joins&#xff09;是空间数据库的主要组成部分&#xff0c;它们允许你使用空间关系作为连接键&#xff08;join key&#xff09;来连接来自不同数据表的信息。我们认为“标准GIS分析”的大部分内容可以表示…

直播预告 | 降本增效持续深化,如何找准 FinOps 关键着力点?

企业落地 FinOps 有哪些实施路径和阶段规划&#xff1f;2023 年&#xff0c;业界 FinOps 取得了哪些进展&#xff1f;12 月 6 日&#xff0c;「降本增效持续深化&#xff0c;如何找准 FinOps 关键着力点」专题直播即将开讲。小红书基础技术部混合云资源管理负责人梁啟成将带来《…

无法从SD卡中删除文件怎么办?

在使用SD卡时&#xff0c;有时我们会无法从SD卡中删除文件&#xff0c;那么这该怎么办呢&#xff1f;下面我们就一起来了解一下吧。 方式1. 检查SD卡&#xff08;读卡器&#xff09;上的写保护选项卡 对于某些SD卡&#xff0c;SD卡的一侧可能有一个开关&#xff0c;并有标有Lo…

AntDesignBlazor示例——创建项目

本示例是AntDesign Blazor的入门示例&#xff0c;在学习的同时分享出来&#xff0c;以供新手参考。 示例代码仓库&#xff1a;https://gitee.com/known/AntDesignDemo 1. 开发环境 VS2022 17.8.2.NET8AntDesign 0.16.2 2. 学习目标 创建新项目安装AntDesign组件包及使用方…

Leetcode 77 组合

题意理解&#xff1a; 给定两个整数 n 和 k&#xff0c;返回范围 [1, n] 中所有可能的 k 个数的组合。 如&#xff1a;n3,k2,则有&#xff1a;12 13 23 一般&#xff0c;我们使用回溯法来解决组合问题。 组合问题没有顺序要求&#xff0c;所以 12 21 是同一个组合&#xff08;如…

【Linux驱动开发】环境搭建Linux驱动开发环境

环境搭建Linux驱动开发环境 1. 简单描述2. 资源3. 安装4. 基本操作和设置 1. 简单描述 基于讯为电子rk3568教程 2. 资源 下载 VMware Workstation Pro 17 链接 Ubuntu 桌面版&#xff08;64位&#xff09; 链接 3. 安装 需要选择自定义硬件&#xff08;内存大于16g 硬盘500g…

试验数字化平台WDP 助力车企数据管理加速度

一 现状 随着现代测控技术的提高&#xff0c;数据结构变得越来越复杂多样&#xff0c;数据量也在日益增大。又因试验条件的限制&#xff0c;大多数企业的数据管理方式主要是通过各类电子文档将试验数据保存在每个工程师的移动电脑中&#xff0c;再进行汇总存储和共享。这种落后…