论文快过(图像配准|Coarse_LoFTR_TRT)|适用于移动端的LoFTR算法的改进分析 1060显卡上45fps

项目地址:https://github.com/Kolkir/Coarse_LoFTR_TRT
创建时间:2022年
相关训练数据:BlendedMVS
在这里插入图片描述
LoFTR [19]是一种有效的深度学习方法,可以在图像对上寻找合适的局部特征匹配。本文报道了该方法在低计算性能和有限内存条件下的设备上的优化工作。原来的LoFTR方法是基于一个ResNet [6]backbone和两个基于线性transformer[22]架构的模块。在本研究中,只剩下粗匹配块,参数的数量显著减少,并使用知识蒸馏技术对网络进行了训练。对比结果表明,在粗匹配块中,尽管模型大小显著减少,但该方法仍可以获得适当的特征检测精度。此外,本文还展示了使模型与NVIDIA TensorRT运行时兼容所需的额外步骤,并展示了一种优化针对低端gpu的训练方法的方法。

简化后的算法运行速度在1060显卡上提升了45倍,针对640×480图像,fps可达45

1、改进思考

1.1 算法背景

为了解决高计算复杂度的问题,对transformer架构的各种修改已经被开发出来。LoFTR方法采用线性变换[22]方法,该方法提出通过将注意层中使用的指数核替换为𝜑(·)=𝑒𝑙𝑢(·)+1.的替代核𝑠𝑖𝑚(𝑄,𝐾)=𝜑(𝑄)* 𝜑 ( K ) T (K)^T (K)T从而降低计算复杂度到𝑂(𝑁)该方法对计算机视觉任务具有良好的计算性能的提高和内存消耗的降低。这很重要,因为这种类型的任务的序列长度等于输入图像中的像素数。使用线性transformer允许对640x480张的图像进行特征匹配,在高端gpu上具有可接受的性能。然而,该体系结构中的更改仍然不足使transformer可以运行在低端gpu上。

1.2 优化方向

使复杂模型适应低端器件[5]要求的工程方法主要有:量化、剪枝[9]和知识蒸馏。

量化是用于计算和权重存储的数据类型的位宽降低。通常浮点计算转换为16位浮点或8位整数类型。为了达到与原始模型相媲美的精度,这种方法通常需要一个特殊的训练过程,考虑到缩小或缩小后的附加模型校准。量化通常在消费者级或嵌入式gpu上不可用,而且它的实现只能在高端gpu中可用。然而,对于基于cpu的设备,该方法是可用的,可以提供良好的效果。

剪枝是一种去除网络参数的方法,它对结果的精度没有太大的贡献。通常一个合适的剪枝条件是权重接近于零。由此得到的模型可能需要更少的内存,而且在推理方面可能更有效。有许多剪枝类型,但可以区分以下两种主要类型:结构化剪枝,当对称的权值块被删除时,例如层,和非结构化剪枝,当被删除的块可能是不同的形状时。由于这种方法改变了模型架构,因此通常需要进行手动调整来恢复正常的模型工作。结构化方法可能更可取,因为它对全局架构进行的更改更少,而且恢复模型操作更容易,甚至可能不必要。然而,流行的深度学习框架通常会实现非结构化的方法。在复杂模型中应用非结构化处理后适应网络操作可能是一项重要的任务,需要很多时间来解决,而且由于该方法不能保证一个稳定的结果,因此应用它并不总是合适的。

知识蒸馏是在教师的帮助下训练模型的一种方法。教师可以是具有相同架构但具有更多参数的网络,也可以是具有其他架构的网络。大多数训练是使用复杂的损失函数,转移教师的知识。转移到学生模型中的知识元素可以是教师网络中某些层的输出值,例如,在分类中可能是softmax之前的输出。也可以使用教师网络[2]的内部层输出值。知识蒸馏在保持所需的精度的同时,显示了良好的结果,但没有标准的方法来组织这样的过程。而成功则取决于正确的知识转移技术、精心选择的损失函数和学生模型架构。

如上所示,没有单一的方法来优化低端设备的深度学习模型。因此,通常会针对特定的架构开发专门的解决方案。本文提出了一种针对LoFTR特征匹配方法的优化方法。

1.3 本文方案

该方法的主要思想是显著减少模型参数的数量和从原始模型中的知识转移。
决定只保留一个transformer block用于粗特征匹配,尽管原始模型包含第二个transformer模块用于细匹配同时,在所有模型块中进行了手动迭代选择较少的层网络结构简化。设计了知识蒸馏损失函数,并使用了一个较小的训练数据集设计知识蒸馏方案。然而,地面-真实的特征点的匹配也可以使用深度图来确定。训练过程是开发使用自动混合精度(AMP)技术和梯度积累方法来节省内存和加快计算。

源代码被改编为以NVIDIA TensorRT [13]引擎格式编译。选择工作内存大小为2Gb的NVIDIA Jetson Nano [12]作为目标设备。并选择了基于英特尔i5处理器和Nvidia GTX 1060 6Gb GPU的桌面机作为训练平台。

2、模型改进

2.1 适配性修改

最初的LoFTR模型是用Python编写的,使用PyTorch作为深度学习框架。为了创建TensorTR模型,有两种可能性,一种是使用torch-TensorRT[14]编译器,第二种是将模型转换为ONNX [1]格式,然后使用NVIDIA TensorTR SDK编译它。由于目标平台的资源有限,无法应用第一个选项,因为使用Torch-TensorRT编译成TensorTR格式意味着在目标设备上运行它以进行实时优化。实验发现,编译ONNX需要的资源更少,并且在目标设备上是可能的,因此选择了第二个选项。

然而,einsum操作在onnx中并不支持。
在这里插入图片描述
所有将运算方式修改为以下,使onnx与tensorRT都支持。
在这里插入图片描述

2.2 结构优化

为了目标设备上实现可接受的性能,即选择块中的层的数量和尺寸。为此目的,我们开发了一个在实时网络摄像头图像上搜索特征匹配的演示应用程序。性能是通过呈现相应匹配时的FPS数量来估计的。然后,在此应用程序的帮助下,迭代地选择了表1中所示的模型配置。
在这里插入图片描述

原始模型的作者报告说,完整模型在RTX 2080Ti上处理116 ms处理一对640×480图像,约8 FPS [19]。简化后的算法运行速度在1060显卡上提升了45倍。
在这里插入图片描述
表3显示了参数数量的变化。从表中可以看到,原始模型的尺寸显著减少,以便在目标设备上实现可接受的性能。
在这里插入图片描述

2.3 训练设置

针对低性能硬件的局限性,对知识精馏训练过程进行了优化。为了加速梯度计算和减少内存消耗,我们使用了自动混合精度(AMP)技术,因为它的实现在PyTorch深度学习框架中可用。该技术的本质是,梯度计算所需的一些操作使用浮点32,而另一部分使用浮点16种数据类型。例如,卷积运算和线性层相关的矩阵计算使用float16计算速度更快。而其他操作,如减法,需要使用一个浮动32范围。这项技术使我们能够为模型训练中涉及的所有操作自动选择适当的数据类型。它的使用可以显著减少模型ResNet+FPN头的内存消耗。然而,AMP技术存在较小梯度值的数值计算问题。因此,为了稳定损失函数,增加了放大因子。

尽管使用了AMP,但在GTX 1060上进行训练也只能是支持到batch为4的640x480的图像。因此,为了增加batch的大小,我们采用了梯度积累的方法。这意味着大batchsize被分为𝑛系列的小batchsize。对于每个系列,进行正向和反向循环,不清除产生的梯度值,而是求和。其中,𝑛=𝐵𝑖𝑔𝐵𝑎𝑡𝑐ℎ𝑆𝑖𝑧𝑒、𝑆𝑚𝑎𝑙𝑙𝐵𝑎𝑡𝑐ℎ𝑆𝑖𝑧𝑒。在每次迭代中,损失函数值乘以比例因子1/𝑛。只有经过所有𝑛迭代后才更新网络参数,然后将梯度归零。因此,利用该技术模拟了大批量的训练。在这项工作中,虚拟批处理大小等于32。尽管,在现实中,硬件处理了8批梯度累计,每批的尺寸为4。梯度积累技术并没有实现实际大批量使用的精确对应关系,因此这两种方法的损失和梯度值将是不同的。

此外,我们还注意到,应用学习速率调度器可以显著加快训练过程。本研究采用了具有标准参数的AdamW [10]优化算法。初始学习速率值为10−3,每15个epoch乘以10−3。

每个epoch都从原始数据集中随机选择大小为5000对图像。

2.4 训练效果

图1显示了有教师和没有教师的训练的损失函数值。这张图清楚地表明,当与教师一起进行训练时,绝对损失函数值明显更小,学习过程本身更稳定。
在这里插入图片描述
图2,它显示了平均绝对误差(MAE)与训练持续时间的依赖关系。它显示了预测的特征匹配分数与地面真实值之间的平均差异。我们可以看到,当与老师一起进行训练时,MAE值远远接近于零。我们可以假设,在没有教师的情况下训练一个较小的网络会使它对其结果缺乏信心。然而,与此同时,这个图1显示了所选择的模型架构能够在没有老师的情况下学习,但可能需要更长的时间来获得可比的结果,并且需要更低的阈值来确定最重要的匹配。
在这里插入图片描述
图3显示了在数据集图像上的模型结果的示例。白点表示原模型作为教师使用的粗LoFTR模块的匹配结果。黑点表示较小模型的结果。从实验结果中可以清楚地看出,较小的模型比教师模型更关注图像的不同部分。最可能的原因是头层数量较少,transformer参数不同,使得模型强调更明显的特征点。也可以注意到较小模型的特征匹配中存在错误,尽管通常特征匹配是相当准确的。
在这里插入图片描述
室外数据配准效果
在这里插入图片描述

3、代码运行

打开 https://github.com/Kolkir/Coarse_LoFTR_TRT,即可下载项目
在这里插入图片描述

3.1 前置修改

如果电脑没有摄像头,则需要进行下列额外代码修改

修改一: webcam.py中默认参数camid,类型修改为str,默认值修改为自己准备好的视频文件

def main():
    parser = argparse.ArgumentParser(description='LoFTR demo.')
    parser.add_argument('--weights', type=str, default='weights/outdoor_ds.ckpt',
                        help='Path to network weights.')
    # parser.add_argument('--camid', type=int, default=0,
    #                     help='OpenCV webcam video capture ID, usually 0 or 1.')
    parser.add_argument('--camid', type=str, default=r"C:\Users\Administrator\Videos\风景视频素材分享_202477135455.mp4",
                        help='OpenCV webcam video capture ID, usually 0 or 1.')

修改二:camera.py中的代码修改为以下,用于支持读取视频文件

import cv2
from threading import Thread


class Camera(object):
    def __init__(self, index):
        self.index=index
        if isinstance(self.index,int):#加载摄像头视频流
            self.cap = cv2.VideoCapture(self.index, cv2.CAP_V4L2)
        else:#加载视频
            self.cap = cv2.VideoCapture(self.index)
        if not self.cap.isOpened():
            print('Failed to open camera {0}'.format(index))
            exit(-1)

        # self.cap.set(cv2.CAP_PROP_FRAME_WIDTH, 1920)
        # self.cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 1080)

        self.thread = Thread(target=self.update, args=())
        self.thread.daemon = True
        self.thread.start()
        self.status = False
        self.frame = None

    def update(self):
        while True:
            try:
                if self.cap.isOpened():
                    (self.status, self.frame) = self.cap.read()
                    if not self.status:
                        if isinstance(self.index,int):#加载摄像头视频流
                            self.cap = cv2.VideoCapture(self.index, cv2.CAP_V4L2)
                        else:#加载视频
                            self.cap = cv2.VideoCapture(self.index)
                else:
                    break
            except cv2.error as e:
                print(e)
                break

    def get_frame(self):
        return self.frame, self.status

    def close(self):
        self.cap.release()
        self.thread.join()

3.2 运行效果

然后运行webcam.py,可以发现fps为25左右,此时硬件环境为win10笔记本、1660显卡,26ms即可处理完一个640*480的图片。但整体fps稳定在16~26左右。
在这里插入图片描述
再次加速,将推理时的图像分辨率修改为320x240 ,即将webcam.py中的 img_size 设置(320, 240),loftr\utils\cvpr_ds_config.py中对应的设置。发现速度没有显著提升,但整体fps稳定在22~28左右。

_CN.INPUT_WIDTH = 320
_CN.INPUT_HEIGHT = 240

在这里插入图片描述
onnx运行效果如下,整体fps稳定在20左右
在这里插入图片描述
将模型配置loftr\utils\cvpr_ds_config.py 中的尺寸修改如下,然后重新运行export_onnx.py,导出模型,再基于webcam.py运行onnx模型,可以发现fps高达40以上。

_CN.INPUT_WIDTH = 320
_CN.INPUT_HEIGHT = 320

在这里插入图片描述

3.3 图像配准

使用Coarse_LoFTR_TRT进行图像配准可以参考
https://blog.csdn.net/a486259/article/details/140241276 中章节5的操作。操作前最好先修改 loftr\utils\cvpr_ds_config.py 的尺寸为320,具体修改如下,然后重新运行export_onnx.py,导出模型。

_CN.INPUT_WIDTH = 320
_CN.INPUT_HEIGHT = 320

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/870365.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

MongoDB教程(十五):MongoDB原子操作

💝💝💝首先,欢迎各位来到我的博客,很高兴能够在这里和您见面!希望您在这里不仅可以有所收获,同时也能感受到一份轻松欢乐的氛围,祝你生活愉快! 文章目录 引言一、MongoD…

深入浅出WebRTC—Pacer

平滑发包(Pacer)是 WebRTC 实现高质量实时通信不可或缺的一部分。在视频通信中,单帧视频可能包含大量的数据,如果未经控制地立即发送,可能瞬间对网络造成巨大压力。Pacer 能够根据网络条件动态调整发送速率&#xff0c…

论文阅读:面向自动驾驶场景的多目标点云检测算法

论文地址:面向自动驾驶场景的多目标点云检测算法 概要 点云在自动驾驶系统中的三维目标检测是关键技术之一。目前主流的基于体素的无锚框检测算法通常采用复杂的二阶段修正模块,虽然在算法性能上有所提升,但往往伴随着较大的延迟。单阶段无锚框点云检测算法简化了检测流程,…

“微软蓝屏”全球宕机,敲响基础软件自主可控警钟

上周五,“微软蓝屏”“感谢微软 喜提假期”等词条冲上热搜,全球百万打工人受此影响,共同见证这一历史性事件。据微软方面发布消息称,旗下Microsoft 365系列服务出现访问中断。随后在全球范围内,包括企业、政府、个人在…

基于DPU与SmartNic的云原生SDN解决方案

1. 方案背景与挑战 随着云计算,大数据和人工智能等技术的蓬勃发展,数据中心面临着前所未有的数据洪流和计算压力,这对SDN提出了更高的性能和效率要求。自云原生概念被提出以来,Kubernetes为云原生应用的落地提供了一个轻量级&am…

node+mysql实现(账户密码,阿里云短信验证,QQ邮箱注册登录,短信验证密码重置,邮箱密码重置)之注册,登录密码重置总篇

node+mysql实现账户登录 注意效果图项目插件代码参数说明短信验证模块邮箱验证模块注册方式登录方式密码重置前端页面部分登录页面账户登录页面(login.html)短信验证登录页面(smsLogin.html)邮箱登录页面(emailLogin.html)注册部分页面短信验证注册页面(register.html)邮…

产品经理NPDP好考吗?

NPDP是新产品开发专业人员的资格认证,对于希望在产品管理领域取得认可的专业人士来说,NPDP认证是一项重要的资格。 那么,产品经理考取NPDP资格认证究竟难不难呢? 首先,NPDP考试的难易程度取决于考生的背景和准备情况…

C++11并发编程

目录 一、线程的创建 1、介绍thread类 2、创建线程 二、线程的2种工作方式 其一:关联主线程 其二:拆离主线程 两种工作方式的使用-代码示例 detach join 三、线程安全问题 1、什么是线程安全 2、怎么使程序线程安全 保护对共享数据的操作-加…

Redis (常用数据结构和命令)

目录 简介 概述 特点 数据结构 常用命令 通用命令 keys del exists expire 与 ttl String 命令 SET 和GET: MSET和MGET INCR和INCRBY和DECY SETNX SETEX Redis 命令 Key 的层级结构 key层级关系 : Hash命令 HSET和HGET HMSET和HMGET HGETALL H…

深入浅出WebRTC—ULPFEC

FEC 通过在发送端添加额外的冗余信息,使接收端即使在部分数据包丢失的情况下也能恢复原始数据,从而减轻网络丢包的影响。在 WebRTC 中,FEC 主要有两种实现方式:ULPFEC 和 FlexFEC,FlexFEC 是 ULPFEC 的扩展和升级&…

数据结构——堆(C语言版)

树 树的概念: 树(Tree)是一种抽象数据结构,它由节点(node)的集合组成,这些节点通过边相连,把 节点集合按照逻辑顺序抽象成图像,看起来就像一个倒挂着的树,也…

OpenCV图像滤波(1)双边滤波函数bilateralFilter的使用

操作系统:ubuntu22.04 OpenCV版本:OpenCV4.9 IDE:Visual Studio Code 编程语言:C11 功能描述 bilateralFilter是图像处理和计算机视觉领域中的一种高级图像滤波技术,特别设计用于在去除噪声的同时保留图像的边缘和细节。相比于传…

网络编程总复习

TCP的创建: 服务器端 : 客户端:

ESP8266用AT指令实现连接MQTT

1准备工作 硬件(ESP8266)连接电脑 硬件已经烧入了MQTT透传固件 2实现连接 2-1(进入AT模式) 打开串口助手发送如下指令 AT 2-2(复位) ATRST 2-3(开启DHCP,自动获取IP&#x…

The Llama 3 Herd of Models.Llama 3 模型论文全文

现代人工智能(AI)系统是由基础模型驱动的。本文提出了一套新的基础模型,称为Llama 3。它是一组语言模型,支持多语言、编码、推理和工具使用。我们最大的模型是一个密集的Transformer,具有405B个参数和多达128K个tokens的上下文窗口。本文对Llama 3进行了广泛的实证评价。我们…

Linux系统上安装Redis

百度网盘: 通过网盘分享的文件:redis_linux 链接: https://pan.baidu.com/s/1ZcECygWA15pQWCuiVdjCtg?pwd8888 提取码: 8888 1.把安装包拖拽到/ruanjian/redis/文件夹中(自己选择) 2.进入压缩包所在文件夹,解压压缩…

深入浅出WebRTC—LossBasedBweV2

WebRTC 同时使用基于丢包的带宽估计算法和基于延迟的带宽估计算法那,能够实现更加全面和准确的带宽评估和控制。基于丢包的带宽估计算法主要依据网络中的丢包情况来动态调整带宽估计,以适应网络状况的变化。本文主要讲解最新 LossBasedBweV2 的实现。 1…

计算机网络实验-RIP配置与分析

前言:本博客仅作记录学习使用,部分图片出自网络,如有侵犯您的权益,请联系删除 一、相关知识 路由信息协议(Routing Information Protocol,RIP)是一种基于距离向量(Distance-Vector&…

python题解

宽度与对齐 输出455、-123、987654,宽度为5,分别左对齐和右对齐 格式 输入格式: 无 输出格式: 输出为整型,空格分隔。每个数的输出占一行 样例 1 输入: 无 复制 输出: 455 455 -123 -123 98…

智慧工地视频汇聚管理平台:打造现代化工程管理的全新视界

一、方案背景 科技高速发展的今天,工地施工已发生翻天覆地的变化,传统工地管理模式很容易造成工地管理混乱、安全事故、数据延迟等问题,人力资源的不足也进一步加剧了监管不到位的局面,严重影响了施工进度质量和安全。 视频监控…