基于华为atlas的unet分割模型探索

Unet模型使用官方基于kaggle Carvana Image Masking Challenge数据集训练的模型。

模型输入为572*572*3,输出为572*572*2。分割目标分别为,0:背景,1:汽车。

Pytorch的pth模型转化onnx模型:

import torch

from unet import UNet

model = UNet(n_channels=3, n_classes=2, bilinear=False)
model = model.to(memory_format=torch.channels_last)

state_dict = torch.load("unet_carvana_scale1.0_epoch2.pth", map_location="cpu")
#del state_dict['mask_values']
model.load_state_dict(state_dict)

dummy_input = torch.randn(1, 3, 572, 572)

torch.onnx.export(model, dummy_input, "unet.onnx", verbose=True)

模型输入输出节点分析:

使用工具Netron查看模型结构,确定模型输入节点名称为input.1,输出节点名称为/outc/conv/Conv

onnx模型转化atlas模型:

atc --model=./unet.onnx --framework=5 --output=unet --soc_version=Ascend310P3  --input_shape="input.1:1,3,572,572" --output_type="/outc/conv/Conv:0:FP32" --out_nodes="/outc/conv/Conv:0"

推理代码实现:

import base64
import json
import os
import time

import numpy as np
import cv2

import MxpiDataType_pb2 as mxpi_data
from StreamManagerApi import InProtobufVector
from StreamManagerApi import MxProtobufIn
from StreamManagerApi import StreamManagerApi


def check_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)


class SDKInferWrapper:
    def __init__(self): # 完成初始化
        self._stream_name = None
        self._stream_mgr_api = StreamManagerApi()

        if self._stream_mgr_api.InitManager() != 0:
            raise RuntimeError("Failed to init stream manager.")

        pipeline_name = './nested_unet.pipeline'

        self.load_pipeline(pipeline_name)

        self.width = 572
        self.height = 572

    def load_pipeline(self, pipeline_path):
        with open(pipeline_path, 'r') as f:
            pipeline = json.load(f)

        self._stream_name = list(pipeline.keys())[0].encode() # 'unet_pytorch'
        if self._stream_mgr_api.CreateMultipleStreams(
                json.dumps(pipeline).encode()) != 0:
            raise RuntimeError("Failed to create stream.")

    def do_infer(self, img_bgr):

        # preprocess
        image = cv2.resize(img_bgr, (self.width, self.height))
        image = cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        image = image.astype('float32') / 255.0
        image = image.transpose(2, 0, 1)


        tensor_pkg_list = mxpi_data.MxpiTensorPackageList()
        tensor_pkg = tensor_pkg_list.tensorPackageVec.add()
        tensor_vec = tensor_pkg.tensorVec.add()
        tensor_vec.deviceId = 0
        tensor_vec.memType = 0

        for dim in [1, *image.shape]:
            tensor_vec.tensorShape.append(dim) # tensorshape属性为[1,3,572,572]

        input_data = image.tobytes()
        tensor_vec.dataStr = input_data
        tensor_vec.tensorDataSize = len(input_data)

        protobuf_vec = InProtobufVector()
        protobuf = MxProtobufIn()
        protobuf.key = b'appsrc0'
        protobuf.type = b'MxTools.MxpiTensorPackageList'
        protobuf.protobuf = tensor_pkg_list.SerializeToString()
        protobuf_vec.push_back(protobuf)

        unique_id = self._stream_mgr_api.SendProtobuf(
            self._stream_name, 0, protobuf_vec)

        if unique_id < 0:
            raise RuntimeError("Failed to send data to stream.")

        infer_result = self._stream_mgr_api.GetResult(
            self._stream_name, unique_id)

        if infer_result.errorCode != 0:
            raise RuntimeError(
                f"GetResult error. errorCode={infer_result.errorCode}, "
                f"errorMsg={infer_result.data.decode()}")
        
        output_tensor = self._parse_output_data(infer_result)
        output_tensor = np.squeeze(output_tensor)
        output_tensor = softmax(output_tensor)

        mask = np.argmax(output_tensor, axis =0)
        score = np.max(output_tensor, axis = 0)


        mask = cv2.resize(mask, [img_bgr.shape[1], img_bgr.shape[0]], interpolation=cv2.INTER_NEAREST)
        score = cv2.resize(score, [img_bgr.shape[1], img_bgr.shape[0]], interpolation=cv2.INTER_NEAREST)

        return mask, score



    def _parse_output_data(self, output_data):
        infer_result_data = json.loads(output_data.data.decode())
        content = json.loads(infer_result_data['metaData'][0]['content'])
        tensor_vec = content['tensorPackageVec'][0]['tensorVec'][0]
        data_str = tensor_vec['dataStr']
        tensor_shape = tensor_vec['tensorShape']
        infer_array = np.frombuffer(base64.b64decode(data_str), dtype=np.float32)
        return infer_array.reshape(tensor_shape)



    def draw(self, mask):
        color_lists = [(255, 0, 0), (0, 255, 0), (0, 0, 255)]

        drawed_img = np.stack([mask, mask, mask], axis = 2)
        for i in np.unique(mask):
            drawed_img[:,:,0][drawed_img[:,:,0]==i] = color_lists[i][0]
            drawed_img[:,:,1][drawed_img[:,:,1]==i] = color_lists[i][1]
            drawed_img[:,:,2][drawed_img[:,:,2]==i] = color_lists[i][2]

        return drawed_img

def softmax(x):
    exps = np.exp(x - np.max(x))
    return exps/np.sum(exps)



def sigmoid(x):
    y = x.copy()
    y[x >= 0] = 1.0 / (1 + np.exp(-x[x >= 0]))
    y[x < 0] = np.exp(x[x < 0]) / (1 + np.exp(x[x < 0]))
    return y



def check_dir(dir):
    if not os.path.exists(dir):
        os.makedirs(dir, exist_ok=True)


def test():
    dataset_dir = './sample_data'
    output_folder = "./infer_result"   
    os.makedirs(output_folder, exist_ok=True)

    sdk_infer = SDKInferWrapper()


    # read img
    image_name = "./sample_data/images/111.jpg"
    img_bgr = cv2.imread(image_name)
    
    # infer
    t1 = time.time()
    mask, score = sdk_infer.do_infer(img_bgr)
    t2 = time.time()
    print(t2-t1, mask, score)
    drawed_img = sdk_infer.draw(mask)
    cv2.imwrite("infer_result/draw.png", drawed_img)
    

if __name__ == "__main__":
    test()

运行代码:

set -e
. /usr/local/Ascend/ascend-toolkit/set_env.sh
# Simple log helper functions
info() { echo -e "\033[1;34m[INFO ][MxStream] $1\033[1;37m" ; }
warn() { echo >&2 -e "\033[1;31m[WARN ][MxStream] $1\033[1;37m" ; }

#export MX_SDK_HOME=/home/work/mxVision
export LD_LIBRARY_PATH=${MX_SDK_HOME}/lib:${MX_SDK_HOME}/opensource/lib:${MX_SDK_HOME}/opensource/lib64:/usr/local/Ascend/ascend-toolkit/latest/acllib/lib64:${LD_LIBRARY_PATH}
export GST_PLUGIN_SCANNER=${MX_SDK_HOME}/opensource/libexec/gstreamer-1.0/gst-plugin-scanner
export GST_PLUGIN_PATH=${MX_SDK_HOME}/opensource/lib/gstreamer-1.0:${MX_SDK_HOME}/lib/plugins

#to set PYTHONPATH, import the StreamManagerApi.py
export PYTHONPATH=$PYTHONPATH:${MX_SDK_HOME}/python

python3 unet.py
exit 0

运行效果:

个人思考:

华为atlas的参考案例细节不到位,步骤缺失较多,摸索困难,代码写法较差,信创化道路任重而道远。

参考资料:

GitHub - milesial/Pytorch-UNet: PyTorch implementation of the U-Net for image semantic segmentation with high quality images

https://gitee.com/ascend/samples/tree/master/python/level2_simple_inference/3_segmentation/unet++

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/430852.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

bun 单元测试

bun test Bun 附带了一个快速、内置、兼容 Jest 的测试运行程序。测试使用 Bun 运行时执行&#xff0c;并支持以下功能。 TypeScript 和 JSX生命周期 hooks快照测试UI 和 DOM 测试使用 --watch 的监视模式使用 --preload 预加载脚本 Bun 旨在与 Jest 兼容&#xff0c;但并非所…

北京Excel表格线下培训班

Excel培训目标 熟练掌握职场中Excel所需的公式函数计算&#xff0c;数据处理分析&#xff0c;各种商务图表制作、动态仪表盘的制作、熟练使用Excel进行数据分析&#xff0c;处理&#xff0c;从复杂的数据表中把数据进行提取汇总 Excel培训形式 线下面授5人以内小班&#xff…

分享Web.dev.cn中国开发者可以正常访问

谷歌开发者很高兴地宣布&#xff0c;web.dev 和 Chrome for Developers 现在都可以通过 .cn 域名访问&#xff0c;这将帮助中国的开发者更加容易获取我们的内容。 在 .cn 域名上&#xff0c;我们已向您提供所有镜像后的内容&#xff0c;并提供支持的语言版本。 Web.dev 中国开…

uipath调用js代码

1&#xff0c;调用js代码&#xff0c;不带参数&#xff0c;没有返回值 为了去掉按钮的disabled属性 function(){ document.getElementsByClassName(submitBtn)[0].removeAttribute(disabled); } 2&#xff0c;调用js代码&#xff0c;带参数&#xff0c;没有返回值 输入参数&a…

el-dialog封装组件

父页面 <template><div><el-button type"primary" click"visible true">展示弹窗</el-button><!-- 弹窗组件 --><PlayVideo v-if"visible" :visible.syncvisible /></div> </template><sc…

Python-Numpy-计算向量间的欧式距离

两个向量间的欧式距离公式&#xff1a; a np.array([[2, 2], [4, 5], [6, 7]]) b np.array([[1, 1]]) # 使用L2范数计算 dev1 np.linalg.norm(a - b, ord2, axis1) # 使用公式计算 dev2 np.sqrt(np.sum((a - b) ** 2, axis1)) print(dev1.reshape((-1, 1)), dev2.reshape((…

掌握WhatsApp手机号质量评分:增加信息可达性

WhatsApp手机号质量评分是用于衡量用户手机号与平台互动的健康度&#xff0c;确保用户通讯时的合规性和安全性。在实掌握WhatsApp手机号质量评分实际应用中&#xff0c;这个评分会影响用户的消息发送的可达性。高质量的评分意味着用户的账户被视为可信赖的&#xff0c;其发送的…

2024最新ChatGPT网站源码, AI绘画系统

一、前言说明 R5Ai创作系统是基于ChatGPT进行开发的Ai智能问答系统和Midjourney绘画系统&#xff0c;支持OpenAI-GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美&#xff0c;那么如何搭建部署AI创作ChatGPT&#xff1f;小编这里写一个详细图文教程吧。已支持GP…

平替电容笔推荐:2024五大高口碑电容笔机型别错过!

现在电容笔已成为许多人工作、学习和创作的重要配件之一&#xff0c;它可以很好的提高我们的书写、绘画效率&#xff0c;无纸化学习也能减轻我们书本重量&#xff0c;让学习更加高效&#xff0c;然而&#xff0c;市场上电容笔种类繁&#xff0c;也少不了一些质量不佳的产品&…

掼蛋“六必治”策略

“六必治”&#xff0c;即是指当对手手中只剩下六张牌的时候&#xff0c;我们不管是用炸弹还是登基牌还是其他大牌都要及时压制对手&#xff0c;夺得出牌权&#xff0c;不能让他再次出牌&#xff0c;防止他有一手整牌或者一炸加上一手牌。 对手剩六张牌&#xff0c;有以下几种情…

正大国际:期货结算价是如何理解呢?结算价有什么作用?

如何理解期货结算价&#xff1a; 什么是商品期货当日结算价&#xff0c; 商品期货当日结算价是指某一期货合约当日交易期间成交价格按成交量的加权平均价。当日 无成交的&#xff0c;当日结算价按照交易所相关规定确定。 股指期货当日结算价是指某一期货合约当日交易期间最后一…

The Design and Implementation of a Capacity-Variant Storage System——论文泛读

FAST 2024 Paper 分布式元数据论文整理 问题 随着SSD的使用&#xff0c;其性能稳步下降。如图1所示&#xff0c;SSD的性能随着SSD的磨损的下降率为4.2%&#xff0c;吞吐量下降不太可能是由于垃圾收集造成的&#xff0c;因为&#xff08;1&#xff09;这是几个月来每天测量的&…

手写分布式配置中心(二)实现分布式配置中心的简单版本

这一篇文章比较简单&#xff0c;就是一个增删改查的服务端和一个获取配置的客户端&#xff0c;旨在搭建一个简单的配置中心架构&#xff0c;代码在 https://gitee.com/summer-cat001/config-center 服务端 服务端选择用springboot 2.7.14搭建&#xff0c;设计了4个接口/confi…

Guava处理异常

guava由Google开发&#xff0c;它提供了大量的核心Java库&#xff0c;例如&#xff1a;集合、缓存、原生类型支持、并发库、通用注解、字符串处理和I/O操作等。 异常处理 传统的Java异常处理通常包括try-catch-finally块和throws关键字。 遇到FileNotFoundException或IOExce…

49、WEB攻防——通用漏洞业务逻辑水平垂直越权访问控制脆弱验证

文章目录 前置知识点水平越权——YXCMS 前置知识点 逻辑越权原理&#xff1a; 水平越权&#xff1a;同级用户权限共享。用户信息获取时未对用户与ID比较判断直接查询等&#xff1b;垂直越权&#xff1a;低高级用户权限共享。数据库中用户类型编号接受篡改或高权限未作验证等。 …

Unity 使用AddListener监听事件与取消监听

在Unity中&#xff0c;有时候我们会动态监听组件中的某个事件。当我们使用代码动态加载多次&#xff0c;每次动态加载后我们会发现原来的和新的事件都会监听&#xff0c;如若我们只想取代原来的监听事件&#xff0c;那么就需要取消监听再添加监听了。 如实现如下需求&#xff…

一加 Ace 3 原神刻晴定制机首销现象级火爆,京东天猫双平台火速售罄

3 月 5 日上午 10 点&#xff0c;一加 Ace 3 原神刻晴定制机正式开售&#xff0c;京东天猫双平台火速售罄。一加 Ace 3 原神刻晴定制机以打造2024行业深度定制新标杆为目标&#xff0c;凭借行业首创工艺、典藏级限定周边、深度的系统定制以及专业的游戏表现&#xff0c;一经发布…

elementUI el-table中的对齐问题

用elementUI时&#xff0c;遇到了一个无法对齐的问题&#xff1a;代码如下&#xff1a; <el-table :data"form.dataList" <el-table-column label"验收结论" prop"checkResult" width"200"> <template slot-sco…

少儿编程 中国电子学会C++等级考试一级历年真题答案解析【持续更新 已更新82题】

C 等级考试一级考纲说明 一、能力目标 通过本级考核的学生&#xff0c;能对 C 语言有基本的了解&#xff0c;会使用顺序结构、选择结构、循环结构编写程序&#xff0c;具体用计算思维的方式解决简单的问题。 二、考核目标 考核内容是根据软件开发所需要的技能和知识&#x…

Premiere Pro 2024:革新视频编辑,打造专业影视新纪元

在数字化时代&#xff0c;视频已经成为人们获取信息、娱乐消遣的重要媒介。对于视频制作者而言&#xff0c;拥有一款功能强大、易于操作的视频编辑软件至关重要。Premiere Pro 2024&#xff0c;作为Adobe旗下的旗舰视频编辑软件&#xff0c;凭借其卓越的性能和创新的特性&#…