从0到1制作单只鳌虾运动轨迹追踪软件

前言

需要准备windows10操作系统,python3.11.9,cuDNN8.9.2.26,CUDA11.8,paddleDetection2.7

流程:

  1. 准备数据集-澳洲鳌虾VOC数据集 
  2. 基于RT-DETR目标检测模型训练
  3. 导出onnx模型进行python部署
  4. 平滑滤波处理视频帧保留的物体质心坐标
  5. 基于pywebview为软件前端,falsk为软件后端制作UI
  6. 使用pyinstaller打包成exe
  7. 使用into setup生成安装包

本人代码禁止任何商业化用途,个人开发者随意。所有代码均开源

项目目录

XXX 项目总目录
    static 存放js静态文件
        plotly.js
    templates 存放html文件
        index.html
    temp 用户上传文件保存路径
    venv 虚拟环境
    main.py 主程序
    model.onnx 模型文件
    1.ico 打包的程序图标

准备数据集

点击下载澳洲鳌虾VOC数据集

下载后解压,文件目录为

data
    Annotations
        0.xml
        1.xml
        ...
    imgs
        0.jpg
        1.jpg
        ...
    lables.txt

然后使用如下的脚本把数据集划分为训练集和测试集

import os
import random
import shutil


def splitDatasets(images_dir,xmls_dir,train_dir,test_dir):

    if os.path.exists(train_dir):
        shutil.rmtree(train_dir)
        
    os.makedirs(train_dir)
    os.makedirs(train_dir+'/imgs')
    os.makedirs(train_dir+'/annotations')
        
    if os.path.exists(test_dir):
        shutil.rmtree(test_dir)
        
    os.makedirs(test_dir)
    os.makedirs(test_dir+'/imgs')
    os.makedirs(test_dir+'/annotations')
        
    images=os.listdir(images_dir)
    random.shuffle(images)

    split_index=int(0.9*len(images))

    train_images=images[:split_index]
    test_images=images[split_index:]

    with open(train_dir+'/train.txt','w') as file:
        for img in train_images:
            shutil.copy(os.path.join(images_dir,img),os.path.join(train_dir,'imgs',img))
            ann=img.replace('jpg','xml')
            shutil.copy(os.path.join(xmls_dir,ann),os.path.join(train_dir,'annotations',ann))
            line=os.path.join(train_dir,'imgs',img)+' '+os.path.join(train_dir,'annotations',ann)+'\n'
            file.write(line)

    with open(test_dir+'/test.txt','w') as file:
        for img in test_images:
            shutil.copy(os.path.join(images_dir,img),os.path.join(test_dir,'imgs',img))
            ann=img.replace('jpg','xml')
            shutil.copy(os.path.join(xmls_dir,ann),os.path.join(test_dir,'annotations',ann))
            line=os.path.join(test_dir,'imgs',img)+' '+os.path.join(test_dir,'annotations',ann)+'\n'
            file.write(line)
        
    shutil.rmtree(images_dir)
    shutil.rmtree(xmls_dir)
    
if __name__=='__main__':
    # 填写img文件夹所在绝对路径
    images_dir='/home/aistudio/work/voc/imgs'
    # 填写Annotations文件夹所在绝对路径
    xmls_dir='/home/aistudio/work/voc/Annotations'
    # 填写 训练集 的存放的绝对路径
    train_dir='/home/aistudio/work/voc/trains'
    # 填写 测试集 的存放的绝对路径
    test_dir='/home/aistudio/work/voc/tests'
    
    splitDatasets(images_dir,xmls_dir,train_dir,test_dir)

训练模型

可在aistudio云平台训练,我放好了所有的相关文件,点击进入,里面的说明很详细

也可在本地进行训练,下面来配置本地的训练环境

配置相关文件

下载paddleDetection2.7

原始目录如下

paddleDetection2.7
    .github
    .travis
    activity
    benchmark
    configs 模型配置文件
    dataset 里面有数据集下载的脚本文件
    demo
    deploy 推理的相关文件
    docs 说明文档
    industrial_tutorial
    ppdet 模型运行的核心文件
    scripts
    test_pic
    tools 模型训练入口,测试,验证,导出等脚本文件
    .gitignore
    .pre-commit-config.yaml
    .style.yapf
    .travis.yml
    LICENSE
    README_cn.md 说明文档中文版
    README_en.md 说明文档英文版
    requirements.txt 相关依赖库
    setup.py 模型编译的相关脚本

需要删除一些目录,把README_en.md改名为README.md,处理过的目录如下

paddleDetection2.7
    configs
    dataset
    deploy
    ppdet
    tools
    README.md
    requirements.txt
    setup.py

把dataset里所有东西都删除,再将划分好的数据集放到该文件下,处理好的目录如下

dataset
    voc
        trains
            annotations
            imgs
            train.txt
        tests
            annotations
            imgs
            test.txt
        labels.txt

进入tools目录,只保留如下文件,其余全删除,处理后的文件目录如下

tools
    train.py
    infer.py
    eval.py
    export_model.py

进入configs目录,只保留下面三个文件和目录,处理后的目录如下

configs
    datasets
    rtdetr
    runtime.yml

进入datasets目录,只保留voc.yml,其余文件全删除,处理后的目录如下

datasets
    voc.yml

并用如下内容覆盖voc.yml

metric: VOC
map_type: 11point
num_classes: 1

TrainDataset:
  name: VOCDataSet
  dataset_dir: dataset/voc
  anno_path: trains/train.txt
  label_list: labels.txt
  data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']

EvalDataset:
  name: VOCDataSet
  dataset_dir: dataset/voc
  anno_path: tests/test.txt
  label_list: labels.txt
  data_fields: ['image', 'gt_bbox', 'gt_class', 'difficult']

TestDataset:
  name: ImageFolder
  anno_path: dataset/labels.txt

进入rtdetr目录,只保留如下2个文件和目录,处理后的目录如下:

rtdetr
    _base_
    rtdetr_hgnetv2_x_6x_coco.yml

进入_base_目录,找到optimizer_6x.yml,修改第一行为epoch: 200,意思是训练200轮

找到rtdetr_reader.yml,根据自己的CPU和GPU调整相关参数,如果是4核CPU,worker_num可为8,batch_size根据显存调整,占用80%到90%的显存即可

安装依赖库

建议在虚拟环境中操作

!pip install -r requirements.txt
!pip install pycocotools
!pip install filterpy
!pip install flask
!pip install pyinstaller
!pip install pywebview
!pip install onnxruntime-gpu
!pip install onnxruntime
!pip install onnx
!pip install paddle2onnx
!python setup.py install

开始训练

建议命令行输入,先进入paddleDetection所在位置,再执行以下命令

python tools/train.py -c configs/rtdetr/rtdetr_hgnetv2_x_6x_coco.yml --eval --use_vdl True --vdl_log_dir vdl_log_dir/scalar

然后就是漫长的等待

导出模型

生成的模型在paddleDetection/output/best_model/model.pdparams

先进入paddleDetection所在位置,再执行以下命令

python tools/export_model.py -c configs/rtdetr/rtdetr_hgnetv2_x_6x_coco.yml -o weights=output/best_model/model.pdparams

转onnx

先进入paddleDetection所在位置,再执行以下命令,可以根据需要选择保存路径

paddle2onnx --model_dir=output_inference/rtdetr_hgnetv2_x_6x_coco/ \
            --model_filename model.pdmodel  \
            --params_filename model.pdiparams \
            --opset_version 16 \
            --save_file /home/work/infer/model.onnx

模型部署

导包

import webview
from flask import Flask, request, jsonify,render_template,stream_with_context,Response
import os
import time
import cv2
from onnxruntime import InferenceSession
import numpy as np
from werkzeug.utils import secure_filename

总览代码

class TrackShrimp():
    def __init__(self,video_path,model_path,onnx_threshold=0.7):
        # 获取帧数据
        self.cap=self.init_video(video_path)
        frame_width=int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height=int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 图形尺寸
        im_shape = np.array([[frame_height, frame_width]], dtype='float32')
        # y轴缩放量
        self.im_scale_y=640.0/frame_height
        # x轴缩放量
        self.im_scale_x=640.0/frame_width
        scale_factor = np.array([[self.im_scale_y,self.im_scale_x]]).astype('float32')
        # 定义模型输入
        self.inputs_dict = {
            'im_shape': im_shape,
            'image': None,
            'scale_factor': scale_factor
            }
        # 初始化模型
        self.sess=self.init_session(model_path)
        # 模型输出阈值
        self.onnx_threshold=onnx_threshold

    def init_video(self,input_path):
        cap=cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f'无法打开视频{input_path}')
        return cap

    def init_session(self,model_path):
        try:
            return InferenceSession(model_path, providers=['CUDAExecutionProvider']) 
        except:
            return InferenceSession(model_path, providers=['CPUExecutionProvider'])
        

    def precess_img(self,frame):
        img = cv2.resize(frame, None,None,fx=self.im_scale_x,fy=self.im_scale_y,interpolation=2)
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, [2, 0, 1])
        img = img[np.newaxis, :, :, :]
        return img
 
    def postcess(self,results:np.ndarray,all_centers:list[np.ndarray]):
        results=results[(results[:, 0] == 0) & (results[:, 1] > self.onnx_threshold)]
        x_centers = (results[:, 2] + results[:, 4]) / 2
        y_centers = (results[:, 3] + results[:, 5]) / 2
        centers = np.column_stack((x_centers, y_centers))
        all_centers.extend(centers)

    def by_smoothfilter(self,centers:list[np.ndarray],window_size=24):
        """
        :param centers: list[np.ndarray,np.ndarray,...]
        :param window_size: 平滑窗口大小
        :return: 平滑后的质心坐标NumPy数组
        """
        centers=np.stack(centers)
        # 计算滑动窗口的平均值,pad函数在序列前后补零以处理边界情况
        padded_centers = np.pad(centers, ((window_size//2, window_size//2), (0, 0)), mode='edge')
        window_sum = np.cumsum(padded_centers, axis=0)
        smoothed_centers = (window_sum[window_size:] - window_sum[:-window_size]) / window_size
        return smoothed_centers

    def calculate_distance(self,centers:np.ndarray):
        '''
        centers:np.ndarray n*2
        '''
        # 计算相邻点之间的差
        diffs = centers[1:] - centers[:-1]
        # 计算每个差值的欧几里得距离
        distances = np.linalg.norm(diffs, axis=1)
        # 计算总路程
        return int(np.sum(distances))

    def gain_position(self,centers:np.ndarray):
        position_list=centers.tolist()
        return position_list
    
    def run(self):
        global schedule
        global run_task
        # 帧数
        frame_count=int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number=0
        center_list=[]
        for frame_number in range(frame_count):
            if not run_task:
                return
            success, frame = self.cap.read()
            if not success:
                break
            schedule=int(frame_number/frame_count*100)
            # 打印进度
            if frame_number%10==0:
                print('Process: ',schedule)
            # 图片预处理
            img=self.precess_img(frame)
            self.inputs_dict['image']=img
            results=self.sess.run(None,self.inputs_dict)[0]
            if results is not None:
                self.postcess(results,center_list)
        # 使用平滑滤波
        filtered_centers = self.by_smoothfilter(center_list)
        self.cap.release()

        # 返回路程,轨迹坐标
        return self.calculate_distance(filtered_centers),self.gain_position(filtered_centers)

由于是对视频进行推理,所以首先得初始化视频打开的方法

def init_video(self,input_path):
        cap=cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f'无法打开视频{input_path}')
        return cap

初始化onnx运行引擎,优先使用显卡,如果CUDA环境有问题,就使用CPU运行

    def init_session(self,model_path):
        try:
            return InferenceSession(model_path, providers=['CUDAExecutionProvider']) 
        except:
            return InferenceSession(model_path, providers=['CPUExecutionProvider'])

onnx引擎需要一定的输入格式,放到类的init里

    def __init__(self,video_path,model_path,onnx_threshold=0.7):
        # 获取帧数据
        self.cap=self.init_video(video_path)
        frame_width=int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height=int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 图形尺寸
        im_shape = np.array([[frame_height, frame_width]], dtype='float32')
        # y轴缩放量
        self.im_scale_y=640.0/frame_height
        # x轴缩放量
        self.im_scale_x=640.0/frame_width
        scale_factor = np.array([[self.im_scale_y,self.im_scale_x]]).astype('float32')
        # 定义模型输入
        self.inputs_dict = {
            'im_shape': im_shape,
            'image': None,
            'scale_factor': scale_factor
            }
        # 初始化模型
        self.sess=self.init_session(model_path)
        # 模型输出阈值
        self.onnx_threshold=onnx_threshold

在提取每一帧后需要进行图像处理,resize图片为模型输入的要求,归一化

    def precess_img(self,frame):
        img = cv2.resize(frame, None,None,fx=self.im_scale_x,fy=self.im_scale_y,interpolation=2)
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, [2, 0, 1])
        img = img[np.newaxis, :, :, :]
        return img

在提取到视频的每一帧中的鳌虾的质心坐标后,由于每一帧的图像都不一样,输入模型后再输出的结果就不一样,会抖动,也就是噪声,我们需要滤波去噪,这里使用平滑滤波,相比卡尔曼滤波简单使用快速出结果。

    def by_smoothfilter(self,centers:list[np.ndarray],window_size=24):
        """
        :param centers: list[np.ndarray,np.ndarray,...]
        :param window_size: 平滑窗口大小
        :return: 平滑后的质心坐标NumPy数组
        """
        centers=np.stack(centers)
        # 计算滑动窗口的平均值,pad函数在序列前后补零以处理边界情况
        padded_centers = np.pad(centers, ((window_size//2, window_size//2), (0, 0)), mode='edge')
        window_sum = np.cumsum(padded_centers, axis=0)
        smoothed_centers = (window_sum[window_size:] - window_sum[:-window_size]) / window_size
        return smoothed_centers

我们需要计算鳌虾的运动总路程,用滤波后的质心坐标计算

    def calculate_distance(self,centers:np.ndarray):
        '''
        centers:np.ndarray n*2
        '''
        # 计算相邻点之间的差
        diffs = centers[1:] - centers[:-1]
        # 计算每个差值的欧几里得距离
        distances = np.linalg.norm(diffs, axis=1)
        # 计算总路程
        return int(np.sum(distances))

滤波后的质心坐标是numpy数组,需要一定的转换再发送到前端进行渲染(matplotlib画的图太丑了,不如plotly.js)

    def gain_position(self,centers:np.ndarray):
        position_list=centers.tolist()
        return position_list

在获取每一帧图像后,送入模型。模型会输出一对numpy数组,需要进行一对的后处理,低于阈值的就抛弃,然后取阈值最高的,计算质心坐标并保存

    def postcess(self,results:np.ndarray,all_centers:list[np.ndarray]):
        results=results[(results[:, 0] == 0) & (results[:, 1] > self.onnx_threshold)]
        x_centers = (results[:, 2] + results[:, 4]) / 2
        y_centers = (results[:, 3] + results[:, 5]) / 2
        centers = np.column_stack((x_centers, y_centers))
        all_centers.extend(centers)

需要在一个主函数里将上述打开视频,图像预处理,送入模型,后处理连起来

    def run(self):
        global schedule
        global run_task
        # 帧数
        frame_count=int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number=0
        center_list=[]
        for frame_number in range(frame_count):
            if not run_task:
                return
            success, frame = self.cap.read()
            if not success:
                break
            schedule=int(frame_number/frame_count*100)
            # 打印进度
            if frame_number%10==0:
                print('Process: ',schedule)
            # 图片预处理
            img=self.precess_img(frame)
            self.inputs_dict['image']=img
            results=self.sess.run(None,self.inputs_dict)[0]
            if results is not None:
                self.postcess(results,center_list)
        # 使用平滑滤波
        filtered_centers = self.by_smoothfilter(center_list)
        self.cap.release()

        # 返回路程,轨迹坐标
        return self.calculate_distance(filtered_centers),self.gain_position(filtered_centers)

前端的设计

以pywebview为平台,html和css设计前端

 

 

 

代码总览

index.html

<!DOCTYPE html>
<html>
<head>
    <title></title>
    <link rel="shortcut icon" href="#" />
    <script src="https://code.jquery.com/jquery-3.6.0.min.js"></script>
    <script src="../static/plotly.js"></script>
    <style>
        html,body{
            width: 100%;
            height: 100%;
            margin: 0 auto;

        }
        body{
            display: flex;
            align-items: center;
            justify-content: center;
            height: 100vh;
            background-color: rgb(6, 32, 80);
        }
        main{
            display: grid;
            grid-template-columns: 1fr 3fr;
            column-gap: 2%;
            width: 98%;
            height: 98%;
        }
        fieldset{
            border: 2px solid rgb(32, 139, 139);
            color: rgb(32, 139, 139);
            margin: 8% 0 8% 0;
        }
        #s2{
            text-align: center;
            display: flex;
            justify-content: center;
            align-items: center;
            background-color: rgba(32, 139, 139, 0.301);
            border: 2px solid rgb(32, 139, 139);
        }
        #progress-circle{
            border: 1em solid rgb(32, 139, 139);
            width: 40vh;
            height: 40vh;
            border-radius: 20vh;
            display: flex;  
            justify-content: center; 
            align-items: center;
        }
        #progress-num{
            font-size: 18vh;
            color: rgb(32, 139, 139);
        }

    </style>
</head>
<body>
    <main>
        <section id="s1">
            <form id="form" enctype="multipart/form-data">
                <fieldset>
                    <legend>选择你要检测的视频</legend>
                    <input type="file" accept=".mp4" id="video" name="vedio">
                </fieldset>
                <fieldset>
                    <legend>功能按键</legend>
                    <button onclick="submit_to()" id="submit">开始上传</button>
                    <button onclick="stopRun()">终止运行</button>
                </fieldset>
            </form>
            <script>
                async function stopRun(){
                    try{
                        const response=await fetch('/stopRun',{method:'POST'})
                        if (!response.ok) {  
                            throw new Error('Network response was not ok.');  
                        }
                        data=await response.json()
                        alert(data.data)
                    }catch(error){
                        console.log(error)
                    }
                }
                
                async function submit_to(){
                    // 防重复激发
                    const button = document.getElementById('submit');  
                    button.disabled = true;
                    try{
                        // 获取文件
                        const input=document.getElementById('video')
                        const file=input.files[0]
                        if (!file){
                            throw new Error('未选择文件')
                        }
                        if(file.type!=='video/mp4'){
                            throw new Error('请选择MP4文件')
                        }
                        // 刷新界面 
                        const s2=document.getElementById('s2')
                        Plotly.purge(s2)
                        // 初始化进度显示
                        const progressCircle=document.getElementById('progress-circle')
                        const progressNum=document.getElementById('progress-num')
                        progressCircle.style.display='flex'
                        progressNum.innerHTML='0%'
                        // 更新进度
                        let source = new EventSource("/progress")
                        source.onmessage = function(event) {
                        progressNum.innerHTML = event.data+'%'
                        }
                        // 发送请求
                        const formData=new FormData()
                        formData.append('video', file)
                        const response=await fetch('/shrimp',{method:'POST',body:formData})
                        if (!response.ok) {
                            throw new Error('Network response was not ok.');  
                        }
                        source.close()
                        const data=await response.json()
                        button.disabled=false
                        if(data.data==='任务被终止'){
                            alert(data.data)
                        }
                        else{
                            progressCircle.style.display='none'
                            $('#distance').text('总路程'+data.distance)
                            // 画图
                            var trace=[{
                                x: data.position_data.map(item=>item[0]),
                                y: data.position_data.map(item=>item[1]),
                                mode:"lines",
                                line:{
                                        color:'rgb(32, 139, 139)'
                                    }
                            }]
                            var layout = {
                                xaxis: {
                                    range: [0, 600],
                                    title: "x(像素)",
                                    titlefont: {  
                                        color: 'rgb(32, 139, 139)' // 轴标签颜色  
                                    },  
                                    linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
                                    tickfont: {  
                                        color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
                                    }
                                },
                                yaxis: {range: [0, 600],
                                    title: "y(像素)",
                                    titlefont: {  
                                        color: 'rgb(32, 139, 139)' // 轴标签颜色  
                                    },  
                                    linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
                                    tickfont: {  
                                        color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
                                    }
                                },  
                                title: "鳌虾运动轨迹",
                                titlefont:{
                                    color:'rgb(32, 139, 139)'
                                },
                                plot_bgcolor: 'rgba(0,0,0,0)',
                                paper_bgcolor:'rgba(0,0,0,0)'
                                }
                            Plotly.newPlot("s2", trace, layout,{scrollZoom: true,editable: true }) 
                        }
                    }catch(error){
                        button.disabled = false
                        if(error.message.startsWith('Failed to fetch')){}
                        else{alert(error)}
                    }
                }
            
            </script>

            <fieldset>
                <legend>输出结果</legend>
                <P id="distance">总路程:</P>
            </fieldset>
            <fieldset>
                <legend>注意事项</legend>
                <p>本程序运行将消耗大量算力和内存,最好使用高配电脑。不支持windows10以下的操作系统。在后台有任务在跑时,切勿重复上传视频,
                    等待后台跑完出图时再上传新的视频。如果选错视频并上传了,请点击'终止运行'再重新上传视频。有问题联系wx:m989783106</p>
            </fieldset>
        </section>

        <section id="s2">
            <div id="progress-circle"><p id="progress-num"></p></div>
        </section>
    </main>
</body>
</html>

plotly.js从官网下载

代码分览

总体设计是以<html>和<body>为底,<main>为主容器内使用grid2列布局,2个<section>作为内容器占据左右2个网格。

左边的<section>容纳文件上传表单,功能按钮,数据显示,使用说明

        <section id="s1">
            <form id="form" enctype="multipart/form-data">
                <fieldset>
                    <legend>选择你要检测的视频</legend>
                    <input type="file" accept=".mp4" id="video" name="vedio">
                </fieldset>
                <fieldset>
                    <legend>功能按键</legend>
                    <button onclick="submit_to()" id="submit">开始上传</button>
                    <button onclick="stopRun()">终止运行</button>
                </fieldset>
            </form>

            <fieldset>
                <legend>输出结果</legend>
                <P id="distance">总路程:</P>
            </fieldset>
            <fieldset>
                <legend>注意事项</legend>
                <p>本程序运行将消耗大量算力和内存,最好使用高配电脑。不支持windows10以下的操作系统。在后台有任务在跑时,切勿重复上传视频,
                    等待后台跑完出图时再上传新的视频。如果选错视频并上传了,请点击'终止运行'再重新上传视频。有问题联系wx:m989783106</p>
            </fieldset>
        </section>

之间用<fieldset>做了区域划分,简单又美观。

<button>均使用onclick属性进行触发

在上传前会检测用户是否选择文件,是否选择的是MP4文件

// 获取文件
const input=document.getElementById('video')
const file=input.files[0]
if (!file){
    throw new Error('未选择文件')
}
if(file.type!=='video/mp4'){
    throw new Error('请选择MP4文件')
}

 一共有3个请求:

  • 请求上传文件,将MP4上传给后端,然后后端运行模型发送质心坐标给前端渲染
  • 请求终止程序,当用户想终止后端运行模型,重新上传文件时
  • 请求获取模型处理进度,后端返回进度给前端,前端进行渲染展示

画轨迹图,前端用plotly.js将质心坐标进行渲染,同时轨迹图还有一定的交互能力。

// 画图
var trace=[{
    x: data.position_data.map(item=>item[0]),
    y: data.position_data.map(item=>item[1]),
    mode:"lines",
    line:{
            color:'rgb(32, 139, 139)'
        }
}]
var layout = {
    xaxis: {
        range: [0, 600],
        title: "x(像素)",
        titlefont: {  
            color: 'rgb(32, 139, 139)' // 轴标签颜色  
        },  
        linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
        tickfont: {  
            color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
        }
    },
    yaxis: {range: [0, 600],
        title: "y(像素)",
        titlefont: {  
            color: 'rgb(32, 139, 139)' // 轴标签颜色  
        },  
        linecolor: 'rgb(32, 139, 139)', // 轴线颜色  
        tickfont: {  
            color: 'lrgb(32, 139, 139)' // 轴刻度标签颜色  
        }
    },  
    title: "鳌虾运动轨迹",
    titlefont:{
        color:'rgb(32, 139, 139)'
    },
    plot_bgcolor: 'rgba(0,0,0,0)',
    paper_bgcolor:'rgba(0,0,0,0)'
    }
Plotly.newPlot("s2", trace, layout,{scrollZoom: true,editable: true })

其余的就是代码的排布顺序,异步执行调度,错误处理能力,系统稳定性,用户交互能力的提升,细节很多,均包含在代码中


右边的<section>容纳进度圈,轨迹图

        <section id="s2">
            <div id="progress-circle"><p id="progress-num"></p></div>
        </section>

在文件上传时,就初始化渲染进度条,然后异步请求获取进度,渲染到页面;当进度到达一定值,比如99%,就关闭获取进度的请求,同时设置进度条的display=none。当用户打断程序执行或者重新运行程序,就清理轨迹图,初始化进度条,循环往复。

后端设计

后端整体使用flask,jinjia模板,将flask与pywebview结合。把模型检测代码封装到一个类TrackShrimp,其余的就是各种请求函数。

代码总览

import webview
from flask import Flask, request, jsonify,render_template,stream_with_context,Response
import os
import time
import cv2
from onnxruntime import InferenceSession
import numpy as np
from werkzeug.utils import secure_filename

class TrackShrimp():
    def __init__(self,video_path,model_path,onnx_threshold=0.7):
        # 获取帧数据
        self.cap=self.init_video(video_path)
        frame_width=int(self.cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height=int(self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        # 图形尺寸
        im_shape = np.array([[frame_height, frame_width]], dtype='float32')
        # y轴缩放量
        self.im_scale_y=640.0/frame_height
        # x轴缩放量
        self.im_scale_x=640.0/frame_width
        scale_factor = np.array([[self.im_scale_y,self.im_scale_x]]).astype('float32')
        # 定义模型输入
        self.inputs_dict = {
            'im_shape': im_shape,
            'image': None,
            'scale_factor': scale_factor
            }
        # 初始化模型
        self.sess=self.init_session(model_path)
        # 模型输出阈值
        self.onnx_threshold=onnx_threshold

    def init_video(self,input_path):
        cap=cv2.VideoCapture(input_path)
        if not cap.isOpened():
            raise ValueError(f'无法打开视频{input_path}')
        return cap

    def init_session(self,model_path):
        try:
            return InferenceSession(model_path, providers=['CUDAExecutionProvider']) 
        except:
            return InferenceSession(model_path, providers=['CPUExecutionProvider'])
        

    def precess_img(self,frame):
        img = cv2.resize(frame, None,None,fx=self.im_scale_x,fy=self.im_scale_y,interpolation=2)
        img = img.astype(np.float32) / 255.0
        img = np.transpose(img, [2, 0, 1])
        img = img[np.newaxis, :, :, :]
        return img
 
    def postcess(self,results:np.ndarray,all_centers:list[np.ndarray]):
        results=results[(results[:, 0] == 0) & (results[:, 1] > self.onnx_threshold)]
        x_centers = (results[:, 2] + results[:, 4]) / 2
        y_centers = (results[:, 3] + results[:, 5]) / 2
        centers = np.column_stack((x_centers, y_centers))
        all_centers.extend(centers)

    def by_smoothfilter(self,centers:list[np.ndarray],window_size=24):
        """
        :param centers: list[np.ndarray,np.ndarray,...]
        :param window_size: 平滑窗口大小
        :return: 平滑后的质心坐标NumPy数组
        """
        centers=np.stack(centers)
        # 计算滑动窗口的平均值,pad函数在序列前后补零以处理边界情况
        padded_centers = np.pad(centers, ((window_size//2, window_size//2), (0, 0)), mode='edge')
        window_sum = np.cumsum(padded_centers, axis=0)
        smoothed_centers = (window_sum[window_size:] - window_sum[:-window_size]) / window_size
        return smoothed_centers

    def calculate_distance(self,centers:np.ndarray):
        '''
        centers:np.ndarray n*2
        '''
        # 计算相邻点之间的差
        diffs = centers[1:] - centers[:-1]
        # 计算每个差值的欧几里得距离
        distances = np.linalg.norm(diffs, axis=1)
        # 计算总路程
        return int(np.sum(distances))

    def gain_position(self,centers:np.ndarray):
        position_list=centers.tolist()
        return position_list
    
    def run(self):
        global schedule
        global run_task
        # 帧数
        frame_count=int(self.cap.get(cv2.CAP_PROP_FRAME_COUNT))
        frame_number=0
        center_list=[]
        for frame_number in range(frame_count):
            if not run_task:
                return
            success, frame = self.cap.read()
            if not success:
                break
            schedule=int(frame_number/frame_count*100)
            # 打印进度
            if frame_number%10==0:
                print('Process: ',schedule)
            # 图片预处理
            img=self.precess_img(frame)
            self.inputs_dict['image']=img
            results=self.sess.run(None,self.inputs_dict)[0]
            if results is not None:
                self.postcess(results,center_list)
        # 使用平滑滤波
        filtered_centers = self.by_smoothfilter(center_list)
        self.cap.release()

        # 返回路程,轨迹坐标
        return self.calculate_distance(filtered_centers),self.gain_position(filtered_centers)


app = Flask(__name__)
UPLOAD_FOLDER = './temp'
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
schedule=0
run_task=False

def run_flask():
    app.run(debug=False, threaded=True,host='127.0.0.1',port=5000)

def video_process(video_path):
    return TrackShrimp(video_path,'./model.onnx').run()

# 主页面
@app.route('/',methods=['POST','GET'])
def return_main_page():
    return render_template('index.html')

# 检测视频页面
@app.route('/shrimp',methods=['POST'])
def shrimp_track():
    global run_task
    run_task=True
    file=request.files.get('video')
    filename = secure_filename(file.filename)
    video_path=os.path.join(app.config['UPLOAD_FOLDER'],filename)
    file.save(video_path)
    try:
        results=video_process(video_path)
        if results is not None:
            distance,position_data=results
            data = {
                'distance': distance,
                'position_data': position_data
            }
            run_task=False
            return jsonify(data)
        else:
            return jsonify({'data':'任务被终止'})
    except Exception as e:
        print('error:',e)
        return jsonify({'data':'任务被终止'})
    finally:
        if os.path.exists(video_path):
            os.remove(video_path)

@app.route('/stopRun',methods=['GET','POST'])
def stopRun():
    global run_task
    global schedule
    if run_task:
        run_task=False
        schedule=0
        return jsonify({'data':'正在停止任务'})
    else:
        return jsonify({'data':'当前没有任务运行'})

# 进度查询路由
@app.route('/progress',methods=['GET'])
def progress():
    @stream_with_context
    def generate():
        global run_task
        ratio = schedule
        while ratio < 95 and run_task:
            yield "data:" + str(ratio) + "\n\n"
            ratio = schedule
            time.sleep(5)
    return Response(generate(), mimetype='text/event-stream')

if __name__=='__main__':
    # 启动后端  
    # flask_thread = threading.Thread(target=run_flask)  
    # flask_thread.start()
    # time.sleep(1)
    # 启动前端
    webview.create_window('鳌虾轨迹侦测',url=app,width=900,height=600)
    # webview.create_window('鳌虾轨迹侦测',url=f'http://127.0.0.1:5000',width=900,height=600)
    webview.start()

代码分览

一个onnx部署的类TrackShrimp,详细见前面。

一些常量的定义

app = Flask(__name__)
UPLOAD_FOLDER = './temp' # 文件的上传路径,后端需要该路径保留用户上传的文件
app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
schedule=0 # 实时进度,初始化进度为0
run_task=False # 一个onnx模型是否在运行的标志,用于接收用户中断信号从而终止模型运行

定义一个flask的·启动函数,用于web调试,浏览器F12启动调试窗口

def run_flask():
    app.run(debug=False, threaded=True,host='127.0.0.1',port=5000)

主页面的请求函数,该页面为主要的UI

# 主页面
@app.route('/',methods=['POST','GET'])
def return_main_page():
    return render_template('index.html')

用户请求中断的请求函数

首先通过标志位(run_task)检测模型是否在跑,如果检测到模型正在运行,就把标志位设为False,然后把进度归0

@app.route('/stopRun',methods=['GET','POST'])
def stopRun():
    global run_task
    global schedule
    if run_task:
        run_task=False
        schedule=0
        return jsonify({'data':'正在停止任务'})
    else:
        return jsonify({'data':'当前没有任务运行'})

进度查询

这里设置当进度为95%时,就停止查询。

@app.route('/progress',methods=['GET'])
def progress():
    @stream_with_context
    def generate():
        global run_task
        ratio = schedule
        while ratio < 95 and run_task:
            yield "data:" + str(ratio) + "\n\n"
            ratio = schedule
            time.sleep(5)
    return Response(generate(), mimetype='text/event-stream')

一个检测的入口函数

def video_process(video_path):
    return TrackShrimp(video_path,'./model.onnx').run()

接收用户上传文件的函数

一旦用户上传文件,就设置运行标志位为True,然后将文件保存,再送入模型运行接口函数,当用户请求终止时,results为None,所以使用if else进行区分。模型结果出来后就把标志位设为False,同时将数据传到前端

@app.route('/shrimp',methods=['POST'])
def shrimp_track():
    global run_task
    run_task=True
    file=request.files.get('video')
    filename = secure_filename(file.filename)
    video_path=os.path.join(app.config['UPLOAD_FOLDER'],filename)
    file.save(video_path)
    try:
        results=video_process(video_path)
        if results is not None:
            distance,position_data=results
            data = {
                'distance': distance,
                'position_data': position_data
            }
            run_task=False
            return jsonify(data)
        else:
            return jsonify({'data':'任务被终止'})
    except Exception as e:
        print('error:',e)
        return jsonify({'data':'任务被终止'})
    finally:
        if os.path.exists(video_path):
            os.remove(video_path)

接着就是启动所有代码了,为了调试方便,我写了2份代码,一份用于调试,一份用于成品

if __name__=='__main__':
    # 启动前端
    webview.create_window('鳌虾轨迹侦测',url=app,width=900,height=600)
    webview.start()
if __name__=='__main__':
    # 启动后端  
    flask_thread = threading.Thread(target=run_flask)  
    flask_thread.start()
    time.sleep(1)
    # 启动前端
    webview.create_window('鳌虾轨迹侦测',url=f'http://127.0.0.1:5000',width=900,height=600)
    webview.start()

pyinstaller打包

进入项目目录,命令行输入

piinstaller -D -w main.py

找到生成的main.spec文件,按如下修改

# -*- mode: python ; coding: utf-8 -*-


a = Analysis(
    ['main.py'],
    pathex=[],
    binaries=[],
    datas=[('templates/','templates/'),('static/','static/'),('venv/Lib/site-packages/onnxruntime/capi/onnxruntime_providers_shared.dll','onnxruntime/capi/'),('venv/Lib/site-packages/onnxruntime/capi/onnxruntime_providers_cuda.dll','onnxruntime/capi/')],
    hiddenimports=[],
    hookspath=[],
    hooksconfig={},
    runtime_hooks=[],
    excludes=[],
    noarchive=False,
    optimize=0,
)
pyz = PYZ(a.pure)

exe = EXE(
    pyz,
    a.scripts,
    [],
    exclude_binaries=True,
    name='main',
    debug=False,
    bootloader_ignore_signals=False,
    strip=False,
    upx=True,
    console=False,
    disable_windowed_traceback=False,
    argv_emulation=False,
    target_arch=None,
    codesign_identity=None,
    entitlements_file=None,
    icon='1.ico'
)
coll = COLLECT(
    exe,
    a.binaries,
    a.datas,
    strip=False,
    upx=True,
    upx_exclude=[],
    name='main',
)

在项目目录下放置一个图标命名为1.ico,最好是48*48像素

然后命令行运行

pyinstaller main.spec

然后在venv中找到 onnxruntime_gpu-1.18.1.dist-info 文件夹,复制到 dist/main/_internal 中

同时在cuDNN中找到如下几个动态链接库,复制到 dist/main/_internal 中

cudnn_ops_infer64_8.dll
cudnn_cnn_infer64_8.dll
cudnn_adv_infer64_8.dll
cudnn64_8.dll
cudart64_110.dll
cublasLt64_11.dll
cublas64_11.dll
cufft64_10.dll

然后将model.onnx放到 dist/main/ ,并在该目录创建一个目录temp

最后处理的结果如下

XXX
    dist
        main
            _internal
            main.exe
            model.onnx
            temp

生成安装包

使用into setup软件,并在网站找到中文的语言包下载为 Chinese.isl 文件,放到intosetup软件安装目录的 Languages 文件夹下

接着如图所示

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

 

取消立即编译,先进入文件里修改一些东西

修改成下面这样 

点击编译

然后就生成了安装包,就可以在任何win10,win11电脑里用CPU跑了,如果安装的电脑 有显卡和CUDA并把CUDA添加到了环境变量,就可以用GPU跑了

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/780264.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

数字化精益生产系统--QMS质量管理系统

QMS质量管理系统&#xff08;Quality Management System&#xff09;是现代企业管理的关键组成部分&#xff0c;旨在确保产品和服务的质量达到或超过客户需求和期望。 以下是对QMS质量管理系统的功能设计&#xff1a;

ip地址突然变了一个城市怎么办

在数字化日益深入的今天&#xff0c;IP地址不仅是网络连接的标识&#xff0c;更是我们网络行为的“身份证”。然而&#xff0c;当您突然发现您的IP地址从一个城市跳转到另一个城市时&#xff0c;这可能会引发一系列的疑问和担忧。本文将带您深入了解IP地址突变的可能原因&#…

软件系统架构的一些常见专业术语

分层架构是逻辑上的&#xff0c;在物理部署上&#xff0c;三层结构可以部署在同一个物理机器上&#xff0c;但是随着网站业务的发展&#xff0c;必然需要对已经分层的模块分离部署&#xff0c;即三层结构分别部署在不同的服务器上&#xff0c;使网站拥有更多的计算资源以应对越…

信号与系统笔记分享

文章目录 一、导论信号分类周期问题能量信号和功率信号系统的线性判断时变&#xff0c;时不变系统因果系统判断记忆性系统判断稳定性系统判断 二、信号时域分析阶跃函数冲激函数取样性质四种特性1 筛选特性2 抽样特性3 展缩特性4 卷积特性卷积作用 冲激偶函数奇函数性质公式推导…

Java版Flink使用指南——安装Flink和使用IntelliJ制作任务包

大纲 安装Flink操作系统安装JDK安装Flink修改配置启动Flink测试 使用IntelliJ制作任务包新建工程Archetype 编写测试代码打包测试 参考资料 在《0基础学习PyFlink》专题中&#xff0c;我们熟悉了Flink的相关知识以及Python编码方案。这个系列我们将使用相对主流的Java语言&…

C++基础(十一):STL简介

从今天开始&#xff0c;我们正式步入STL的学习&#xff0c;STL&#xff08;标准模板库&#xff0c;Standard Template Library&#xff09;是C标准库的重要组成部分&#xff0c;提供了一系列通用的类和函数模板&#xff0c;包括容器、算法、迭代器等。它的设计极大地提高了代码…

中国科学技术大学发布了2024年少年班录取名单

7月7日&#xff0c;中国科学技术大学发布了2024年少年班录取名单公示&#xff0c;来自上海的12岁“小孩哥”刘尧进入名单。 据澎湃新闻此前报道&#xff0c;刘尧是因为此前通过了中科大少年班的校测考试&#xff0c;提前拿到了“高考体验券”。他所在的上海市实验学校&#xff…

柳叶刀:5Kg负重巡飞无人机技术详解

一、引言 随着无人机技术的不断发展&#xff0c;巡飞无人机在军事侦察、环境监测、边境巡逻等领域的应用日益广泛。其中&#xff0c;“柳叶刀”作为一款5Kg负重巡飞无人机&#xff0c;凭借其独特的机体结构、高效的动力系统、先进的飞行控制系统等技术优势&#xff0c;在众多无…

【位运算】基础算法总结

目录 基础位运算给一个数n&#xff0c;确定它的二进制表示的第x位是0还是1将一个数n的二进制表示的第x位修改成1将一个数n的二进制表示的第x位修改成0位图思想&#xff08;哈希表&#xff09;提取一个数&#xff08;n&#xff09;二进制表示中的最右侧的1&#xff08;lowbit&am…

KIVY 3D Rotating Monkey Head¶

7 Python Kivy Projects (With Full Tutorials) – Pythonista Planet KIVY 3D Rotating Monkey Head kivy 3D 旋转猴子头How to display rotating monkey example in a given layout. Issue #6688 kivy/kivy GitHub 3d 模型下载链接 P99 - Download Free 3D model by …

vue学习笔记(购物车小案例)

用一个简单的购物车demo来回顾一下其中需要注意的细节。 先看一下最终效果 功能&#xff1a; &#xff08;1&#xff09;全选按钮和下面的商品项的选中状态同步&#xff0c;当下面的商品全部选中时&#xff0c;全选勾选&#xff0c;反之&#xff0c;则不勾选。 &#xff08…

前端扫盲:cookie、localStorage和sessionStorage

cookie、localStorage和sessionStorage都是存储数据的方式&#xff0c;他们之间有什么不同&#xff0c;各有什么应用场景&#xff0c;本文为您一一解答。 一、什么是cookie、localStorage和sessionStorage 1. Cookie是一种存储在用户计算机上的小型文本文件&#xff0c;由服务…

和干瘪的列表说拜拜,看看卡片列表的精彩演绎

在移动UI设计中&#xff0c;卡片列表是一种常见的设计模式&#xff0c;可以将干瘪的列表变得更加生动和精彩。卡片列表通过使用卡片元素来呈现列表项&#xff0c;每个卡片可以包含图片、标题、描述、按钮等内容&#xff0c;使得列表项更加丰富和有趣。 以下是一些卡片列表的精彩…

网络防御保护——网络安全概述

一.网络安全概念 1.网络空间---一个由信息基础设施组成相互依赖的网络 。 网络空间&#xff0c;它跟以前我们所理解的网络不一样了&#xff0c;它不光是一个虚无缥缈的&#xff0c;虚拟的东西&#xff0c;它更多的是融入了我们这些真实的物理设备&#xff0c;也就意味着这个网…

数据库作业2

需求 一、在数据库中创建一个表student&#xff0c;用于存储学生信息 CREATE TABLE student( id INT PRIMARY KEY, name VARCHAR(20) NOT NULL, grade FLOAT ); 1、向student表中添加一条新记录 记录中id字段的值为1&#xff0c;name字段的值为"monkey"&#xff0…

1分钟了解LangChain是什么?

一: LangChain介绍 LangChain 是一个基于大型语言模型&#xff08;LLM&#xff09;开发应用程序的框架, 它旨在简化语言模型应用的开发流程&#xff0c;特别是在构建对话系统和其他基于语言的AI解决方案时.目标是将复杂的语言模型技术转化为可通过简单API调用实现的功能&#…

第T4周:使用TensorFlow实现猴痘病识别

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 文章目录 一、前期工作1.设置GPU&#xff08;如果使用的是CPU可以忽略这步&#xff09;2. 导入数据3. 查看数据 二、数据预处理1、加载数据2、数据可视化3、再…

Splunk Enterprise 中的严重漏洞允许远程执行代码

Splunk 是搜索、监控和分析机器生成大数据的软件领先提供商&#xff0c;为其旗舰产品 Splunk Enterprise 发布了紧急安全更新。 这些更新解决了几个构成重大安全风险的关键漏洞&#xff0c;包括远程代码执行 (RCE) 的可能性。 受影响的版本包括 * 9.0.x、9.1.x 和 9.2.x&…

竞赛 深度学习OCR中文识别 - opencv python

文章目录 0 前言1 课题背景2 实现效果3 文本区域检测网络-CTPN4 文本识别网络-CRNN5 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; **基于深度学习OCR中文识别系统 ** 该项目较为新颖&#xff0c;适合作为竞赛课题方向&#xff0c;…

STM32-SPI和W25Q64

本内容基于江协科技STM32视频学习之后整理而得。 文章目录 1. SPI&#xff08;串行外设接口&#xff09;通信1.1 SPI通信简介1.2 硬件电路1.3 移位示意图1.4 SPI时序基本单元1.5 SPI时序1.5.1 发送指令1.5.2 指定地址写1.5.3 指定地址读 2. W25Q642.1 W25Q64简介2.2 硬件电路2…