yolov8-seg 分割推理流程

目录

一、分割+检测

二、图像预处理

二、推理

三、后处理与可视化

3.1、后处理

3.2、mask可视化

四、完整pytorch代码


一、分割+检测

注:本篇只是阐述推理流程,tensorrt实现后续跟进。

yolov8-pose的tensorrt部署代码稍后更新,还是在仓库:GitHub - FeiYull/TensorRT-Alpha: 🔥🔥🔥TensorRT-Alpha supports YOLOv8、YOLOv7、YOLOv6、YOLOv5、YOLOv4、v3、YOLOX、YOLOR...🚀🚀🚀CUDA IS ALL YOU NEED.🍎🍎🍎It also supports end2end CUDA C acceleration and multi-batch inference.

也可以关注:TensorRT系列教程-CSDN博客

以下是官方预测代码:

from ultralytics import YOLO
model = YOLO(model='yolov8n-pose.pt')
model.predict(source="d:/Data/1.jpg", save=True)

推理过程无非是:图像预处理 -> 推理 -> 后处理 + 可视化,这三个关键步骤在文件大概247行:D:\CodePython\ultralytics\ultralytics\engine\predictor.py,代码如下:

# Preprocess
with profilers[0]:
	im = self.preprocess(im0s) # 图像预处理

# Inference
with profilers[1]:
	preds = self.inference(im, *args, **kwargs) # 推理

# Postprocess
with profilers[2]:
	self.results = self.postprocess(preds, im, im0s) # 后处理

二、图像预处理

通过debug,进入上述self.preprocess函数,看到代码实现如下。处理流程大概是:padding(满足矩形推理),图像通道转换,即:BGR装RGB,检查图像数据是否连续,存储顺序有HWC转为CHW,然后归一化。需要注意,原始pytorch框架图像预处理的时候,会将图像缩放+padding为HxW的图像,其中H、W为32倍数,而导出tensorrt的时候,为了高效推理,H、W 固定为640x640。

def preprocess(self, im):
	"""Prepares input image before inference.

	Args:
		im (torch.Tensor | List(np.ndarray)): BCHW for tensor, [(HWC) x B] for list.
	"""
	not_tensor = not isinstance(im, torch.Tensor)
	if not_tensor:
		im = np.stack(self.pre_transform(im))
		im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
		im = np.ascontiguousarray(im)  # contiguous
		im = torch.from_numpy(im)

	img = im.to(self.device)
	img = img.half() if self.model.fp16 else img.float()  # uint8 to fp16/32
	if not_tensor:
		img /= 255  # 0 - 255 to 0.0 - 1.0
	return img

二、推理

图像预处理之后,直接推理就行了,这里是基于pytorch推理。

def inference(self, im, *args, **kwargs):
	visualize = increment_path(self.save_dir / Path(self.batch[0][0]).stem,
							   mkdir=True) if self.args.visualize and (not self.source_type.tensor) else False
	return self.model(im, augment=self.args.augment, visualize=visualize)

三、后处理与可视化

3.1、后处理

640x640输入之后,有两个输出,其中

  • output1:尺寸为:116X8400,其中116=4+80+32,32为seg部分特征,经过NMS之后,输出为:N*38,其中38=4 + 2 + 32
  • output2:尺寸为32x160x160,拿上面NMS后的特征图后面,即:N*38矩阵后面部分N*32的特征图和output2作矩阵乘法,得到N*160*160的矩阵,接着执行sigmiod,然后拉平得到N*160*160 的mask。

然后将bbox缩放160*160的坐标系,如下代码,用于截断越界的mask,就是如下函数。最后,将所有mask上采样到640*640,然后用阀值0.5过一下。最后mask中只有0和1了,结束。

有关def crop_mask(masks, boxes):的理解:

def crop_mask(masks, boxes):
    """
    It takes a mask and a bounding box, and returns a mask that is cropped to the bounding box

    Args:
      masks (torch.Tensor): [n, h, w] tensor of masks
      boxes (torch.Tensor): [n, 4] tensor of bbox coordinates in relative point form

    Returns:
      (torch.Tensor): The masks are being cropped to the bounding box.
    """
    n, h, w = masks.shape
    x1, y1, x2, y2 = torch.chunk(boxes[:, :, None], 4, 1)  # x1 shape(n,1,1)
    r = torch.arange(w, device=masks.device, dtype=x1.dtype)[None, None, :]  # rows shape(1,1,w)
    c = torch.arange(h, device=masks.device, dtype=x1.dtype)[None, :, None]  # cols shape(1,h,1)

    return masks * ((r >= x1) * (r < x2) * (c >= y1) * (c < y2))

上面代码最后一句return,如下图理解,mask中所有点,例如点(r,c)必须在bbox内部。做法就是将bbox缩放到和mask一样的坐标系(160x160)如下图,然后使用绿色的bbox将mask进行截断:

3.2、mask可视化

直接将mask从灰度图转为彩色图,然后将类别对应的颜色乘以0.4,最后加在彩色图上就行了。

四、完整pytorch代码

将以上流程合并起来,并加以修改,完整代码如下:

import torch
import cv2 as cv
import numpy as np
from ultralytics.data.augment import LetterBox
from ultralytics.utils import ops
from ultralytics.engine.results import Results
import copy

# path = 'd:/Data/1.jpg'
path = 'd:/Data/640640.jpg'
device = 'cuda:0'
conf = 0.25
iou = 0.7

# preprocess
im = cv.imread(path)
# letterbox
im = [im]
orig_imgs = copy.deepcopy(im)
im = [LetterBox([640, 640], auto=True, stride=32)(image=x) for x in im]
im = im[0][None] # im = np.stack(im)
im = im[..., ::-1].transpose((0, 3, 1, 2))  # BGR to RGB, BHWC to BCHW, (n, 3, h, w)
im = np.ascontiguousarray(im)  # contiguous
im = torch.from_numpy(im)
img = im.to(device)
img = img.float()
img /= 255
# load model pt
ckpt = torch.load('yolov8n-seg.pt', map_location='cpu')
model = ckpt['model'].to(device).float()  # FP32 model
model.eval()

# inference
preds = model(img)

# poseprocess
p = ops.non_max_suppression(preds[0], conf, iou, agnostic=False, max_det=300, nc=80, classes=None)
results = []
# 如果导出onnx,第二个输出维度是1,应该就是mask,需要后续上采样
proto = preds[1][-1] if len(preds[1]) == 3 else preds[1]  # second output is len 3 if pt, but only 1 if exported???????
for i, pred in enumerate(p):
    orig_img = orig_imgs[i]
    if not len(pred):  # save empty boxes
        results.append(Results(orig_img=orig_img, path=path, names=model.names, boxes=pred[:, :6]))
        continue
    masks = ops.process_mask(proto[i], pred[:, 6:], pred[:, :4], img.shape[2:], upsample=True)  # HWC
    if not isinstance(orig_imgs, torch.Tensor):
        pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
    results.append(Results(orig_img=orig_img, path=path, names=model.names, boxes=pred[:, :6], masks=masks))

# show
plot_args = {'line_width': None,'boxes': True,'conf': True, 'labels': True}
plot_args['im_gpu'] = img[0]
result = results[0]
plotted_img = result.plot(**plot_args)
cv.imshow('plotted_img', plotted_img)
cv.waitKey(0)
cv.destroyAllWindows()

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/198898.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

git的创建以及使用

1、上传本地仓库 首先确定项目根目录中没有.git文件&#xff0c;有的话就删了&#xff0c;没有就下一步。在终端中输入git init命令。注意必须是根目录&#xff01; 将代码存到暂存区 将代码保存到本地仓库 2、创建git仓库 仓库名称和路径&#xff08;name&#xff09;随便写…

4P营销模型

4P营销模型 菲利普科特勒在其畅销书《营销管理&#xff1a;分析、规划与控制》中进一步确认了以4P为核心的营销组合方法. 模型介绍 「4P营销模型」是市场营销中的经典理论&#xff0c;代表了产品、价格、促销和渠道四个要素。这些要素是制定市场营销策略和实施计划的关键组成部…

MySql的InnoDB的三层B+树可以存储两千万左右条数据的计算逻辑

原创/朱季谦 B树是一种在非叶子节点存放排序好的索引而在叶子节点存放数据的数据结构&#xff0c;值得注意的是&#xff0c;在叶子节点中&#xff0c;存储的并非只是一行表数据&#xff0c;而是以页为单位存储&#xff0c;一个页可以包含多行表记录。非叶子节点存放的是索引键…

SA与NSA网络架构的区别

SA与NSA网络架构的区别 1. 三大运营商网络制式&#xff1a;2. 5G组网方式及业务特性3. NSA-3系列4. NSA—4系列5. NSA-7系列6. 5G SA网络架构7. 运营商策略 1. 三大运营商网络制式&#xff1a; 联通&#xff1a;3G(WCDMA)\4G(FDD-LTE/TD-LTE)\5G(SA/NSA)移动&#xff1a;2G(GS…

健全隧道健康监测,保障隧道安全管理

隧道工程事故的严重性不容忽视。四川隧道事故再次凸显了隧道施工的危险性&#xff0c;以及加强隧道安全监管的必要性。隧道工程事故不仅会给受害人带来巨大的痛苦和家庭悲剧&#xff0c;也会对整个社会产生严重的负面影响。因此&#xff0c;如何有效地降低隧道工程事故的发生率…

开发知识点-CSS样式

CSS样式 fontCSS 外边距 —— 围绕在元素边框的空白区域# linear-gradient() ——创建一个线性渐变的 "图像"# transform ——旋转 元素![在这里插入图片描述](https://img-blog.csdnimg.cn/20191204100321698.png)# rotate() [旋转] # 边框 (border) —— 围绕元素内…

Peter算法小课堂—高精度减法

给大家看个小视频高精度减法_哔哩哔哩_bilibili 基本思想 计算机模拟人类做竖式计算&#xff0c;从而得到正确答案 大家还记得小学时学的“减法竖式”吗&#xff1f;是不是这样 x-y问题 函数总览&#xff1a; 1.converts() 字符串转为高精度大数 2.le() 判断大小 3.sub() …

无分类编址 CIDR

在域名系统出现之后的第一个十年里&#xff0c;基于分类网络进行地址分配和路由IP数据包的设计就已明显显得可扩充性不足&#xff08;参见RFC 1517&#xff09;。为了解决这个问题&#xff0c;互联网工程工作小组在1993年发布了一新系列的标准——RFC 1518和RFC 1519——以定义…

Git分支管理--Bug分支

愿所有美好如期而遇 我们现在正在dev4分支上进行开发&#xff0c;但是在我们开发过程中&#xff0c;并且我们还未提交&#xff0c;master分支上出现了bug&#xff0c;需要我们修复&#xff0c;我们先来看情景 我们添加一行代码并且不提交充作开发&#xff0c;模拟正在进行开发时…

面试题:汉诺塔问题 · 递归

你好&#xff0c;我是安然无虞。 文章目录 汉诺塔问题问题描述解题思路代码详解 汉诺塔问题 问题描述 解题思路 这道题的名字还是很响的&#xff0c;基本上都能看出来使用递归解题&#xff0c;但是具体怎么实现还是需要细细想一想。 我们一步一步来&#xff0c;请看&#xff…

【搜维尔科技】产品推荐:Virtuose 6D RV,大型工作空间触觉设备

Virtuose 6D RV为一款具有大工作空间并在所有6自由度上提供力反馈的触觉设备&#xff0c;设计专用于虚拟现实环境&#xff0c;特别适合于大型虚拟物体的处理。 Virtuose 6D RV是当今市场上唯一将高工作效率与高工作量相结合在一起的产品。6D RV特别适合于缩放与操纵等应用&…

uni-app x生成的安卓包,安装时,提示不兼容。解决方案

找到 manifest.json 进入&#xff1a;源码视图 代码 {"name" : "xxx康养","appid" : "__xxx6","description" : "xxx康养","versionName" : "1.0.12","versionCode" : 100012,&…

MacBook如何远程控制华为手机?

将手机屏幕投影到电脑上可以提供更大的屏幕空间&#xff0c;方便观看电影、浏览照片、阅读文档等。然而&#xff0c;除了想将手机投屏到电脑&#xff0c;还想要在电脑上直接操作手机&#xff0c;有方法可以实现吗&#xff1f; 现在使用AirDroid Cast的远程控制手机功能就可以实…

从 0 搭建 Vite 3 + Vue 3 Js版 前端工程化项目

之前分享过一篇vue3+ts+vite构建工程化项目的文章,针对小的开发团队追求开发速度,不想使用ts想继续使用js,所以就记录一下从0搭建一个vite+vue3+js的前端项目,做记录分享。 技术栈 Vite 3 - 构建工具 Vue 3 Vue Router - 官方路由管理器 Pinia - Vue Store你也可以选择vue…

使用Moment.js中获取上周的开始日期和结束日期(可自定义)

前言 有时候需求是这样的&#xff0c;想要获取上周的开始日期和结束日期&#xff0c;或者前几周的时间范围 比如今天是2023.11.28号&#xff0c;我想获取上周的周一到周日&#xff0c;也就是&#xff0c;上周的开始日期: 2023-11-20&#xff0c;上周的结束日期: 2023-11-26 1.…

1742. 盒子中小球的最大数量

力扣&#xff08;LeetCode&#xff09;官网 - 全球极客挚爱的技术成长平台备战技术面试&#xff1f;力扣提供海量技术面试资源&#xff0c;帮助你高效提升编程技能&#xff0c;轻松拿下世界 IT 名企 Dream Offer。https://leetcode.cn/problems/maximum-number-of-balls-in-a-b…

JavaScript 的 DOM 知识点有哪些?

文档对象模型&#xff08;Document Object Model&#xff0c;简称 DOM&#xff09;&#xff0c;是一种与平台和语言无关的模型&#xff0c;用来表示 HTML 或 XML 文档。文档对象模型中定义了文档的逻辑结构&#xff0c;以及程序访问和操作文档的方式。 当网页加载时&#xff0…

UDP实现群聊通信

服务器端 #include <myhead.h> #define UDPIP "192.168.115.92" #define UDPPORT 6666 //存储客户信息的链表结构体 typedef struct Node {char name[20];struct sockaddr_in cin;struct Node *next; }*linklist; //数据结构体 struct data_cli {char type;ch…

从 0 到 1 开发一个 node 命令行工具

G2 5.0 推出了服务端渲染的能力&#xff0c;为了让开发者更快捷得使用这部分能力&#xff0c;最写了一个 node 命令行工具 g2-ssr-node&#xff1a;用于把 G2 的 spec 转换成 png、jpeg 或者 pdf 等。基本的使用如下&#xff1a; $ g2-ssr-node g2png -i ./bar.json -o ./bar.…

【Web】BJDCTF 2020 个人复现

目录 ①easy_md5 ②ZJCTF&#xff0c;不过如此 ③Cookie is so subtle! ④Ezphp ⑤The Mystery of IP ①easy_md5 ffifdyop绕过SQL注入 sql注入&#xff1a;md5($password,true) 右键查看源码 数组绕过 ?a[]1&b[]2 跳转到levell14.php 同样是数组绕过 param1[…