英伟达SSD视觉算法分类代码解析

一、官方原代码

#!/usr/bin/env python3
#
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.
#

import sys
import argparse

from jetson_inference import imageNet
from jetson_utils import videoSource, videoOutput, cudaFont, Log

# parse the command line
parser = argparse.ArgumentParser(description="Classify a live camera stream using an image recognition DNN.", 
                                 formatter_class=argparse.RawTextHelpFormatter, 
                                 epilog=imageNet.Usage() + videoSource.Usage() + videoOutput.Usage() + Log.Usage())

parser.add_argument("input", type=str, default="", nargs='?', help="URI of the input stream")
parser.add_argument("output", type=str, default="", nargs='?', help="URI of the output stream")
parser.add_argument("--network", type=str, default="googlenet", help="pre-trained model to load (see below for options)")
parser.add_argument("--topK", type=int, default=1, help="show the topK number of class predictions (default: 1)")

try:
	args = parser.parse_known_args()[0]
except:
	print("")
	parser.print_help()
	sys.exit(0)


# load the recognition network
net = imageNet(args.network, sys.argv)

# note: to hard-code the paths to load a model, the following API can be used:
#
# net = imageNet(model="model/resnet18.onnx", labels="model/labels.txt", 
#                 input_blob="input_0", output_blob="output_0")

# create video sources & outputs
input = videoSource(args.input, argv=sys.argv)
output = videoOutput(args.output, argv=sys.argv)
font = cudaFont()

# process frames until EOS or the user exits
while True:
    # capture the next image
    img = input.Capture()

    if img is None: # timeout
        continue  

    # classify the image and get the topK predictions
    # if you only want the top class, you can simply run:
    #   class_id, confidence = net.Classify(img)
    predictions = net.Classify(img, topK=args.topK)

    # draw predicted class labels
    for n, (classID, confidence) in enumerate(predictions):
        classLabel = net.GetClassLabel(classID)
        confidence *= 100.0

        print(f"imagenet:  {confidence:05.2f}% class #{classID} ({classLabel})")

        font.OverlayText(img, text=f"{confidence:05.2f}% {classLabel}", 
                         x=5, y=5 + n * (font.GetSize() + 5),
                         color=font.White, background=font.Gray40)
                         
    # render the image
    output.Render(img)

    # update the title bar
    output.SetStatus("{:s} | Network {:.0f} FPS".format(net.GetNetworkName(), net.GetNetworkFPS()))

    # print out performance info
    net.PrintProfilerTimes()

    # exit on input/output EOS
    if not input.IsStreaming() or not output.IsStreaming():
        break

二、代码解析

代码增加中文注释

#!/usr/bin/env python3
#
# 版权所有 (c) 2020, NVIDIA CORPORATION. 保留所有权利。
#
# 特此免费授予获得此软件和相关文档文件(“软件”)副本的任何人,允许他们在不受限制的情况下处理软件,
# 包括但不限于使用、复制、修改、合并、发布、分发、再许可和/或出售软件副本,并允许提供软件的人
# 这样做,条件如下:
#
# 上述版权声明和本许可声明应包含在软件的所有副本或主要部分中。
#
# 本软件按“原样”提供,不提供任何形式的明示或暗示保证,包括但不限于适销性、
# 适用于特定目的和不侵权的保证。在任何情况下,作者或版权持有人均不对因使用本软件或其他交易,
# 或因使用本软件或其他交易而产生的任何索赔、损害或其他责任负责,无论是在合同诉讼、侵权诉讼还是其他诉讼中。
#

import sys
import argparse

from jetson_inference import imageNet
from jetson_utils import videoSource, videoOutput, cudaFont, Log

# 解析命令行参数
parser = argparse.ArgumentParser(
    description="使用图像识别DNN对实时摄像头流进行分类。",
    formatter_class=argparse.RawTextHelpFormatter,
    epilog=imageNet.Usage() + videoSource.Usage() + videoOutput.Usage() + Log.Usage()
)

parser.add_argument("input", type=str, default="", nargs='?', help="输入流的URI")
parser.add_argument("output", type=str, default="", nargs='?', help="输出流的URI")
parser.add_argument("--network", type=str, default="googlenet", help="要加载的预训练模型(参见下方选项)")
parser.add_argument("--topK", type=int, default=1, help="显示前K个类别预测(默认:1)")

try:
    args = parser.parse_known_args()[0]
except:
    print("")
    parser.print_help()
    sys.exit(0)

# 加载识别网络
net = imageNet(args.network, sys.argv)

# 注意:要硬编码加载模型的路径,可以使用以下API:
# net = imageNet(model="model/resnet18.onnx", labels="model/labels.txt", 
#                input_blob="input_0", output_blob="output_0")

# 创建视频源和输出
input = videoSource(args.input, argv=sys.argv)
output = videoOutput(args.output, argv=sys.argv)
font = cudaFont()

# 处理帧直到输入结束或用户退出
while True:
    # 捕获下一帧图像
    img = input.Capture()

    if img is None:  # 超时
        continue  

    # 对图像进行分类并获取前K个预测
    # 如果只需要最顶层的类别,可以简单地运行:
    #   class_id, confidence = net.Classify(img)
    predictions = net.Classify(img, topK=args.topK)

    # 绘制预测的类别标签
    for n, (classID, confidence) in enumerate(predictions):
        classLabel = net.GetClassLabel(classID)
        confidence *= 100.0

        print(f"imagenet:  {confidence:05.2f}% class #{classID} ({classLabel})")

        font.OverlayText(
            img, 
            text=f"{confidence:05.2f}% {classLabel}", 
            x=5, y=5 + n * (font.GetSize() + 5),
            color=font.White, background=font.Gray40
        )
                         
    # 渲染图像
    output.Render(img)

    # 更新标题栏
    output.SetStatus("{:s} | Network {:.0f} FPS".format(net.GetNetworkName(), net.GetNetworkFPS()))

    # 打印性能信息
    net.PrintProfilerTimes()

    # 输入/输出流结束时退出
    if not input.IsStreaming() or not output.IsStreaming():
        break

这段Python代码是一个使用NVIDIA的Jetson平台进行图像分类的示例程序。代码解析如下:

头部版权声明和许可信息

这部分代码声明了版权信息和软件许可,允许免费使用、复制和分发软件。

导入模块

import sys
import argparse
from jetson_inference import imageNet
from jetson_utils import videoSource, videoOutput, cudaFont, Log
  • sys: 处理系统特定的参数和功能。
  • argparse: 解析命令行参数。
  • jetson_inferencejetson_utils模块用于加载和处理图像分类模型、视频源、视频输出、绘制字体和日志记录。

解析命令行参数

parser = argparse.ArgumentParser(description="Classify a live camera stream using an image recognition DNN.", 
                                 formatter_class=argparse.RawTextHelpFormatter, 
                                 epilog=imageNet.Usage() + videoSource.Usage() + videoOutput.Usage() + Log.Usage())

parser.add_argument("input", type=str, default="", nargs='?', help="URI of the input stream")
parser.add_argument("output", type=str, default="", nargs='?', help="URI of the output stream")
parser.add_argument("--network", type=str, default="googlenet", help="pre-trained model to load (see below for options)")
parser.add_argument("--topK", type=int, default=1, help="show the topK number of class predictions (default: 1)")

try:
	args = parser.parse_known_args()[0]
except:
	print("")
	parser.print_help()
	sys.exit(0)
  • 使用argparse模块定义和解析命令行参数,包括输入和输出流的URI、使用的预训练模型和显示前K个预测结果的数量。
  • 尝试解析命令行参数,如果解析失败,则显示帮助信息并退出程序。

加载图像分类网络

net = imageNet(args.network, sys.argv)
  • 使用imageNet类加载预训练的神经网络模型。

创建视频源和视频输出

input = videoSource(args.input, argv=sys.argv)
output = videoOutput(args.output, argv=sys.argv)
font = cudaFont()
  • 使用videoSource类创建视频输入流。
  • 使用videoOutput类创建视频输出流。
  • 使用cudaFont类创建用于绘制文本的字体。

处理视频帧

while True:
    # capture the next image
    img = input.Capture()

    if img is None: # timeout
        continue  

    # classify the image and get the topK predictions
    predictions = net.Classify(img, topK=args.topK)

    # draw predicted class labels
    for n, (classID, confidence) in enumerate(predictions):
        classLabel = net.GetClassLabel(classID)
        confidence *= 100.0

        print(f"imagenet:  {confidence:05.2f}% class #{classID} ({classLabel})")

        font.OverlayText(img, text=f"{confidence:05.2f}% {classLabel}", 
                         x=5, y=5 + n * (font.GetSize() + 5),
                         color=font.White, background=font.Gray40)
                         
    # render the image
    output.Render(img)

    # update the title bar
    output.SetStatus("{:s} | Network {:.0f} FPS".format(net.GetNetworkName(), net.GetNetworkFPS()))

    # print out performance info
    net.PrintProfilerTimes()

    # exit on input/output EOS
    if not input.IsStreaming() or not output.IsStreaming():
        break
  • 使用input.Capture()捕获下一帧图像。
  • 如果捕获失败,则继续下一次循环。
  • 使用net.Classify()对图像进行分类,并获取前K个预测结果。
  • 使用font.OverlayText()在图像上绘制预测的类别标签和置信度。
  • 使用output.Render()渲染图像。
  • 使用output.SetStatus()更新输出窗口的标题栏,显示网络名称和帧率。
  • 使用net.PrintProfilerTimes()打印性能信息。
  • 如果输入或输出流停止,则退出循环。

总体来说,这段代码实现了一个实时图像分类系统,使用Jetson平台的预训练模型对视频流中的每一帧进行分类,并在输出视频中叠加预测结果。



# 三、imageNet类有哪些对象和方法
`imageNet`类是Jetson Inference库的一部分,用于处理图像分类任务。以下是`imageNet`类的主要对象和方法:

## 对象属性
- `NetworkType`: 预定义的网络类型枚举。
- `model`: 加载的模型文件路径。
- `labels`: 加载的标签文件路径。
- `input_blob`: 输入层名称。
- `output_blob`: 输出层名称。

## 主要方法

1. **`__init__`**:
   ```python
   def __init__(self, network="googlenet", argv=[], model=None, labels=None, input_blob="input_0", output_blob="output_0"):
  • 初始化图像分类网络。
  • 参数包括网络类型、命令行参数、模型文件路径、标签文件路径、输入层名称和输出层名称。
  1. Classify:

    def Classify(self, img, topK=1):
    
    • 对给定的图像进行分类。
    • 返回预测结果的列表,每个结果包括类别ID和置信度。
  2. GetClassLabel:

    def GetClassLabel(self, classID):
    
    • 返回指定类别ID的标签。
  3. GetNetworkName:

    def GetNetworkName(self):
    
    • 返回网络的名称。
  4. GetNetworkFPS:

    def GetNetworkFPS(self):
    
    • 返回网络处理帧率(FPS)。
  5. PrintProfilerTimes:

    def PrintProfilerTimes(self):
    
    • 打印网络的性能分析信息。
  6. Usage:

    @staticmethod
    def Usage():
    
    • 返回类的用法说明,通常用于命令行帮助信息。

使用示例

以下是如何使用imageNet类的简单示例:

from jetson_inference import imageNet
from jetson_utils import loadImage

# 初始化图像分类网络
net = imageNet("googlenet")

# 加载图像
img = loadImage("example.jpg")

# 对图像进行分类
class_id, confidence = net.Classify(img)

# 获取类别标签
class_label = net.GetClassLabel(class_id)

print(f"Image is classified as {class_label} with {confidence * 100:.2f}% confidence")

这个示例展示了如何初始化一个imageNet对象,加载一张图像,并对其进行分类,最后打印分类结果和置信度。

三、使用示例

 python3 imagenet.py /dev/video0 display://0 --network=googlenet

在这里插入图片描述
在这里插入图片描述

四、训练自己的分类模型

以下是包含中文注释的SSD训练代码示例:

1. 安装TAO Toolkit

确保在具有NVIDIA GPU的系统上安装了Docker和NVIDIA Container Toolkit。

2. 拉取TAO Toolkit Docker容器

docker pull nvcr.io/nvidia/tao/tao-toolkit-tf:v3.21.11-tf1.15.5-py3

3. 准备数据

准备训练和验证数据,数据应按照Kitti或Pascal VOC格式组织,包含图像文件和对应的标注文件。

4. 创建SSD配置文件

以下是SSD配置文件的示例,并包含中文注释:

random_seed: 42  # 随机种子,用于确保实验的可重复性
dataset_config {
  data_sources: {
    label_directory_path: "/path/to/labels"  # 训练数据的标签路径
    image_directory_path: "/path/to/images"  # 训练数据的图像路径
  }
  validation_data_sources: {
    label_directory_path: "/path/to/val_labels"  # 验证数据的标签路径
    image_directory_path: "/path/to/val_images"  # 验证数据的图像路径
  }
}
model_config {
  pretrained_model_file: "/path/to/pretrained/model"  # 预训练模型文件路径
  num_layers: 18  # 模型的层数
  all_proposals: 200  # 所有提议框的数量
}
train_config {
  batch_size: 8  # 批次大小
  learning_rate: 0.001  # 学习率
  num_epochs: 80  # 训练轮数
  augmentations: {
    horizontal_flip: true  # 是否进行水平翻转数据增强
    vertical_flip: false  # 是否进行垂直翻转数据增强
  }
}

5. 运行训练

使用以下命令运行训练任务,并包含中文注释:

docker run --gpus all -v /path/to/your/data:/data -v /path/to/your/config:/config -v /path/to/your/output:/output nvcr.io/nvidia/tao/tao-toolkit-tf:v3.21.11-tf1.15.5-py3 ssd train \
  -e /config/ssd_config.yaml \  # 配置文件路径
  -r /output/experiment_dir \  # 实验输出目录
  -k $API_KEY  # TAO Toolkit的API密钥
  • --gpus all: 使用所有可用的GPU。
  • -v /path/to/your/data:/data: 将本地数据目录挂载到容器内的/data路径。
  • -v /path/to/your/config:/config: 将本地配置文件目录挂载到容器内的/config路径。
  • -v /path/to/your/output:/output: 将本地输出目录挂载到容器内的/output路径。
  • -e /config/ssd_config.yaml: 指定配置文件。
  • -r /output/experiment_dir: 指定实验输出目录。
  • -k $API_KEY: 指定TAO Toolkit的API密钥。

6. 导出模型

训练完成后,使用以下命令导出模型,并包含中文注释:

docker run --gpus all -v /path/to/your/output:/output nvcr.io/nvidia/tao/tao-toolkit-tf:v3.21.11-tf1.15.5-py3 ssd export \
  -m /output/experiment_dir/model.tlt \  # 输入的TAO模型路径
  -o /output/experiment_dir/model.etlt \  # 输出的ETLT模型路径
  -k $API_KEY  # TAO Toolkit的API密钥

总结

通过TAO Toolkit,你可以方便地对SSD目标检测模型进行训练。准备数据、配置训练参数并运行训练命令,可以帮助你快速训练自定义的目标检测模型并进行部署。详细的指南和更多高级功能可以参考TAO Toolkit的官方文档。

这样,代码和配置文件中都增加了中文注释,便于理解和使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/700208.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【电路笔记】-电子放大器介绍

电子放大器介绍 文章目录 电子放大器介绍1、概述2、四极表示法3、理想模型4、真实放大器的限制5、噪音考虑因素6、电子放大器的类型1、概述 放大器是一种电子模块,可放大电位信号(电压放大器)、强度信号(电流放大器)或两者(功率放大器)。 放大器由两个输入组成,分别是…

开门预警系统技术规范(简化版)

开门预警系统技术规范(简化版) 1 系统概述2 预警区域3 预警目标4 功能需求5 功能条件6 显示需求7 指标需求1 系统概述 开门预警系统(DOW),在自车停止开门过程中,安装在车辆的传感器(如安装在车辆后保险杆两个角雷达)检测从自车后方接近的目标车(汽车、摩托车等)的相对…

Django面试题

1. 什么是wsgi? WSGI 是 “Web Server Gateway Interface” 的缩写,它是一种用于 Python Web 应用程序和 Web 服务器之间通信的标准接口。它定义了一组规则和约定,使 Web 服务器能够与任何符合 WSGI 规范的 Python Web 应用程序进行交互。 #…

2024年中级会计报名失败原因汇总❗

2024年中级会计报名失败原因汇总❗ ❌这四类考生不能报考24年中级⇩⇩⇩ 1️⃣不参加会计信息采集的同学 2️⃣未按规定完成继续教育的同学 3️⃣不符合会计工作年限要求的同学 4️⃣报名前未做好材料准备 需要准备有效期内身份证、本人学历或学位证书、户籍证或者居住证明、符…

翻转链表-链表题

LCR 141. 训练计划 III - 力扣(LeetCode) 非递归 class Solution { public:ListNode* trainningPlan(ListNode* head) {if(head ! nullptr && head->next ! nullptr){ListNode* former nullptr;ListNode* mid head;ListNode* laster nul…

C++ PDF转图片

C PDF转图片#include "include/fpdfview.h" #include <fstream> #include <include/core/SkImage.h>sk_sp<SkImage> pdfToImg(sk_sp<SkData> pdfData) {sk_sp<SkImage> img;FPDF_InitLibrary(nullptr);FPDF_DOCUMENT doc;FPDF_PAGE …

Character Region Awareness for Text Detection论文学习

​1.首先将模型在Synth80k数据集上训练 Synth80k数据集是合成数据集&#xff0c;里面标注是使用单个字符的标注的&#xff0c;也就是这篇文章作者想要的标注的样子&#xff0c;但是大多数数据集是成堆标注的&#xff0c;也就是每行或者一堆字体被整体标注出来&#xff0c;作者…

人工智能ChatGPT的多种应用:提示词工程

简介 ChatGPT 的主要优点之一是它能够理解和响应自然语言输入。在日常生活中&#xff0c;沟通本来就是很重要的一门课程&#xff0c;沟通的过程中表达的越清晰&#xff0c;给到的信息越多&#xff0c;那么沟通就越顺畅。 和 ChatGPT 沟通也是同样的道理&#xff0c;如果想要 …

33.星号三角阵(二)

上海市计算机学会竞赛平台 | YACSYACS 是由上海市计算机学会于2019年发起的活动,旨在激发青少年对学习人工智能与算法设计的热情与兴趣,提升青少年科学素养,引导青少年投身创新发现和科研实践活动。https://www.iai.sh.cn/problem/742 题目描述 给定一个整数 𝑛,输出一个…

专属部署简介

什么是专属部署 专属部署(也称为专用部署)是一种部署选择&#xff0c;它允许用户将数据和应用部署到自己的专用云基础架构中&#xff0c;而不是与其他租户共享基础架构。这种部署方式可以提供更高的安全性、控制力和性能优化&#xff0c;因为用户可以完全控制和管理自己的基础设…

大众点评全国爱车店铺POI采集177万家-2024年5月底

大众点评全国爱车店铺POI采集177万家-2024年5月底 店铺POI点位示例&#xff1a; 店铺id H69Y6l1Ixs2jLGg2 店铺名称 HEEJOO豪爵足道(伍家店) 十分制服务评分 7.7 十分制环境评分 7.7 十分制划算评分 7.7 人均价格 134 评价数量 2982 店铺地址 桔城路2号盛景商业广场1-3…

46【Aseprite 作图】发光

1 通过“编辑 - 特效 - 卷积矩阵”&#xff0c;这次选择“7*7”&#xff0c;可以做出窗户的效果

面试题:什么是线程的上下文切换?

线程的上下文切换是指在操作系统中&#xff0c;CPU从执行一个线程的任务切换到执行另一个线程任务的过程。在现代操作系统中&#xff0c;为了实现多任务处理和充分利用CPU资源&#xff0c;会同时管理多个线程的执行。由于CPU在任意时刻只能执行一个线程&#xff0c;因此需要在这…

【QT5】<知识点> IMX6ULL开发板运行QT

目录 1. 安装交叉编译器 2. 命令行交叉编译QT项目 3. 运行该可执行程序 4. 开发板上运行UDP程序与Ubuntu通信 1. 安装交叉编译器 第一步&#xff1a;进入正点原子论坛找到IMX6ULL开发板的资料&#xff0c;下载“开发工具”&#xff0c;将“交叉编译工具”中的fsl-imx-x11-…

讲透计算机网络知识(实战篇)01——计算机网络和协议

一、计算机网络和协议 1、网络和互联网络 1.1 网络、互联网、Internet 用交换机、集线器连接在一起的计算机构成一个网络。 用路由器连接多个网络&#xff0c;形成互联网。 全球最大的互联网&#xff1a;Internet。 1.2 网络举例 家庭互联网 图中的无线拨号路由器既是路由…

mysql和redis备份和恢复数据的笔记

一、mysql的备份及恢复方法&#xff1a; 1.完全备份与恢复 1.1物理备份与恢复 物理备份又叫冷备份&#xff0c;需停止数据库服务&#xff0c;适合线下服务器 备份数据流程&#xff1a; 第一步:制作备份文件 systemctl stop mysqld #创建存放备份文件的目录 mkdir /bakdir …

一夜之间,苹果杀死无数AI工具创业公司!GPT-4o深度整合进苹果

就在刚刚&#xff0c;苹果发布会WWDC2024官宣了一系列AI相关的重磅升级。 由于这一波AI升级攒的太大了&#xff0c;苹果甚至索性创造了一个新的概念——苹果智能&#xff08;Apple Intelligence&#xff09;。 如果你认为 苹果智能 Siri升级&#xff0c;那你就大错特错了。 …

Ubuntu,Linux服务器安装Mellanox MCX653105A IB网卡HCA卡驱动

驱动下载地址 https://network.nvidia.com/products/infiniband-drivers/linux/mlnx_ofed/ 选择对应操作系统 进入目录运行 安装成功显示 如果中途报错&#xff0c;需要核对下载的版本&#xff0c;并且把原来安装的卸载

在 TypeScript 中,定义类型时你用 Types 还是 Interfaces?

什么是 Types 和 Interfaces&#xff1f; Types 和 Interfaces 是 TypeScript 中两种用于定义数据结构的工具。它们可以帮助开发者在编写代码时约束变量和对象的类型&#xff0c;从而减少错误并提高代码的可读性。 Types&#xff1a;Types 允许你定义各种类型&#xff0c;包括基…

TCP四次挥手全过程详解

TCP四次挥手全过程 有几点需要澄清&#xff1a; 1.首先&#xff0c;tcp四次挥手只有主动和被动方之分&#xff0c;没有客户端和服务端的概念 2.其次&#xff0c;发送报文段是tcp协议栈的行为&#xff0c;用户态调用close会陷入到内核态 3.再者&#xff0c;图中的情况前提是双…