【论文源码实战】轻量化MobileSAM,分割一切大模型出现,模型缩小60倍,速度提高40倍

前言

MobileSAM模型是在2023年发布的,其对之前的SAM分割一切大模型进行了轻量化的优化处理,模型整体体积缩小了60倍,运行速度提高40倍,但分割效果却依旧很好。

MobileSAM在使用方法上沿用了SAM模型的接口,因此可以与SAM模型的使用方法几乎可以无缝对接,真的是非常Nice。唯一的区别就是在模型加载的时候需要修改一点点。

一、环境配置

创建专属环境

conda create -n MobileSAM python=3.9

​​​​​

激活环境

conda activate MobileSAM

 

安装 Pytorch 环境

pip install -i https://pypi.tuna.tsinghua.edu.cn/simple "torch-1.13.0+cu116-cp39-cp39-win_amd64.whl" 
pip install -i https://pypi.tuna.tsinghua.edu.cn/simple "torchvision-0.14.0+cu116-cp39-cp39-win_amd64.whl"

二、代码测试

网页版使用

安装相关库

pip install -r requirements.txt

pip install gradio -i https://pypi.tuna.tsinghua.edu.cn/simple

pip install timm -i https://pypi.tuna.tsinghua.edu.cn/simple

代码运行

python app/app.py

点击链接进入下方的网页界面

在此就可以就行在网页上进行分割操作,下方是一些分割的图片:

Instructions for point mode(点模式说明)

  1. Restart by click the Restart button(单击“重新启动”按钮重新启动)

  2. Select a point with Add Mask for the foreground (Must)(选择具有“添加遮罩”的点作为前景(必须))

  3. Select a point with Remove Area for the background (Optional)(选择具有“删除区域”的点作为背景(可选))

  4. Click the Start Segmenting.(单击“开始分割”)

纯代码实现

Predictor 方法【提示点分割代码】

加载模型
def load_sam():
    # Selecting objects with SAM
    sam_checkpoint = "./weights/mobile_sam.pt"
    model_type = "vit_t"

    device = "cuda" if torch.cuda.is_available() else "cpu"

    sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
    sam.to(device=device)
    sam.eval()
    return SamPredictor(sam)
绘制结果
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white',
               linewidth=1.25)


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
单点得到掩码
# 使用 MobileSAM 从提示中得到掩码对象
# 选择需要分割的图像上的一个点。点以 (x,y) 格式输入到模型中,标签为 1(前景点)或 0(背景点)。
input_point = np.array([[400, 400]])
input_label = np.array([1])

# 使用 `SamPredictor.predict`进行预测
# 返回值:掩码、这些掩码的质量预测值和低分辨率掩码数值,这些数据可传递给下一次进行迭代预测。
# 当 `multimask_output=True`(默认设置)时,SAM 会输出 3 个掩码,其中 `scores` 给出了模型对这些掩码质量的估计值。
# 此设置用于模棱两可的输入提示,帮助模型区分与提示一致的不同对象。如果设置为 "false",则将返回单一掩码。
# 对于模棱两可的提示(如单点),建议使用 `multimask_output=True`,即使只需要单个掩码;可以通过选择在 `scores` 中返回的分数最高的掩码来选择最佳的单个掩码,这通常会产生更好的掩码。
masks, scores, logits = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=True,
)

for i, (mask, score) in enumerate(zip(masks, scores)):
    plt.figure(figsize=(10, 10))
    plt.imshow(image)
    show_mask(mask, plt.gca())
    show_points(input_point, input_label, plt.gca())
    plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
    plt.axis('off')
    plt.show()
多点得到掩码
# Specifying a specific object with additional points(指定具有附加点的特定对象)
# 单个输入点模棱两可,模型返回了多个与之一致的对象。要获得单一对象,可以提供多个点。
# 如果有上一次迭代的掩码,也可以提供给模型以帮助预测。
# 在使用多个提示指定单个对象时,可以通过设置 `multimask_output=False` 来得到单个掩码。
input_point = np.array([[400, 400], [450, 350]])
input_label = np.array([1, 1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
多点得到掩码前景和背景
input_point = np.array([[400, 400], [100, 500]])
input_label = np.array([0, 1])

mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

masks, _, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    mask_input=mask_input[None, :, :],
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
通过方框得到掩码
input_box = np.array([190, 70, 460, 280])

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
plt.axis('off')
plt.show()

SamAutomaticMaskGenerator 方法【一键全景分割代码】

Automatic mask generator (得到全部图像掩码)
# Automatic mask generation(自动生成掩码)
mask_generator = load_sam()

# To generate masks
masks = mask_generator.generate(image)

# Mask generation returns a list over masks, where each mask is a dictionary containing various data about the mask.
# These keys are:
# segmentation: the mask
# area: the area of the mask in pixels
# bbox: the boundary box of the mask in XYWH format
# predicted_iou: the model's own prediction for the quality of the mask
# point_coords: the sampled input point that generated this mask
# stability_score: an additional measure of mask quality
# crop_box: the crop of the image used to generate this mask in XYWH format

# Show all the masks overlayed on the image.
plt.figure(figsize=(20,20))
plt.imshow(image)
show_anns(masks)
plt.axis('off')
plt.show()
Automatic mask generation options
# Automatic mask generation options
# There are several tunable parameters in automatic mask generation that control how densely points are sampled and what the thresholds are for removing low quality or duplicate masks. Additionally, generation can be automatically run on crops of the image to get improved performance on smaller objects, and post-processing can remove stray pixels and holes. Here is an example configuration that samples more masks:
# 在自动掩模生成中有几个可调参数,用于控制采样点的密度以及去除低质量或重复掩模的阈值。此外,生成可以在图像的裁剪上自动运行,以提高较小对象的性能,后处理可以去除杂散像素和孔洞。
# 以下是一个示例配置,用于对更多掩码进行采样:
mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,  # Requires open-cv to run post-processing
)
masks2 = mask_generator_2.generate(image)
plt.figure(figsize=(20, 20))
plt.imshow(image)
show_anns(masks2)
plt.axis('off')
plt.show()

Onnx 推理【提示点分割代码】

模型转换(.pt 转 .onnx)
# 得到多个输出
python scripts/export_onnx_model.py --checkpoint ./weights/mobile_mul_sam.pt --model-type vit_t --output ./weights/mobile_mul_sam.onnx

# 得到单个输出
python scripts/export_onnx_model.py --checkpoint ./weights/mobile_single_sam.pt --model-type vit_t --return-single-mask --output ./weights/mobile_single_sam.onnx
量化onnx模型
onnx_model_path = 'mobile_single_sam.onnx'

onnx_model_quantized_path = "mobile_single_sam_quantized.onnx"
# 通过对模型进行量化和优化。我们发现,这显著改善了web运行时,而质量性能的变化可以忽略不计。
quantize_dynamic(
    model_input=onnx_model_path,
    model_output=onnx_model_quantized_path,
    optimize_model=True,
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QUInt8,
)
onnx_model_path = onnx_model_quantized_path
使用onnx 模型
# Using an ONNX model
ort_session = onnxruntime.InferenceSession(onnx_model_path)

checkpoint = "../weights/mobile_sam.pt"
model_type = "vit_t"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cpu')
predictor = SamPredictor(sam)
predictor.set_image(image)

image_embedding = predictor.get_image_embedding().cpu().numpy()
Onnx 推理参数

The ONNX model has a different input signature than SamPredictor.predict. The following inputs must all be supplied. Note the special cases for both point and mask inputs. All inputs are np.float32.

  1. image_embeddings: The image embedding from predictor.get_image_embedding(). Has a batch index of length 1.
  2. point_coords: Coordinates of sparse input prompts, corresponding to both point inputs and box inputs. Boxes are encoded using two points, one for the top-left corner and one for the bottom-right corner. Coordinates must already be transformed to long-side 1024. Has a batch index of length 1.
  3. point_labels: Labels for the sparse input prompts. 0 is a negative input point, 1 is a positive input point, 2 is a top-left box corner, 3 is a bottom-right box corner, and -1 is a padding point. If there is no box input, a single padding point with label -1 and coordinates (0.0, 0.0) should be concatenated.
  4. mask_input: A mask input to the model with shape 1x1x256x256. This must be supplied even if there is no mask input. In this case, it can just be zeros.
  5. has_mask_input: An indicator for the mask input. 1 indicates a mask input, 0 indicates no mask input.
  6. orig_im_size: The size of the input image in (H,W) format, before any transformation.

Additionally, the ONNX model does not threshold the output mask logits. To obtain a binary mask, threshold at sam.mask_threshold (equal to 0.0).

单点得到掩码
input_point = np.array([[250, 375]])
input_label = np.array([1])

# Add a batch index, concatenate a padding point, and transform.
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

# Create an empty mask input and an indicator for no mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

# Package the inputs to run in the onnx model
ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

# Predict a mask and threshold it.
masks, _, low_res_logits = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
多点得到掩码
input_point = np.array([[250, 375], [490, 380], [375, 360]])
input_label = np.array([1, 1, 0])

# Use the mask output from the previous run. It is already in the correct form for input to the ONNX model.
onnx_mask_input = low_res_logits

# Transform the points as in the previous example.
onnx_coord = np.concatenate([input_point, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, np.array([-1])], axis=0)[None, :].astype(np.float32)

onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

# The `has_mask_input` indicator is now 1.
onnx_has_mask_input = np.ones(1, dtype=np.float32)

# Package inputs, then predict and threshold the mask.
ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

masks, _, _ = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold

plt.figure(figsize=(10,10))
plt.imshow(image)
show_mask(masks, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()
points和box得到掩码
input_box = np.array([210, 200, 350, 500])
input_point = np.array([[275, 400]])
input_label = np.array([0])

# Add a batch index, concatenate a box and point inputs, add the appropriate labels for the box corners, and transform. There is no padding point since the input includes a box input.
onnx_box_coords = input_box.reshape(2, 2)
onnx_box_labels = np.array([2,3])

onnx_coord = np.concatenate([input_point, onnx_box_coords], axis=0)[None, :, :]
onnx_label = np.concatenate([input_label, onnx_box_labels], axis=0)[None, :].astype(np.float32)

onnx_coord = predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)

# Package inputs, then predict and threshold the mask.
onnx_mask_input = np.zeros((1, 1, 256, 256), dtype=np.float32)
onnx_has_mask_input = np.zeros(1, dtype=np.float32)

ort_inputs = {
    "image_embeddings": image_embedding,
    "point_coords": onnx_coord,
    "point_labels": onnx_label,
    "mask_input": onnx_mask_input,
    "has_mask_input": onnx_has_mask_input,
    "orig_im_size": np.array(image.shape[:2], dtype=np.float32)
}

masks, _, _ = ort_session.run(None, ort_inputs)
masks = masks > predictor.model.mask_threshold

plt.figure(figsize=(10, 10))
plt.imshow(image)
show_mask(masks[0], plt.gca())
show_box(input_box, plt.gca())
show_points(input_point, input_label, plt.gca())
plt.axis('off')
plt.show()

三、总结

从结果来看,MobileSAM相比于SAM,模型整体体积缩小了60倍,运行速度提高40倍,但分割效果却保持相当水平。个人认为,这对于视觉大模型在移动端的部署与应用是具有里程碑意义的。

关于MobileSAM模型的相关代码、论文PDF、预训练模型、使用方法等,我都已打包好,供需要的小伙伴交流研究,获取方式如下:

关注公众号,回复:MobileSAM,即可获取MobileSAM相关代码、论文、预训练模型、使用方法示例

四、链接作者

欢迎关注我的公众号:@AI算法与电子竞赛

硬性的标准其实限制不了无限可能的我们,所以啊!少年们加油吧!

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/559470.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

超越GPT-4V,苹果多模态大模型上新,神经形态计算加速MLLM(一)

4月8日,苹果发布了其最新的多模态大语言模型(MLLM )——Ferret-UI,能够更有效地理解和与屏幕信息进行交互,在所有基本UI任务上都超过了GPT-4V! 苹果开发的多模态模型Ferret-UI增强了对屏幕的理解和交互&am…

【YOLOv8改进[Backbone]】使用MobileNetV3助力YOLOv8网络结构轻量化并助力涨点

目录 一 MobileNetV3 1 面向块搜索的平台感知NAS和NetAdapt 2 反向残差和线性瓶颈 二 使用MobileNetV3助力YOLOv8 1 整体修改 ① 添加MobileNetV3.py文件 ② 修改ultralytics/nn/tasks.py文件 ③ 修改ultralytics/utils/torch_utils.py文件 2 配置文件 3 训练 其他 …

内置管线升级到SBP,如何复用之前打包的AssetBundle

1)内置管线升级到SBP,如何复用之前打包的AssetBundle 2)安卓真机,在Unity 2021.3.31版本下Buffer数据异常 3)URP里CullResults.CreateSharedRendererScene下面的消耗 4)移动端是否支持曲面细分着色 这是第3…

C#基础|Debug程序调试学习和技巧总结

哈喽,你好啊,我是雷工! 在程序的开发过程中,可能绝大部分时间是用来调试程序, 当完成了某个功能的编程,都需要调试一下程序,看编程是否存在问题。 01 为什么需要程序调试 无论是电气工程师还…

Zed,有望打败 VS Code 吗?

大家好,我是楷鹏。 先说结论,不行。 Zed,又一款新起的文本代码编辑器 👉 https://zed.dev 今年一月二十四号正式开源,短短不到三个月,GitHub 上已经冲上 3 万 star 正如 Zed 的口号所说「Code at the spe…

win11家庭中文版安装docker遇到Hyper-V启用失败,如何解决??

🏆本文收录于「Bug调优」专栏,主要记录项目实战过程中的Bug之前因后果及提供真实有效的解决方案,希望能够助你一臂之力,帮你早日登顶实现财富自由🚀;同时,欢迎大家关注&&收藏&&…

统一所有 LLM API:支持预算与速率限制 | 开源日报 No.229

BerriAI/litellm Stars: 6.7k License: NOASSERTION litellm 是一个使用 OpenAI 格式调用所有 LLM API 的工具。它支持 Bedrock、Azure、OpenAI、Cohere、Anthropic 等 100 多种 LLMs,提供企业级代理服务器和稳定版本 v1.30.2。 主要功能和优势包括: 将…

Jenkins的安装和部署

文章目录 概述Jenkins部署项目的流程jenkins的安装启动创建容器进入容器浏览器访问8085端口 Jenkins创建项目创建example项目 概述 Jenkins:是一个开源的、提供友好操作界面的持续集成(CLI)工具,主要用于持续、自动构建的一些定时…

nVisual在线网络规划设计软件

●01● nVisual在线网络规划设计软件 在信息化快速发展的今天,网络基础设施的建设与优化变得尤为关键。为了满足现代通信行业对高效、精准的网络规划需求,nVisual在线网络规划设计软件应运而生,它通过集成先进的GIS技术和网络规划工具&#…

如何快速学习盲打键盘的指法

学习盲打键盘的指法需要一定的时间和练习,但是以下几个方法可以帮助你加快学习的速度: 掌握正确的手位:了解标准的键盘布局以及手指应该放置的位置是学习盲打的第一步。在QWERTY键盘上,你的左手应该放在ASDF键上,右手应…

基于开源项目改造,我制作了15个酷炫的数据大屏(附 Python 源码)

数据可视化大屏在许多领域都有广泛的应用,其带来了好处也是显而易见的: 直观展示数据: 大屏幕数据可视化能够将庞大的数据集以图形化的方式展示出来,使人们能够更容易地理解和分析数据。这种可视化形式使信息更加直观,…

SSM项目前后端分离详细说明

1.后端 1.1打包 说明:使用idea打开项目,然后进行打包。 1.2tomcat 说明:把后端打成war包后放入tomcat启动。 1.3启动tomcat 说明: 找到tomcat中bin目录中的startup.bat文件,进行启动。如果启动失败,可以…

Stream 流常见基本操作

文章目录 概述一、Stream 流的常见生成方式二、Stream 流中间操作方法1、常用中间操作方法2、使用示例13、使用示例24、使用示例35、使用示例46、使用示例57、Stream 流使用注意事项 三、Stream 流终结操作方法1、常用终结方法2、使用示例13、使用示例24、使用示例35、Stream 基…

界面控件DevExpress Blazor UI v23.2 - 浅谈增强的可访问性

DevExpress Blazor UI组件库提供了一套全面的原生Blazor组件(包括DataGrid、Pivot Grid、 调度程序、图表、数据编辑器和报表),使用C#为Blazor Server和Blazor WebAssembly创建高影响力的用户体验! 获取DevExpress v23.2正式版下载(Q技术交流&#xff1…

【Linux开发 第三篇】vmtools安装,快照

虚拟机克隆 方式一 直接拷贝一份安装好的虚拟机文件,再用VM打开文件即可 方式二 使用vmware的克隆操作(克隆时要先关闭Linux系统) 虚拟机快照 如果你在使用虚拟机的时候,担心现在的操作,想回到操作之前的状态&a…

开放式耳机哪个牌子好?热门开放式耳机合集,买前必看!

随着人们对运动健康的重视,越来越多的运动爱好者开始关注如何在运动中享受音乐。开放式蓝牙耳机凭借其独特的设计,成为了户外运动的理想选择。它不仅让你在运动时能够清晰听到周围环境的声音,保持警觉,还能让你在需要时与他人轻松…

08 JavaScript学习:数据类型

JavaScript 数据类型 值类型(基本类型):字符串(String)、数字(Number)、布尔(Boolean)、空(Null)、未定义(Undefined)、Symbol。 引用数据类型(对象类型):对…

插入排序与希尔排序

文章目录 插入排序配图详解核心思想核心代码 源代码运行结果 希尔排序实现逻辑源代码运行结果 插入排序 插入排序在少量数据中是一个高效的算法,你可以想象在打牌的时候,左手是已经整理好的牌,右手是正在抓取的牌。 配图详解 对一组数据 5&…

手机号码空号过滤API:有效验证和过滤无效电话号码

随着移动通信技术的发展,手机号码成为人们日常生活和工作中不可或缺的一部分。然而,随着时间的推移,一些手机号码可能会变成空号,这给企业在进行电话营销和数据分析时带来了一定的困扰。为了解决这个问题,挖数据平台提…

武汉星起航:引领跨境电商新潮流,一站式孵化助力卖家轻松出海

武汉星起航电子商务有限公司,作为跨境电商领域的领军者,始终秉持“走出去”的战略理念,依托自营店铺的丰富经验和对跨境电商资源的深度整合,成功打造了一站式卖家孵化体系。这一体系集开店策划、运营教学、资源服务于一体&#xf…