Language Segment-Anything 是一个开源项目,它结合了实例分割和文本提示的强大功能,为图像中的特定对象生成蒙版。它建立在最近发布的 Meta 模型、segment-anything 和 GroundingDINO 检测模型之上,是一款易于使用且有效的对象检测和图像分割工具。 然而在整个流程中,GroundingDINO 通常耗时0.6s作用,segment-anything 通常耗时5-8s左右。这是因为segment-anything model运算量比较大,将SAM修改为MoblieSAM可以将整个流程的预测时间降低到0.7s左右。
项目地址:https://github.com/luca-medeiros/lang-segment-anything/tree/main
1、前置条件
参考 https://blog.csdn.net/a486259/article/details/136820052 实现LangSAM项目代码的下载安装
使用以下命令 安装MobileSAM
pip install git+https://github.com/ChaoningZhang/MobileSAM.git
下载mobile_sam.pt存储到lang_sam目录下,具体如下图所示
2、修改lang_sam\lang_sam.py
将lang_sam\lang_sam.py的代码修改为以下内容
import os
import groundingdino.datasets.transforms as T
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.inference import predict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from huggingface_hub import hf_hub_download
from segment_anything import sam_model_registry
from segment_anything import SamPredictor
from mobile_sam import sam_model_registry as sam_moblie_model_registry
import time
SAM_MODELS = {
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_t": "./mobile_sam.pt"
}
CACHE_PATH = os.environ.get("TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints"))
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename, local_files_only=True )
args = SLConfig.fromfile(cache_config_file)
model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True )
#cache_file = 'F:\OPEN_PROJECT\GroundingDINO-main\weights\groundingdino_swint_ogc.pth'
checkpoint = torch.load(cache_file, map_location='cpu')
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
print(f"Model loaded from {cache_file} \n => {log}")
model.eval()
return model
def transform_image(image) -> torch.Tensor:
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image_transformed, _ = transform(image, None)
return image_transformed
class LangSAM():
def __init__(self, sam_type="vit_t", ckpt_path=None):
self.sam_type = sam_type
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.build_groundingdino()
self.build_sam(ckpt_path)
def build_sam(self, ckpt_path):
if self.sam_type is None or ckpt_path is None:
if self.sam_type is None:
print("No sam type indicated. Using vit_t by default.")
self.sam_type = "vit_t"
checkpoint_url = SAM_MODELS[self.sam_type]
try:
if self.sam_type=='vit_t':
pt_url = os.path.dirname(os.path.abspath(__file__))+'/'+checkpoint_url
print(pt_url)
sam = sam_moblie_model_registry[self.sam_type](pt_url)
print(" use mobile sam!")
else:
sam = sam_model_registry[self.sam_type]()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
sam.load_state_dict(state_dict, strict=True)
except:
raise ValueError(f"Problem loading MobileSAM please make sure you have the right model type: {self.sam_type} \
and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
re-downloading it.")
sam.to(device=self.device)
self.sam = SamPredictor(sam)
else:
try:
if self.sam_type=='vit_t':
sam = sam_moblie_model_registry[self.sam_type](ckpt_path)
print(" use mobile sam!")
else:
sam = sam_model_registry[self.sam_type](ckpt_path)
except:
raise ValueError(f"Problem loading SAM. Your model type: {self.sam_type} \
should match your checkpoint path: {ckpt_path}. Recommend calling LangSAM \
using matching model type AND checkpoint path")
sam.to(device=self.device)
self.sam = SamPredictor(sam)
def build_groundingdino(self):
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filename = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
self.groundingdino = load_model_hf(ckpt_repo_id, ckpt_filename, ckpt_config_filename)
def predict_dino(self, image_pil, text_prompt, box_threshold, text_threshold):
image_trans = transform_image(image_pil)
boxes, logits, phrases = predict(model=self.groundingdino,
image=image_trans,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold,
device=self.device)
W, H = image_pil.size
boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
return boxes, logits, phrases
def predict_sam(self, image_pil, boxes):
image_array = np.asarray(image_pil)
self.sam.set_image(image_array)
transformed_boxes = self.sam.transform.apply_boxes_torch(boxes, image_array.shape[:2])
masks, _, _ = self.sam.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(self.sam.device),
multimask_output=False,
)
return masks.cpu()
def predict(self, image_pil, text_prompt, box_threshold=0.3, text_threshold=0.25):
t0=time.time()
boxes, logits, phrases = self.predict_dino(image_pil, text_prompt, box_threshold, text_threshold)
t1=time.time()
print("self.predict_dino use time:",t1-t0)
masks = torch.tensor([])
if len(boxes) > 0:
masks = self.predict_sam(image_pil, boxes)
masks = masks.squeeze(1)
t2=time.time()
print("self.predict_sam use time:",t2-t1)
return masks, boxes, phrases, logits
3、修改app.py
将原来app.py中的代码修改为以下内容。
import os
import groundingdino.datasets.transforms as T
import numpy as np
import torch
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.inference import predict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from huggingface_hub import hf_hub_download
from segment_anything import sam_model_registry
from segment_anything import SamPredictor
from mobile_sam import sam_model_registry as sam_moblie_model_registry
import time
SAM_MODELS = {
"vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
"vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
"vit_t": "./mobile_sam.pt"
}
CACHE_PATH = os.environ.get("TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints"))
def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename, local_files_only=True )
args = SLConfig.fromfile(cache_config_file)
model = build_model(args)
args.device = device
cache_file = hf_hub_download(repo_id=repo_id, filename=filename, local_files_only=True )
#cache_file = 'F:\OPEN_PROJECT\GroundingDINO-main\weights\groundingdino_swint_ogc.pth'
checkpoint = torch.load(cache_file, map_location='cpu')
log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
print(f"Model loaded from {cache_file} \n => {log}")
model.eval()
return model
def transform_image(image) -> torch.Tensor:
transform = T.Compose([
T.RandomResize([800], max_size=1333),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])
image_transformed, _ = transform(image, None)
return image_transformed
class LangSAM():
def __init__(self, sam_type="vit_t", ckpt_path=None):
self.sam_type = sam_type
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.build_groundingdino()
self.build_sam(ckpt_path)
def build_sam(self, ckpt_path):
if self.sam_type is None or ckpt_path is None:
if self.sam_type is None:
print("No sam type indicated. Using vit_t by default.")
self.sam_type = "vit_t"
checkpoint_url = SAM_MODELS[self.sam_type]
try:
if self.sam_type=='vit_t':
pt_url = os.path.dirname(os.path.abspath(__file__))+'/'+checkpoint_url
print(pt_url)
sam = sam_moblie_model_registry[self.sam_type](pt_url)
print(" use mobile sam!")
else:
sam = sam_model_registry[self.sam_type]()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
sam.load_state_dict(state_dict, strict=True)
except:
raise ValueError(f"Problem loading MobileSAM please make sure you have the right model type: {self.sam_type} \
and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
re-downloading it.")
sam.to(device=self.device)
self.sam = SamPredictor(sam)
else:
try:
if self.sam_type=='vit_t':
sam = sam_moblie_model_registry[self.sam_type](ckpt_path)
print(" use mobile sam!")
else:
sam = sam_model_registry[self.sam_type](ckpt_path)
except:
raise ValueError(f"Problem loading SAM. Your model type: {self.sam_type} \
should match your checkpoint path: {ckpt_path}. Recommend calling LangSAM \
using matching model type AND checkpoint path")
sam.to(device=self.device)
self.sam = SamPredictor(sam)
def build_groundingdino(self):
ckpt_repo_id = "ShilongLiu/GroundingDINO"
ckpt_filename = "groundingdino_swinb_cogcoor.pth"
ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
self.groundingdino = load_model_hf(ckpt_repo_id, ckpt_filename, ckpt_config_filename)
def predict_dino(self, image_pil, text_prompt, box_threshold, text_threshold):
image_trans = transform_image(image_pil)
boxes, logits, phrases = predict(model=self.groundingdino,
image=image_trans,
caption=text_prompt,
box_threshold=box_threshold,
text_threshold=text_threshold,
device=self.device)
W, H = image_pil.size
boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
return boxes, logits, phrases
def predict_sam(self, image_pil, boxes):
image_array = np.asarray(image_pil)
self.sam.set_image(image_array)
transformed_boxes = self.sam.transform.apply_boxes_torch(boxes, image_array.shape[:2])
masks, _, _ = self.sam.predict_torch(
point_coords=None,
point_labels=None,
boxes=transformed_boxes.to(self.sam.device),
multimask_output=False,
)
return masks.cpu()
def predict(self, image_pil, text_prompt, box_threshold=0.3, text_threshold=0.25):
t0=time.time()
boxes, logits, phrases = self.predict_dino(image_pil, text_prompt, box_threshold, text_threshold)
t1=time.time()
print("self.predict_dino use time:",t1-t0)
masks = torch.tensor([])
if len(boxes) > 0:
masks = self.predict_sam(image_pil, boxes)
masks = masks.squeeze(1)
t2=time.time()
print("self.predict_sam use time:",t2-t1)
return masks, boxes, phrases, logits
4、使用app.py
执行app.py,控制台输出如下所示
识别效果如下所示,与原始的LangSAM项目输出一模一样,但快了5~6倍