图像抠图DIS——自然图像中高精度二分图像抠图的方法(C++/python模型推理)

概述

DIS(Dichotomous Image Segmentation)是一种新的图像分割任务,旨在从自然图像中分割出高精度的物体。与传统的图像分割任务相比,DIS更侧重于具有单个或几个目标的图像,因此可以提供更丰富准确的细节。

为了研究DIS任务,研究人员创建了一个名为DIS5K的大规模、可扩展的数据集。DIS5K数据集包含了5,470张高分辨率图像,每张图像都配有高精度的二值分割掩码。这个数据集的建立有助于推动多个应用方向的发展,如图像去背景、艺术设计、模拟视图运动、基于图像的增强现实(AR)应用、基于视频的AR应用、3D视频制作等。

通过研究DIS任务和使用DIS5K数据集,研究人员可以探索新的图像分割方法,并为各种应用领域提供更精确、更可靠的图像分割技术,从而推动分割技术在更广泛的领域中的应用。

官网:https://xuebinqin.github.io/dis/index.html
Github:https://github.com/xuebinqin/DIS

数据集

图像二类分割是将图像分割成两个主要区域:前景和背景。在这种情况下,前景代表图像中的某个类别的物体,而背景则是除了该物体之外的所有内容。
官方公布了算所使用的数据集DIS5K, DIS5K数据集中的每张图像都经过了像素级别的手工标注,标注的真值掩码非常精确,每张图像的标记时间相当长。这种高精度的标注使得数据集中的每个像素都与其相应的类别关联起来,从而为模型提供了可靠的训练数据。这种高精度的标注是实现图像二类分割的关键,因为模型需要能够准确地识别和分割出前景物体。

在DIS5K数据集中,标注对象的类型多样,包括透明和半透明的物体,标注使用单个像素的二值掩码进行。这种精确的标注确保了模型训练的有效性和准确性,并且使得模型能够预测出高精度的物体分割结果。

DIS5K数据集网盘地址:https://pan.baidu.com/s/1umNk2AeBG5aB5kXlHTHdIg
提取码:7qfs

模型训练

模型训练可参考git上的官方的文档

模型推理

模型C++使用onnxruntime进行推理

#include <opencv2/opencv.hpp>
#include <onnxruntime_cxx_api.h>


class DIS
{
public:
	DIS(std::string model_path);
	void inference(cv::Mat& cv_src, cv::Mat& cv_mask);
private:
	std::vector<float> input_image_;
	int inpWidth;
	int inpHeight;
	int outWidth;
	int outHeight;
	const float score_th = 0;

	Ort::Env env = Ort::Env(ORT_LOGGING_LEVEL_ERROR, "DIS");
	Ort::Session* ort_session = nullptr;
	Ort::SessionOptions sessionOptions = Ort::SessionOptions();
	std::vector<char*> input_names;
	std::vector<char*> output_names;
	std::vector<std::vector<int64_t>> input_node_dims; // >=1 outputs
	std::vector<std::vector<int64_t>> output_node_dims; // >=1 outputs
};



DIS::DIS(std::string model_path)
{
	std::wstring widestr = std::wstring(model_path.begin(), model_path.end());
	//OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);
	sessionOptions.SetGraphOptimizationLevel(ORT_ENABLE_BASIC);
	ort_session = new Ort::Session(env, widestr.c_str(), sessionOptions);
	size_t numInputNodes = ort_session->GetInputCount();
	size_t numOutputNodes = ort_session->GetOutputCount();
	Ort::AllocatorWithDefaultOptions allocator;
	for (int i = 0; i < numInputNodes; i++)
	{
		input_names.push_back(ort_session->GetInputName(i, allocator));
		Ort::TypeInfo input_type_info = ort_session->GetInputTypeInfo(i);
		auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
		auto input_dims = input_tensor_info.GetShape();
		input_node_dims.push_back(input_dims);
	}
	for (int i = 0; i < numOutputNodes; i++)
	{
		output_names.push_back(ort_session->GetOutputName(i, allocator));
		Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo(i);
		auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
		auto output_dims = output_tensor_info.GetShape();
		output_node_dims.push_back(output_dims);
	}
	this->inpHeight = input_node_dims[0][2];
	this->inpWidth = input_node_dims[0][3];
	this->outHeight = output_node_dims[0][2];
	this->outWidth = output_node_dims[0][3];
}


void DIS::inference(cv::Mat& cv_src, cv::Mat& cv_mask)
{
	cv::Mat cv_dst;
	cv::resize(cv_src, cv_dst, cv::Size(this->inpWidth, this->inpHeight));
	this->input_image_.resize(this->inpWidth * this->inpHeight * cv_dst.channels());
	for (int c = 0; c < 3; c++)
	{
		for (int i = 0; i < this->inpHeight; i++)
		{
			for (int j = 0; j < this->inpWidth; j++)
			{
				float pix = cv_dst.ptr<uchar>(i)[j * 3 + 2 - c];
				this->input_image_[c * this->inpHeight * this->inpWidth + i * this->inpWidth + j] = pix / 255.0 - 0.5;
			}
		}
	}
	std::array<int64_t, 4> input_shape_{ 1, 3, this->inpHeight, this->inpWidth };

	auto allocator_info = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
	Ort::Value input_tensor_ = Ort::Value::CreateTensor<float>(allocator_info,
		input_image_.data(), input_image_.size(), input_shape_.data(), input_shape_.size());

	std::vector<Ort::Value> ort_outputs = ort_session->Run(Ort::RunOptions{ nullptr }, &input_names[0],
		&input_tensor_, 1, output_names.data(), output_names.size());   // 开始推理
	float* pred = ort_outputs[0].GetTensorMutableData<float>();
	cv::Mat mask(outHeight, outWidth, CV_32FC1, pred);
	double min_value, max_value;
	minMaxLoc(mask, &min_value, &max_value, 0, 0);
	mask = (mask - min_value) / (max_value - min_value);
	cv::resize(mask, cv_mask, cv::Size(cv_src.cols, cv_src.rows));
}

void show_img(std::string name, const cv::Mat& img)
{
	cv::namedWindow(name, 0);
	int max_rows = 500;
	int max_cols = 600;
	if (img.rows >= img.cols && img.rows > max_rows) {
		cv::resizeWindow(name, cv::Size(img.cols * max_rows / img.rows, max_rows));
	}
	else if (img.cols >= img.rows && img.cols > max_cols) {
		cv::resizeWindow(name, cv::Size(max_cols, img.rows * max_cols / img.cols));
	}
	cv::imshow(name, img);
}

cv::Mat replaceBG(const cv::Mat cv_src, cv::Mat& alpha, std::vector<int>& bg_color)
{
	int width = cv_src.cols;
	int height = cv_src.rows;

	cv::Mat cv_matting = cv::Mat::zeros(cv::Size(width, height), CV_8UC3);

	float* alpha_data = (float*)alpha.data;
	for (int i = 0; i < height; i++)
	{
		for (int j = 0; j < width; j++)
		{
			float alpha_ = alpha_data[i * width + j];
			cv_matting.at < cv::Vec3b>(i, j)[0] = cv_src.at < cv::Vec3b>(i, j)[0] * alpha_ + (1 - alpha_) * bg_color[0];
			cv_matting.at < cv::Vec3b>(i, j)[1] = cv_src.at < cv::Vec3b>(i, j)[1] * alpha_ + (1 - alpha_) * bg_color[1];
			cv_matting.at < cv::Vec3b>(i, j)[2] = cv_src.at < cv::Vec3b>(i, j)[2] * alpha_ + (1 - alpha_) * bg_color[2];
		}
	}

	return cv_matting;
}

int main()
{
	DIS dis_net("isnet_general_use_720x1280.onnx");

	std::string path = "images";
	std::vector<std::string> filenames;
	cv::glob(path, filenames, false);

	for (auto file_name : filenames)
	{
		cv::Mat cv_src = cv::imread(file_name);
		//std::vector<cv::Mat> cv_dsts;
		cv::Mat cv_dst, cv_mask;
		dis_net.inference(cv_src, cv_mask);
		std::vector<int> color{255, 0, 0};
		cv_dst=replaceBG(cv_src, cv_mask, color);

		show_img("src", cv_src);
		show_img("mask", cv_mask);
		show_img("dst", cv_dst);

		cv::waitKey(0);
	}
}

python推理代码也依赖onnxruntime

import argparse
import cv2
import numpy as np
import onnxruntime
### onnxruntime load ['isnet_general_use_HxW.onnx', 'isnet_HxW.onnx', 'isnet_Nx3xHxW.onnx']  inference failed
class DIS():
    def __init__(self, modelpath, score_th=None):
        so = onnxruntime.SessionOptions()
        so.log_severity_level = 3
        self.net = onnxruntime.InferenceSession(modelpath, so)
        self.input_height = self.net.get_inputs()[0].shape[2]
        self.input_width = self.net.get_inputs()[0].shape[3]
        self.input_name = self.net.get_inputs()[0].name
        self.output_name = self.net.get_outputs()[0].name
        self.score_th = score_th

    def detect(self, srcimg):
        img = cv2.resize(srcimg, dsize=(self.input_width, self.input_height))
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = img.astype(np.float32) / 255.0 - 0.5
        blob = np.expand_dims(np.transpose(img, (2, 0, 1)), axis=0).astype(np.float32)
        outs = self.net.run([self.output_name], {self.input_name: blob})
        
        mask = np.array(outs[0]).squeeze()
        min_value = np.min(mask)
        max_value = np.max(mask)
        mask = (mask - min_value) / (max_value - min_value)
        if self.score_th is not None:
            mask = np.where(mask < self.score_th, 0, 1)
        mask *= 255
        mask = mask.astype('uint8')

        mask = cv2.resize(mask, dsize=(srcimg.shape[1], srcimg.shape[0]), interpolation=cv2.INTER_LINEAR)
        return mask

def generate_overlay_image(srcimg, mask):
    overlay_image = np.zeros(srcimg.shape, dtype=np.uint8)
    overlay_image[:] = (255, 255, 255)
    mask = np.stack((mask,) * 3, axis=-1).astype('uint8') 
    mask_image = np.where(mask, srcimg, overlay_image)
    return mask, mask_image

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--imgpath", type=str, default='images/cam_image47.jpg')
    parser.add_argument("--modelpath", type=str, default='weights/isnet_general_use_480x640.onnx')
    args = parser.parse_args()
    
    mynet = DIS(args.modelpath)
    srcimg = cv2.imread(args.imgpath)
    mask = mynet.detect(srcimg)
    mask, overlay_image = generate_overlay_image(srcimg, mask)

    winName = 'Deep learning object detection in onnxruntime'
    cv2.namedWindow(winName, cv2.WINDOW_NORMAL)
    cv2.imshow(winName, np.hstack((srcimg, mask)))
    cv2.waitKey(0)
    cv2.destroyAllWindows()

推理结果
在这里插入图片描述

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
资源和模型下载地址:https://download.csdn.net/download/matt45m/89024664

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/486718.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Java只有中国人在搞了吗?

还是看你将来想干啥。想干应用架构&#xff0c;与Java狗谈笑风生&#xff0c;沆瀣一气&#xff0c;你就好好写Java&#xff0c;学DDD&#xff0c;看Clean Architecture。你想成为炼丹玄学工程师&#xff0c;年入百万&#xff0c;就选python&#xff0c;专精各种paper。你不在意…

如何修改SystemUI Clock的样式

开机的流程为&#xff1a; 在 CollapsedStatusBarFragment 的onCreateView 方法中 inflate R.layout.status_bar.xml&#xff0c; 里面定义有Clock。 CollapsedStatusBarFragment 的被调用流程为&#xff1a; 在StatusBar 的makeStatusBarView方法中显示出来。 所以可以在文…

Vue 若依框架 form-generator添加表格组件和动态表单组件

效果图&#xff1a; 在若依框架自带的流程表单配置基础上添加这两个组件 config.js // 表单属性【右面板】 export const formConf {formRef: elForm,formModel: formData,other: other,size: medium,labelPosition: right,labelWidth: 100,formRules: rules,gutter: 15,dis…

vue2 和 vue3 配置路由有什么区别

vue2 和 vue3 配置路由有什么区别 初始化路由器实例&#xff1a;注入到应用中&#xff1a;动态路由参数和捕获所有路由&#xff1a;编程式导航 API&#xff1a;异步加载组件&#xff1a; vue2 如何 使用路由 第一步&#xff1a;安装 vue-router第二步&#xff1a;创建路由组件第…

在面对API的安全风险,WAAP全站防护能做到哪些?

随着数字化转型的加速&#xff0c;API&#xff08;应用程序接口&#xff09;已经成为企业间和企业内部系统交互的核心组件。在应用程序开发过程中&#xff0c;API能够在不引起用户注意的情况下&#xff0c;无缝、流畅地完成各种任务。例如从一个应用程序中提取所需数据并传递给…

【MySQL】知识点 + 1

# &#xff08;1&#xff09;查询当前日期、当前时间以及到2022年1月1日还有多少天&#xff0c;然后通过mysql命令执行命令。 select curdate() AS 当前日期,curtime() AS 当前时间,datediff(2022-01-01, curdate()) AS 距离2022年1月1日还有天数;# &#xff08;2&#xff09;利…

【iOS ARKit】3D文字

首先&#xff0c;3D场景中渲染的任何虚拟元素都必须具有网格&#xff08;顶点及顶点间的拓扑关系&#xff09;&#xff0c;没有网格的元素无法利用GPU 进行渲染&#xff0c;因此&#xff0c;在3D 场景申渲染 3D文字时&#xff0c;文字也必须具有网格。在计算机系统中&#xff0…

发展新质生产力,亚信科技切中产业痛点

管理学大师拉姆查兰认为&#xff0c;经营性不确定性通常在预知范围之内&#xff0c;不会对原有格局产生根本性影响&#xff1b;而结构性不确定性则源于外部环境的根本性变化&#xff0c;将彻底改变产业格局&#xff0c;带来根本性影响。 毫无疑问&#xff0c;一个充满结构性不…

VS Code配置Python环境

首先贴一张完全卸载VS Code的图&#xff0c;包括一些配置和插件。 讲述一下如何配置Python环境以及和Conda的配合使用(涉及到虚拟环境) VS Code配置Python需要三步&#xff1a;安装Python环境&#xff1b;在VS Code软件中下载Python插件&#xff1b;新建python文件开始coding。…

Docker容器初始

华子目录 docker简介虚拟化技术硬件级虚拟化硬件级虚拟化历史操作系统虚拟化历史基于服务的云计算模式 什么是dockerDocker和传统虚拟化方式的不同之处为什么要使用docker&#xff1f;Docker 在如下几个方面具有较大的优势 对比传统虚拟机总结docker应用场景docker改变了什么 基…

抖音小店和抖音橱窗有什么区别?普通人最适合做哪个?

大家好&#xff0c;我是电商糖果 说起抖音卖货&#xff0c;很多人都会搞不清楚抖音小店和抖音橱窗有什么不同。 甚至有的朋友将他们认为是一个项目。 这里糖果就帮大家仔细的分辨一下&#xff0c;想在抖音卖货的普通人&#xff0c;看看它们谁最适合自己。 来百度APP畅享高清…

MySQL中的基本SQL语句

文章目录 MySQL中的基本SQL语句查看操作创建与删除数据库和表修改表格数据库用户管理 MySQL中的基本SQL语句 查看操作 1. 查看有哪些数据库 show databases; 2.切换数据库 use 数据库名;比如切换至 mysql数据库 use mysql;3.查看数据库中的表 show tables;4.查看表中…

干货分享 | TSMaster如何同时记录标定变量和DBC信号至BLF文件

客户在使用TSMaster软件标定功能时&#xff0c;有如下使用场景&#xff1a;将DBC文件中的信号与A2L文件中的标定变量同时记录在一个记录文件。针对此应用场景&#xff0c;TSMaster软件提供了一种方法来满足此需求。今天重点和大家分享一下关于TSMaster软件中同时记录标定变量和…

【计算机组成】计算机组成与结构(四)

上一篇&#xff1a;【计算机组成】计算机组成与结构&#xff08;三&#xff09; &#xff08;7&#xff09;存储系统 计算机采用分级存储体系的主要目的是为了解决存储容量、成本和速度之间的矛盾问题。 两级存储:cache-主存、主存-辅存(虚拟存储体系) 局部性原理 ◆ 局部性…

openssl 升级1.1.1.1k 到 3.0.13

下载 https://www.openssl.org/source/ tar -zxvf openssl-3.0.13.tar.gzcd openssl-3.0.13/./config enable-fips --prefix/usr/local --openssldir/usr/local/opensslmake && make install 将原有openssl备份 mv /usr/bin/openssl /usr/bin/openssl.bak mv /usr/i…

ElasTool v3.0 程序:材料弹性和机械性能的高效计算和可视化工具包

分享一个材料弹性和机械性能的高效计算和可视化工具包&#xff1a; ElasTool v3.0。 感谢论文的原作者&#xff01; 主要内容 “弹性和机械性能的高效计算和可视化对于材料的选择和新材料的设计至关重要。该工具包标志着材料弹性和机械性能计算分析和可视化方面的重大进步…

Linux脚本打开多个终端执行不懂程序(树莓派)

1、首先需要安装gnome-terminal sudo apt install gnome-terminal 2、然后编写一下代码到RunApp.sh脚本&#xff0c;设置窗口上方名字&#xff08;Qt5.8.0写的Server和Client&#xff09; #!/bin/sh gnome-terminal --title "Client" -- bash -c "./Client&q…

算法打卡day24|回溯法篇04|Leetcode 93.复原IP地址、78.子集、90.子集II

算法题 Leetcode 93.复原IP地址 题目链接:93.复原IP地址 大佬视频讲解&#xff1a;复原IP地址视频讲解 个人思路 这道题和昨天的分割回文串有点类似&#xff0c;但这里是限制了只能分割3次以及分割块的数字大小&#xff0c;根据这些不同的条件用回溯法解决就好啦 解法 回溯…

二维码门楼牌管理应用平台建设:提升城市管理效率的新路径

文章目录 前言一、二维码门楼牌管理应用平台的建设背景二、人工数据审核的重要性三、地址匹配校验的作用四、数据修改后的状态管理五、二维码门楼牌管理应用平台的未来展望 前言 随着城市管理的不断升级&#xff0c;二维码门楼牌管理应用平台正逐渐成为城市管理的新宠。本文将…

今天简单聊聊容器化

什么是容器化 容器化&#xff08;Containerization&#xff09;是一种软件开发和部署的方法&#xff0c;其核心思想是将应用程序及其所有依赖项打包到一个独立的运行环境中&#xff0c;这个环境被称为容器。容器化技术使得应用程序可以在不同的计算环境中以一致的方式运行&…