sagment-anything官方代码使用详解

文章目录

  • 一. sagment-anything官方例程说明
    • 1. 结果显示函数说明
    • 2. SamAutomaticMaskGenerator对象
      • (1) SamAutomaticMaskGenerator初始化参数
    • 3. SamPredictor对象
      • (1) 初始化参数
      • (2) set_image()
      • (3) predict()
  • 二. SamPredictor流程说明
    • 1. 导入所需要的库
    • 2. 读取图像
    • 3. 加载模型
    • 4. 生成预测对象
    • 5. 设置要检测的图像
    • 6. 根据不同输入需求对图像进行掩膜预测
      • (1) 根据输入一个点,输出对于这个点的三个不同置信度的掩膜
      • (2) 通过多个点获取一个对象的掩膜
      • (3) 通过设置反向点反选掩膜
      • (4) boxes输入生成掩膜
      • (5) 同时输入点与boxes生成掩膜
      • (6) 多个输入输出不同预测结果
  • 三. SamAutomaticMaskGenerator预测流程
    • 1. 导入所需要的库
    • 2. 读取图像
    • 3. 加载模型
    • 4. 生成预测对象
    • 5. 设置要检测的图像
    • 6. 给分割出来的物体上色,显示分割效果
  • 四. SamAutomaticMaskGenerator不同参数下的检测效果
    • 1. points_per_side参数测试
    • 2. pred_iou_thresh参数测试
    • 3. stability_score_thresh参数测试
    • 4. box_nms_thresh参数测试
    • 5. crop_nms_thresh参数测试

一. sagment-anything官方例程说明

1. 结果显示函数说明

def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)

2. SamAutomaticMaskGenerator对象

(1) SamAutomaticMaskGenerator初始化参数

  • model (Sam): 用于掩模预测的Sam模型。
  • points_per_side (int or None): 沿图像一侧要采样的点的数量。总点数为points_per_side 2 ^2 2。如果为None,则point_grids必须提供显式点采样。默认为32
  • points_per_batch (int): 设置模型同时检测的点数。更高的数字可能更快,但使用更多的GPU内存。默认为64
  • pred_iou_thresh (float): [0,1]中的滤波阈值,使用模型的预测掩码质量。默认值为0.88
  • stability_score_thresh (float): [0,1]中的滤波阈值,使用掩码在截断值变化下的稳定性,用于对模型的掩码预测进行二值化。默认值为0.95
  • stability_score_offset (float): 计算稳定性分数时,偏移截止值的量。默认值为1.0
  • box_nms_thresh (float): 非最大抑制用于过滤重复掩码的框IoU截止。默认值为0.7
  • crop_n_layers (int): 如果>0,将对图像的裁剪再次运行掩膜预测。设置要运行的层数,其中每层具有2*i_layer数量的图像裁剪。默认值为0
  • crop_nms_thresh (float): 非最大抑制用于过滤不同物体之间的重复掩码的框IoU截止。默认值为0.7
  • crop_overlap_ratio (float): 设置物体重叠的程度。在第一个裁剪层中,裁剪将重叠图像长度的这一部分。物体较多的后几层会缩小这种重叠。默认值为512 / 1500
  • crop_n_points_downscale_factor (int): 在层n中采样的每侧的点数按比例缩小crop_n_points_downscale_factor n ^n n。默认值为1
  • point_grids (list(np.ndarray) or None): 用于采样的点的显式网格上的列表,归一化为[0,1]。列表中的第n个栅格用于第n个裁剪层。与points_per_side独占。默认值为None
  • min_mask_region_area (int): 如果>0,将应用后处理来移除面积小于min_mask_region_area的掩膜来中断开连接的区域和孔。需要opencv。默认为0
  • output_mode (str): 表单掩码在中返回。可以是binary_maskuncompressed_rlecoco_rlecoco_rle需要pycocotools。对于大分辨率,binary_mask可能会消耗大量内存。默认为'binary_mask'
    “”"

3. SamPredictor对象

(1) 初始化参数

  • model (Sam): 用于掩模预测的Sam模型。

(2) set_image()

说明:
	设置检测的图像
参数:
	image(np.ndarray):用于计算掩码的图像。应为HWC uint8格式的图像,像素值为[0,255]。
	image_format(str):图像的颜色格式,以'RGB''BGR'为单位。

(3) predict()

说明:
	使用当前设置的图像预测给定输入提示的掩码。
参数:
	point_coords(np.ndarray或None):存放指向图像中物体的点的Nx2数组。每个点都以像素为单位(X,Y)。
	point_labels(np.ndarray或None):点提示的长度为N的标签阵列。1表示前景点,0表示背景点。
	box(np.ndarray或None):长度为4的数组,以XYXY格式向模型提供长方体提示。
	mask_input(np.ndarray):输入到模型的低分辨率掩码,通常来自先前的预测迭代。形式为1xHxW,其中对于SAM,H=W=256。
	multimask_output(bool):如果为true,则模型将返回三个掩码。对于不明确的输入提示(如单击),这通常会产生比单个预测更好的掩码。
	                          如果只需要单个遮罩,则可以使用模型的预测质量分数来选择最佳遮罩。对于非模糊提示,例如多个输入提示,
	                          multimask_output=False可以提供更好的结果。
	return_logits(bool):如果为true,则返回未阈值掩码logits,而不是二进制掩码。
返回值:
    (np.ndarray):CxHxW格式的输出掩码,其中C是掩码的数量,(H,W)是原始图像大小。
    (np.ndarray):长度为C的数组,包含模型对每个掩码质量的预测。
    (np.ndarray):形状为CxHxW的数组,其中C是掩码的数量,H=W=256。这些低分辨率logits可以作为掩码输入传递给后续迭代。

二. SamPredictor流程说明

1. 导入所需要的库

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

2. 读取图像

image = cv2.imread('images/dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

3. 加载模型

sam_checkpoint = "sam_vit_h_4b8939.pth"  # 模型文件所在路径
model_type = "vit_h"  # 模型的类型
device = "cuda"  # 运行模型的设备

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)  # 注册模型
sam.to(device=device)

4. 生成预测对象

mask_predictor = SamPredictor(sam)  # 生成sam预测对象

5. 设置要检测的图像

predictor.set_image(image)

6. 根据不同输入需求对图像进行掩膜预测

(1) 根据输入一个点,输出对于这个点的三个不同置信度的掩膜

input_point = np.array([[250, 187]])
input_label = np.array([1])

# 在'multimask_output=True'(默认设置)的情况下,SAM输出3个掩码,其中“scores”给出了模型对这些掩码质量的估计。
masks, scores, logits = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=True,)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10,10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

(2) 通过多个点获取一个对象的掩膜

# 通过多个点获取一个对象的掩膜
input_point = np.array([[237, 244], [273, 259]])
input_label = np.array([1, 1])  # 把两个点的标签都设置为1,代表两个点为同一个目标物所有 

masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(3) 通过设置反向点反选掩膜

# 通过多个点获取一个对象的掩膜
input_point = np.array([[237, 244], [319, 274]])
input_label = np.array([1, 0])  # 把两个点的标签都设置为1,代表两个点为同一个目标物所有 

masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label, multimask_output=False)

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(4) boxes输入生成掩膜

input_box = np.array([228, 230, 280, 276])

masks, _, _ = predictor.predict(point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False,)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(5) 同时输入点与boxes生成掩膜

input_point = np.array([[237, 244]])
input_label = np.array([1])
input_box = np.array([228, 230, 280, 276])

masks, _, _ = predictor.predict(point_coords=input_point, point_labels=input_label, box=input_box[None, :], multimask_output=False,)
plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_points(input_point, input_label, plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

(6) 多个输入输出不同预测结果

SamPredictor可以使用predict_tarch方法对同一图像输入多个提示(points、boxes)。该方法假设输入点已经是tensor张量,且boxes信息与image size相符合。例如,假设我们有几个来自对象检测器的输出结果。
SamPredictor对象(此外也可以使用segment_anything.utils.transforms)可以将boxes信息编码为特征向量(以实现对任意数量boxes的支持,transformed_boxes),然后预测mask。

input_boxes = torch.tensor([
    [228, 230, 280, 276],
    [495, 90, 554, 125],
    [447, 499, 494, 548],
    [162, 346, 214, 390],
], device=predictor.device) #假设这是目标检测的预测结果

transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])

masks, _, _ = predictor.predict_torch(point_coords=None, point_labels=None, boxes=transformed_boxes, multimask_output=False)

plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
    show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in input_boxes:
    show_box(box.cpu().numpy(), plt.gca())
plt.axis('off')
plt.show()

在这里插入图片描述

三. SamAutomaticMaskGenerator预测流程

1. 导入所需要的库

import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

2. 读取图像

image = cv2.imread('images/dog.jpg')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.figure(figsize=(20,20))
plt.imshow(image)
plt.axis('off')
plt.show()

3. 加载模型

sam_checkpoint = "sam_vit_h_4b8939.pth"  # 模型文件所在路径
model_type = "vit_h"  # 模型的类型
device = "cuda"  # 运行模型的设备

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)  # 注册模型
sam.to(device=device)

4. 生成预测对象

mask_generator = SamAutomaticMaskGenerator(model=sam,
                                           points_per_side=32,
                                           points_per_batch=64,
                                           pred_iou_thresh=0.88,
                                           stability_score_thresh=0.95,
                                           stability_score_offset=1.0,
                                           box_nms_thresh=0.7,
                                           crop_n_layers=0,
                                           crop_nms_thresh=0.7,
                                           crop_overlap_ratio=0.34133,
                                           crop_n_points_downscale_factor=1,
                                           point_grids=None,
                                           min_mask_region_area=0,
                                           output_mode='binary_mask')

5. 设置要检测的图像

# 将图像送入推理对象进行推理分割,输出结果为一个列表,其中存的每个字典对象内容为:
# segmentation : 分割出来的物体掩膜(与原图像同大小,有物体的地方为1其他地方为0)
# area : 物体掩膜的面积
# bbox : 掩膜的边界框(XYWH)
# predicted_iou : 模型自己对掩模质量的预测
# point_coords : 生成此掩码的采样输入点
# stability_score : 掩模质量的一个附加度量
# crop_box : 用于以XYWH格式生成此遮罩的图像的裁剪
masks = mask_generator.generate(image)

6. 给分割出来的物体上色,显示分割效果

# 给分割出来的物体上色,显示分割效果
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()

在这里插入图片描述

四. SamAutomaticMaskGenerator不同参数下的检测效果

1. points_per_side参数测试

  1. points_per_side=4,检测到9个物体
    在这里插入图片描述

  2. points_per_side=16,检测到211个物体
    在这里插入图片描述

  3. points_per_side=64,检测到683个物体
    在这里插入图片描述

  4. points_per_side=256,检测到872个物体
    在这里插入图片描述

2. pred_iou_thresh参数测试

  1. pred_iou_thresh=1, 检测到1个物体
    在这里插入图片描述
  2. pred_iou_thresh=0.95, 检测到274个物体
    在这里插入图片描述
  3. pred_iou_thresh=0.8, 检测到792个物体
    在这里插入图片描述

3. stability_score_thresh参数测试

  1. stability_score_thresh=1,检测到0个物体
    kjui
  2. stability_score_thresh=0.95,检测到683个物体
    在这里插入图片描述
  3. stability_score_thresh=0.95,检测到764个物体
    在这里插入图片描述

4. box_nms_thresh参数测试

  1. box_nms_thresh=1,检测到4680个物体
    在这里插入图片描述

  2. box_nms_thresh=0.7,检测到683个物体
    在这里插入图片描述

  3. box_nms_thresh=0.4,检测到621个物体
    在这里插入图片描述

  4. box_nms_thresh=0.1,检测到458个物体
    在这里插入图片描述

  5. box_nms_thresh=0,检测到201个物体
    在这里插入图片描述

5. crop_nms_thresh参数测试

  1. crop_nms_thresh=1,检测到683个物体
    在这里插入图片描述

  2. crop_nms_thresh=0.7,检测到683个物体
    在这里插入图片描述

  3. crop_nms_thresh=0.1,检测到683个物体
    在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/217636.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

IntelliJ IDEA的下载安装配置步骤详解

引言 IntelliJ IDEA 是一款功能强大的集成开发环境,它具有许多优势,适用于各种开发过程。本文将介绍 IDEA 的主要优势,并提供详细的安装配置步骤。 介绍 IntelliJ IDEA(以下简称 IDEA)之所以被广泛使用,…

Kubernetes存储搭建NFS挂载失败处理

搞NFS存储时候发现如下问题: Events:Type Reason Age From Message---- ------ ---- ---- -------Normal Scheduled 5m1s default-scheduler Successful…

【web安全】RCE漏洞原理

前言 菜某的笔记总结,如有错误请指正。 RCE漏洞介绍 简而言之,就是代码中使用了可以把字符串当做代码执行的函数,但是又没有对用户的输入内容做到充分的过滤,导致可以被远程执行一些命令。 RCE漏洞的分类 RCE漏洞分为代码执行…

如何基于Akamai IoT边缘平台打造一个无服务器的位置分享应用

与地理位置有关的应用相信大家都很熟悉了,无论是IM软件里的位置共享或是电商、外卖应用中的配送地址匹配,我们几乎每天都在使用类似的功能与服务。不过你有没有想过,如何在自己开发的应用中嵌入类似的功能? 本文Akamai将为大家提…

C语言中如何取一串比特中的特定位的比特

#include <iostream> #include <bitset> using namespace std; /* 向右的移位操作相当于丢掉最后的几位&#xff0c;然后剩下的位数进行“与”运算即可。 */ int main() {int a 0x2FB7; //0x2FB70010 1111 1011 0111char end3 (a >> 4) & 0x07; //取a…

从零开始搭建博客网站-----框架页

实现效果如下 发布的功能还没有实现&#xff0c;仅仅实现了简单的页面显示 关键代码如下 <template><div class"layout"><el-header class"header"><div class"logo">EasyBlog</div></el-header><el-c…

室内外融合便携式定位终端5G+UWB+RTK

一、介绍 便携式定位终端主要用于提供高精度的位置数据&#xff0c;支持室内UWB定位和室外北斗系统定位功能&#xff0c;支持5G公网和5G专网通信功能&#xff0c;便携式定位终端中超宽带(UWB)和实时动态(RTK)技术的集成代表了精确位置跟踪方面的重大进步。这款UWBRTK便携式定位…

SpringBootWeb案例_02

Web后端开发_05 SpringBootWeb案例_02 1.新增员工 1.1需求 在新增用户时&#xff0c;我们需要保存用户的基本信息&#xff0c;并且还需要上传的员工的图片&#xff0c;目前我们先完成第一步操作&#xff0c;保存用户的基本信息。 1.2 接口文档 基本信息 请求路径&#xff…

springboot + vue 企业级工位管理系统

qq&#xff08;2829419543&#xff09;获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;springboot 前端&#xff1a;采用vue技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xf…

【23-24 秋学期】NNDL 作业13 优化算法3D可视化

编程实现优化算法&#xff0c;并3D可视化 1. 函数3D可视化 分别画出和的3D图 NNDL实验 优化算法3D轨迹 鱼书例题3D版_优化算法3d展示-CSDN博客 2.加入优化算法&#xff0c;画出轨迹 分别画出和的3D轨迹图 从轨迹、速度等多个角度讲解各个算法优缺点 NNDL实验 优化算法3D轨…

Abaper入门实战篇 ——从 0 - 1 完成一个ALV

SAP ABAP 顾问&#xff08;开发工程师&#xff09;能力模型_Terry谈企业数字化的博客-CSDN博客文章浏览阅读516次。目标&#xff1a;基于对SAP abap 顾问能力模型的梳理&#xff0c;给一年左右经验的abaper 快速成长为三年经验提供超级燃料&#xff01;https://blog.csdn.net/j…

基于腾讯云手把手教你搭建网站

目录 前言前期准备工作具体搭建网站番外篇&#xff1a;网站开发及优化结束语 前言 在当今数字化时代浪潮之下&#xff0c;作为开发者拥有一个属于自己的网站是非常有必要的&#xff0c;也是展示个人形象、打造影响力和给别人提供服务的重要途径。网站不仅可以作为打造自己影响…

算法通关村-----跳跃游戏问题

跳跃游戏 问题描述 给你一个非负整数数组 nums &#xff0c;你最初位于数组的 第一个下标 。数组中的每个元素代表你在该位置可以跳跃的最大长度。 判断你是否能够到达最后一个下标&#xff0c;如果可以&#xff0c;返回 true &#xff1b;否则&#xff0c;返回 false 。 详见…

MySQL笔记-第02章_MySQL环境搭建

视频链接&#xff1a;【MySQL数据库入门到大牛&#xff0c;mysql安装到优化&#xff0c;百科全书级&#xff0c;全网天花板】 文章目录 第02章_MySQL环境搭建1. MySQL的卸载步骤1&#xff1a;停止MySQL服务步骤2&#xff1a;软件的卸载步骤3&#xff1a;残余文件的清理步骤4&am…

【网络安全技术】密钥管理

一、分级密钥概念 典型的密钥分级分为三级&#xff0c;三级密钥就是一次会话的session key&#xff0c;用来加密通信&#xff0c;所以通常使用对称密钥。 二级密钥就是分发三级密钥的密钥&#xff0c;用来加密三级密钥来分发三级密钥。 一级密钥就是分发二级密钥的密钥&…

Linux系统与python常用密码的加密解密方法

Linux系统与python常用加密&解密方法 文章目录 Linux系统与python常用加密&解密方法Linux系统加密解密方法一、openssl二、示例1、加密规则语法2、解密语法规则3、shell脚本 Python密码加密方法一、Base64加密1、加密2、解密 二、哈希算法加密三、Fernet对称加密算法1、…

运维03:LAMP

黄金架构LAMP 什么是LAMP LAMP是公认的最常见&#xff0c;最古老的黄金web技术栈 快速部署LAMP架构 #停止nginx&#xff0c;并且把nginx应用卸载了 systemctl stop nginx yum remove nginx -y#关闭防火墙 iptables -F #清空防火墙规则&#xff0c;比如哪些请求允许进入服…

7. 系统信息与系统资源

7. 系统信息与系统资源 1. 系统信息1.1 系统标识 uname()1.2 sysinfo()1.3 gethostname()1.4 sysconf() 2. 时间、日期2.1 Linux 系统中的时间2.1.1 Linux 怎么记录时间2.1.2 jiffies 的引入 2.2 获取时间 time/gettimeofday2.2.1 time()2.2.2 gettimeofday() 2.3 时间转换函数…

Java集合(二)

1. Map 1.1 HashMap 和 Hashtable 的区别 线程是否安全&#xff1a; HashMap 是非线程安全的&#xff0c;Hashtable 是线程安全的,因为 Hashtable 内部的方法基本都经过synchronized 修饰。&#xff08;如果你要保证线程安全的话就使用 ConcurrentHashMap 吧&#xff01;&…

[多线程]阻塞队列和生产者消费者模型

目录 1.阻塞队列 1.1引言 1.2Java标准库中的阻塞队列 1.3自主通过Java代码实现一个阻塞队列(泛型实现) 2.生产者消费者模型 1.阻塞队列 1.1引言 阻塞队列是多线程部分一个重要的概念,它相比于一般队列,有两个特点: 1.线程是安全的 2.带有阻塞功能 1) 队列为空,出队列就会阻…