LangSAM项目优化,将SAM修改为MoblieSAM,提速5~6倍

Language Segment-Anything 是一个开源项目,它结合了实例分割和文本提示的强大功能,为图像中的特定对象生成蒙版。它建立在最近发布的 Meta 模型、segment-anything 和 GroundingDINO 检测模型之上,是一款易于使用且有效的对象检测和图像分割工具。 然而在整个流程中,GroundingDINO 通常耗时0.6s作用,segment-anything 通常耗时5-8s左右。这是因为segment-anything model运算量比较大,将SAM修改为MoblieSAM可以将整个流程的预测时间降低到0.7s左右。

项目地址:https://github.com/luca-medeiros/lang-segment-anything/tree/main

1、前置条件

参考 https://blog.csdn.net/a486259/article/details/136820052 实现LangSAM项目代码的下载安装

使用以下命令 安装MobileSAM

pip install git+https://github.com/ChaoningZhang/MobileSAM.git

下载mobile_sam.pt存储到lang_sam目录下,具体如下图所示
在这里插入图片描述

2、修改lang_sam\lang_sam.py

将lang_sam\lang_sam.py的代码修改为以下内容

import os

import groundingdino.datasets.transforms as T
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.inference import predict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from huggingface_hub import hf_hub_download
from segment_anything import sam_model_registry
from segment_anything import SamPredictor
from mobile_sam import sam_model_registry as sam_moblie_model_registry
import time
SAM_MODELS = {
    "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
    "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
    "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
    "vit_t": "./mobile_sam.pt"
}

CACHE_PATH = os.environ.get("TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints"))


def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename, local_files_only=True )

    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True )
    #cache_file = 'F:\OPEN_PROJECT\GroundingDINO-main\weights\groundingdino_swint_ogc.pth'
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print(f"Model loaded from {cache_file} \n => {log}")
    model.eval()
    return model


def transform_image(image) -> torch.Tensor:
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    image_transformed, _ = transform(image, None)
    return image_transformed


class LangSAM():

    def __init__(self, sam_type="vit_t", ckpt_path=None):
        self.sam_type = sam_type
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.build_groundingdino()
        self.build_sam(ckpt_path)

    def build_sam(self, ckpt_path):
        if self.sam_type is None or ckpt_path is None:
            if self.sam_type is None:
                print("No sam type indicated. Using vit_t by default.")
                self.sam_type = "vit_t"
            checkpoint_url = SAM_MODELS[self.sam_type]
            try:
                if self.sam_type=='vit_t':
                    pt_url = os.path.dirname(os.path.abspath(__file__))+'/'+checkpoint_url
                    print(pt_url)
                    sam = sam_moblie_model_registry[self.sam_type](pt_url)
                    print("         use mobile sam!")
                else:
                    sam = sam_model_registry[self.sam_type]()
                    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
                    sam.load_state_dict(state_dict, strict=True)
            except:
                raise ValueError(f"Problem loading MobileSAM please make sure you have the right model type: {self.sam_type} \
                    and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
                    re-downloading it.")
            sam.to(device=self.device)
            self.sam = SamPredictor(sam)
        else:
            try:
                
                if self.sam_type=='vit_t':
                    sam = sam_moblie_model_registry[self.sam_type](ckpt_path)
                    print("         use mobile sam!")
                else:
                    sam = sam_model_registry[self.sam_type](ckpt_path)
            except:
                raise ValueError(f"Problem loading SAM. Your model type: {self.sam_type} \
                should match your checkpoint path: {ckpt_path}. Recommend calling LangSAM \
                using matching model type AND checkpoint path")
            sam.to(device=self.device)
            self.sam = SamPredictor(sam)

    def build_groundingdino(self):
        ckpt_repo_id = "ShilongLiu/GroundingDINO"
        ckpt_filename = "groundingdino_swinb_cogcoor.pth"
        ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
        self.groundingdino = load_model_hf(ckpt_repo_id, ckpt_filename, ckpt_config_filename)

    def predict_dino(self, image_pil, text_prompt, box_threshold, text_threshold):
        image_trans = transform_image(image_pil)
        boxes, logits, phrases = predict(model=self.groundingdino,
                                         image=image_trans,
                                         caption=text_prompt,
                                         box_threshold=box_threshold,
                                         text_threshold=text_threshold,
                                         device=self.device)
        W, H = image_pil.size
        boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

        return boxes, logits, phrases

    def predict_sam(self, image_pil, boxes):
        image_array = np.asarray(image_pil)
        self.sam.set_image(image_array)
        transformed_boxes = self.sam.transform.apply_boxes_torch(boxes, image_array.shape[:2])
        masks, _, _ = self.sam.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes.to(self.sam.device),
            multimask_output=False,
        )
        return masks.cpu()

    def predict(self, image_pil, text_prompt, box_threshold=0.3, text_threshold=0.25):
        t0=time.time()
        boxes, logits, phrases = self.predict_dino(image_pil, text_prompt, box_threshold, text_threshold)
        t1=time.time()
        print("self.predict_dino  use time:",t1-t0)
        masks = torch.tensor([])
        if len(boxes) > 0:
            masks = self.predict_sam(image_pil, boxes)
            masks = masks.squeeze(1)
        t2=time.time()
        print("self.predict_sam  use time:",t2-t1)
        return masks, boxes, phrases, logits


3、修改app.py

将原来app.py中的代码修改为以下内容。


import os

import groundingdino.datasets.transforms as T
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.inference import predict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from huggingface_hub import hf_hub_download
from segment_anything import sam_model_registry
from segment_anything import SamPredictor
from mobile_sam import sam_model_registry as sam_moblie_model_registry
import time
SAM_MODELS = {
    "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
    "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
    "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
    "vit_t": "./mobile_sam.pt"
}

CACHE_PATH = os.environ.get("TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints"))


def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename, local_files_only=True )

    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True )
    #cache_file = 'F:\OPEN_PROJECT\GroundingDINO-main\weights\groundingdino_swint_ogc.pth'
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print(f"Model loaded from {cache_file} \n => {log}")
    model.eval()
    return model


def transform_image(image) -> torch.Tensor:
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    image_transformed, _ = transform(image, None)
    return image_transformed


class LangSAM():

    def __init__(self, sam_type="vit_t", ckpt_path=None):
        self.sam_type = sam_type
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.build_groundingdino()
        self.build_sam(ckpt_path)

    def build_sam(self, ckpt_path):
        if self.sam_type is None or ckpt_path is None:
            if self.sam_type is None:
                print("No sam type indicated. Using vit_t by default.")
                self.sam_type = "vit_t"
            checkpoint_url = SAM_MODELS[self.sam_type]
            try:
                if self.sam_type=='vit_t':
                    pt_url = os.path.dirname(os.path.abspath(__file__))+'/'+checkpoint_url
                    print(pt_url)
                    sam = sam_moblie_model_registry[self.sam_type](pt_url)
                    print("         use mobile sam!")
                else:
                    sam = sam_model_registry[self.sam_type]()
                    state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
                    sam.load_state_dict(state_dict, strict=True)
            except:
                raise ValueError(f"Problem loading MobileSAM please make sure you have the right model type: {self.sam_type} \
                    and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
                    re-downloading it.")
            sam.to(device=self.device)
            self.sam = SamPredictor(sam)
        else:
            try:
                
                if self.sam_type=='vit_t':
                    sam = sam_moblie_model_registry[self.sam_type](ckpt_path)
                    print("         use mobile sam!")
                else:
                    sam = sam_model_registry[self.sam_type](ckpt_path)
            except:
                raise ValueError(f"Problem loading SAM. Your model type: {self.sam_type} \
                should match your checkpoint path: {ckpt_path}. Recommend calling LangSAM \
                using matching model type AND checkpoint path")
            sam.to(device=self.device)
            self.sam = SamPredictor(sam)

    def build_groundingdino(self):
        ckpt_repo_id = "ShilongLiu/GroundingDINO"
        ckpt_filename = "groundingdino_swinb_cogcoor.pth"
        ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
        self.groundingdino = load_model_hf(ckpt_repo_id, ckpt_filename, ckpt_config_filename)

    def predict_dino(self, image_pil, text_prompt, box_threshold, text_threshold):
        image_trans = transform_image(image_pil)
        boxes, logits, phrases = predict(model=self.groundingdino,
                                         image=image_trans,
                                         caption=text_prompt,
                                         box_threshold=box_threshold,
                                         text_threshold=text_threshold,
                                         device=self.device)
        W, H = image_pil.size
        boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

        return boxes, logits, phrases

    def predict_sam(self, image_pil, boxes):
        image_array = np.asarray(image_pil)
        self.sam.set_image(image_array)
        transformed_boxes = self.sam.transform.apply_boxes_torch(boxes, image_array.shape[:2])
        masks, _, _ = self.sam.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes.to(self.sam.device),
            multimask_output=False,
        )
        return masks.cpu()

    def predict(self, image_pil, text_prompt, box_threshold=0.3, text_threshold=0.25):
        t0=time.time()
        boxes, logits, phrases = self.predict_dino(image_pil, text_prompt, box_threshold, text_threshold)
        t1=time.time()
        print("self.predict_dino  use time:",t1-t0)
        masks = torch.tensor([])
        if len(boxes) > 0:
            masks = self.predict_sam(image_pil, boxes)
            masks = masks.squeeze(1)
        t2=time.time()
        print("self.predict_sam  use time:",t2-t1)
        return masks, boxes, phrases, logits

4、使用app.py

执行app.py,控制台输出如下所示
在这里插入图片描述
识别效果如下所示,与原始的LangSAM项目输出一模一样,但快了5~6倍
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/497185.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

定时任务 之 cron 表达式

Cron 表达式产生的背景:在Unix系统中,用户经常需要设置一些周期性被执行的任务,如定期备份文件、发送邮件等。为了满足这种需求,Unix系统提供了crontab命令,允许用户定义任务的时间表,并在指定的时间点自动…

实现实时查询并带有查询结果列表的输入框

这个功能主要是实现了一个可以实时查询结果的搜索框,并具备点击外部关闭搜索结果框体的功能,除了v-show和transition依托于vue实现以外其余功能都基于原生JS实现。 效果图: 该功能的实现主要是很久之前面试被问到过,当时没有做出…

Linux:进程控制

进程创建 进程&#xff1a;内核的相关管理数据结构&#xff08;task_structmm_struct页表&#xff09;代码&#xff08;<-共享&#xff09;和数据(<-写时拷贝) fork函数初识 在 linux 中 fork 函数时非常重要的函数&#xff0c;它从已存在进程中创建一个新进程。新进程…

1992-2022年经过矫正的夜间灯光数据

DMSP/OLS夜间灯光的年份是1992-2013年&#xff0c;NPP/VIIRS夜间灯光的年份是2012-2021&#xff0c;且由于光谱分辨率、空间分辨率、辐射分辨率、产品更新周期等方面的差异&#xff0c;DMSP-OLS和SNPP-VIIRS数据不兼容&#xff0c;也就是说我们无法直接对比分析DMSP-OLS和SNPP-…

Linux常用命令-文件操作

文章目录 ls基本用法常用选项组合选项示例注意事项 cd基本用法示例注意事项 pwd基本用法示例选项总结 cp基本用法常见选项示例注意事项 rm基本用法常见选项示例删除单个文件&#xff1a;交互式删除文件&#xff1a;强制删除文件&#xff1a;递归删除目录&#xff1a;交互式递归…

实验02-1 C#和ASP.NET控件:在Web窗体中输出九九乘法表

【实验内容及要求】 1. 在Web窗体中输出九九乘法表 浏览效果如图2-1所示。 图2-1 在Default.aspx.cs中编写C#代码 using System; using System.Collections.Generic; using System.Linq; using System.Web; using System.Web.UI; using System.Web.UI.WebControls;public par…

项目四-图书管理系统

1.创建项目 流程与之前的项目一致&#xff0c;不再进行赘述。 2.需求定义 需求: 1. 登录: ⽤⼾输⼊账号,密码完成登录功能 2. 列表展⽰: 展⽰图书 3.前端界面测试 无法启动&#xff01;&#xff01;&#xff01;--->记得加入mysql相关操作记得在yml进行配置 配置后启动…

操作系统系列学习——多级页表与快表

文章目录 前言多级页表与快表 前言 一个本硕双非的小菜鸡&#xff0c;备战24年秋招&#xff0c;计划学习操作系统并完成6.0S81&#xff0c;加油&#xff01; 本文总结自B站【哈工大】操作系统 李治军&#xff08;全32讲&#xff09; 老师课程讲的非常好&#xff0c;感谢 【哈工…

PythonGUI应用:模拟航空订票小程序

在本教程中&#xff0c;我们将创建一个基本的航空订票管理系统GUI应用&#xff0c;用户可以通过图形界面执行各种操作。我们将使用Python编程语言和Tkinter库来实现此应用。 功能概述&#xff1a; 航班管理&#xff1a; 用户可以添加新的航班&#xff0c;输入航班号、起始地、目…

Convex and Semi-Nonnegative Matrix Factorizations

我们提出了非负矩阵分解&#xff08;NMF&#xff09;主题的几种新变体。考虑形式为X FG^T的因子分解&#xff0c;我们关注的是G被限制为包含非负元素的算法&#xff0c;但允许数据矩阵X具有混合符号&#xff0c;从而扩展了NMF方法的适用范围。我们还考虑了基向量F被约束为数据…

Ubuntu20.04更换镜像源------最简单的教程

本教程适用于&#xff1a;Ubuntu22.04 操作流程 打开终端&#xff0c;运行以下命令&#xff1a; sudo sed -i "shttp://.*archive.ubuntu.comhttps://mirrors.tuna.tsinghua.edu.cng" /etc/apt/sources.list 运行后即完成更改。 如果找不到22.04的镜像&#xff…

海外盲盒APP:加速开拓海外盲盒市场

盲盒是年轻群体消费中增速较快的模式&#xff0c;从前几年起&#xff0c;盲盒就在我国掀起了一股热潮&#xff0c;市场得到了迅速发展。 如今&#xff0c;盲盒经济已经遍布到了全球&#xff0c;尤其是在亚洲地区&#xff0c;盲盒消费呈现出了高速发展态势&#xff0c;在海外市…

支小蜜校园防霸凌系统的具体功能是什么?

在当今社会&#xff0c;校园霸凌问题日益严重&#xff0c;成为影响学生健康成长的一大隐患。为了应对这一问题&#xff0c;许多学校开始引入校园防霸凌系统。这一系统以其独特的功能&#xff0c;为校园安全提供了有力保障&#xff0c;为学生的健康成长创造了良好环境。 校园防…

蓝桥杯单片机快速开发笔记——PCF8591的DAC模拟电压输出

一、原理分析 PCF8591电压信号探测器&#xff1a;http://t.csdnimg.cn/R38tC IIC原理&#xff1a;http://t.csdnimg.cn/v4dSv IIC指令&#xff1a;http://t.csdnimg.cn/RY6yi HC573/HC138&#xff1a;http://t.csdnimg.cn/W0a0U 数码管&#xff1a;http://t.csdnimg.cn/kfm9Y 独…

反序列化动态调用 [NPUCTF2020]ReadlezPHP1

在源代码上看到提示 访问一下看看 代码审计一下 <?php #error_reporting(0); class HelloPhp {public $a;public $b;public function __construct(){$this->a "Y-m-d h:i:s";$this->b "date";}public function __destruct(){$a $this->a;…

编译安装飞桨fastdeploy@FreeBSD(失败)

FastDeploy是一款全场景、易用灵活、极致高效的AI推理部署工具&#xff0c; 支持云边端部署。提供超过 &#x1f525;160 Text&#xff0c;Vision&#xff0c; Speech和跨模态模型&#x1f4e6;开箱即用的部署体验&#xff0c;并实现&#x1f51a;端到端的推理性能优化。包括 物…

上传镜像到仓库

上传镜像到公开仓库 1、给要上传的镜像打标签 # 从206节点上传镜像到仓库&#xff08;201&#xff09;magedu项目&#xff0c;查看206镜像 [rootk8s-node2 ~]# docker images REPOSITORY TAG IMAGE ID CRE…

arp 协议

数据链路层 我们之前学习到的 IP 协议解决的是数据跨网络传输的问题。 数据链路层解决的是&#xff1a;直接相连的主机&#xff0c;进行数据交付的问题&#xff01; 直接相连的设备包括我们的电脑&#xff0c;路由器等等哈&#xff01; 我们在网络基础那篇文章中讲过什么是以…

OneDiff加速“图生生”,解锁电商AI图像处理新范式

2024年&#xff0c;电商领域正目睹生成式AI软件工具的飞速发展&#xff0c;AI Generated Content (AIGC) 技术在电商应用中的普及率正在显著提升&#xff0c;这类技术能够显著提高商业运营的效率&#xff0c;并促进业绩的稳步增长。 硅基流动研发的图片/视频生成推理引擎OneDif…

近线数仓优化改造

近线数仓优化改造 1. 背景2. 优化3. 改造3.1. 重构3.2. 优化 1. 背景 大概就是有那么一个数仓&#xff0c;然后简略结构如下&#xff1a; #mermaid-svg-PVoUzuQhj2BK7Qge {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid…