C# yolov8 TensorRT +ByteTrack Demo

C# yolov8 TensorRT +ByteTrack  Demo

目录

效果

说明 

项目

代码

Form2.cs

YoloV8.cs

ByteTracker.cs

下载

参考 


效果

说明 

环境

NVIDIA GeForce RTX 4060 Laptop GPU

cuda12.1+cudnn 8.8.1+TensorRT-8.6.1.6

版本和我不一致的需要重新编译TensorRtExtern.dll,TensorRtExtern源码地址:TensorRT-CSharp-API/src/TensorRtExtern at TensorRtSharp2.0 · guojin-yan/TensorRT-CSharp-API · GitHub

Windows版 CUDA安装参考:Windows版 CUDA安装_win cuda安装-CSDN博客

项目

代码

Form2.cs

using ByteTrack;
using OpenCvSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.IO;
using System.Threading;
using System.Windows.Forms;
using TensorRtSharp.Custom;

namespace yolov8_TensorRT_Demo
{
    public partial class Form2 : Form
    {
        public Form2()
        {
            InitializeComponent();
        }

        string imgFilter = "*.*|*.bmp;*.jpg;*.jpeg;*.tiff;*.tiff;*.png";

        YoloV8 yoloV8;
        Mat image;

        string image_path = "";
        string model_path;

        string video_path = "";
        string videoFilter = "*.mp4|*.mp4;";
        VideoCapture vcapture;
        VideoWriter vwriter;
        bool saveDetVideo = false;
        ByteTracker tracker;


        /// <summary>
        /// 单图推理
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button2_Click(object sender, EventArgs e)
        {

            if (image_path == "")
            {
                return;
            }

            button2.Enabled = false;
            pictureBox2.Image = null;
            textBox1.Text = "";

            Application.DoEvents();

            image = new Mat(image_path);

            List<DetectionResult> detResults = yoloV8.Detect(image);

            //绘制结果
            Mat result_image = image.Clone();
            foreach (DetectionResult r in detResults)
            {
                Cv2.PutText(result_image, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.Rectangle(result_image, r.Rect, Scalar.Red, thickness: 2);
            }

            if (pictureBox2.Image != null)
            {
                pictureBox2.Image.Dispose();
            }
            pictureBox2.Image = new Bitmap(result_image.ToMemoryStream());
            textBox1.Text = yoloV8.DetectTime();

            button2.Enabled = true;

        }

        /// <summary>
        /// 窗体加载,初始化
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void Form1_Load(object sender, EventArgs e)
        {
            image_path = "test/dog.jpg";
            pictureBox1.Image = new Bitmap(image_path);

            model_path = "model/yolov8n.engine";

            if (!File.Exists(model_path))
            {
                //有点耗时,需等待
                Nvinfer.OnnxToEngine("model/yolov8n.onnx", 20);
            }

            yoloV8 = new YoloV8(model_path, "model/lable.txt");

        }

        /// <summary>
        /// 选择图片
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button1_Click_1(object sender, EventArgs e)
        {
            OpenFileDialog ofd = new OpenFileDialog();
            ofd.Filter = imgFilter;
            if (ofd.ShowDialog() != DialogResult.OK) return;

            pictureBox1.Image = null;

            image_path = ofd.FileName;
            pictureBox1.Image = new Bitmap(image_path);

            textBox1.Text = "";
            pictureBox2.Image = null;
        }

        /// <summary>
        /// 选择视频
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button4_Click(object sender, EventArgs e)
        {
            OpenFileDialog ofd = new OpenFileDialog();
            ofd.Filter = videoFilter;
            ofd.InitialDirectory = Application.StartupPath + "\\test";
            if (ofd.ShowDialog() != DialogResult.OK) return;

            video_path = ofd.FileName;

            textBox1.Text = "";
            pictureBox1.Image = null;
            pictureBox2.Image = null;

            button3_Click(null, null);

        }

        /// <summary>
        /// 视频推理
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button3_Click(object sender, EventArgs e)
        {
            if (video_path == "")
            {
                return;
            }

            textBox1.Text = "开始检测";

            Application.DoEvents();

            Thread thread = new Thread(new ThreadStart(VideoDetection));

            thread.Start();
            thread.Join();

            textBox1.Text = "检测完成!";
        }

        void VideoDetection()
        {
            vcapture = new VideoCapture(video_path);
            if (!vcapture.IsOpened())
            {
                MessageBox.Show("打开视频文件失败");
                return;
            }

            tracker = new ByteTracker((int)vcapture.Fps, 200);

            Mat frame = new Mat();
            List<DetectionResult> detResults;

            // 获取视频的fps
            double videoFps = vcapture.Get(VideoCaptureProperties.Fps);
            // 计算等待时间(毫秒)
            int delay = (int)(1000 / videoFps);
            Stopwatch _stopwatch = new Stopwatch();

            if (checkBox1.Checked)
            {
                vwriter = new VideoWriter("out.mp4", FourCC.X264, vcapture.Fps, new OpenCvSharp.Size(vcapture.FrameWidth, vcapture.FrameHeight));
                saveDetVideo = true;
            }
            else
            {
                saveDetVideo = false;
            }

            while (vcapture.Read(frame))
            {
                if (frame.Empty())
                {
                    MessageBox.Show("读取失败");
                    return;
                }

                _stopwatch.Restart();

                delay = (int)(1000 / videoFps);

                detResults = yoloV8.Detect(frame);

                //绘制结果
                //foreach (DetectionResult r in detResults)
                //{
                //    Cv2.PutText(frame, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                //    Cv2.Rectangle(frame, r.Rect, Scalar.Red, thickness: 2);
                //}

                Cv2.PutText(frame, "preprocessTime:" + yoloV8.preprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 30), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "inferTime:" + yoloV8.inferTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 70), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "postprocessTime:" + yoloV8.postprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 110), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "totalTime:" + yoloV8.totalTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 150), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "video fps:" + videoFps.ToString("F2"), new OpenCvSharp.Point(10, 190), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "det fps:" + yoloV8.detFps.ToString("F2"), new OpenCvSharp.Point(10, 230), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);

                List<Track> track = new List<Track>();
                Track temp;
                foreach (DetectionResult r in detResults)
                {
                    RectBox _box = new RectBox(r.Rect.X, r.Rect.Y, r.Rect.Width, r.Rect.Height);
                    temp = new Track(_box, r.Confidence, ("label", r.ClassId), ("name", r.Class));
                    track.Add(temp);
                }

                var trackOutputs = tracker.Update(track);

                foreach (var t in trackOutputs)
                {
                    Rect rect = new Rect((int)t.RectBox.X, (int)t.RectBox.Y, (int)t.RectBox.Width, (int)t.RectBox.Height);
                    //string txt = $"{t["name"]}-{t.TrackId}:{t.Score:P0}";
                    string txt = $"{t["name"]}-{t.TrackId}";
                    Cv2.PutText(frame, txt, new OpenCvSharp.Point(rect.TopLeft.X, rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                    Cv2.Rectangle(frame, rect, Scalar.Red, thickness: 2);
                }

                if (saveDetVideo)
                {
                    vwriter.Write(frame);
                }

                Cv2.ImShow("DetectionResult", frame);

                // for test
                // delay = 1;
                delay = (int)(delay - _stopwatch.ElapsedMilliseconds);
                if (delay <= 0)
                {
                    delay = 1;
                }
                //Console.WriteLine("delay:" + delay.ToString()) ;
                if (Cv2.WaitKey(delay) == 27)
                {
                    break; // 如果按下ESC,退出循环
                }
            }

            Cv2.DestroyAllWindows();
            vcapture.Release();
            if (saveDetVideo)
            {
                vwriter.Release();
            }

        }

    }

}

using ByteTrack;
using OpenCvSharp;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Drawing;
using System.IO;
using System.Threading;
using System.Windows.Forms;
using TensorRtSharp.Custom;

namespace yolov8_TensorRT_Demo
{
    public partial class Form2 : Form
    {
        public Form2()
        {
            InitializeComponent();
        }

        string imgFilter = "*.*|*.bmp;*.jpg;*.jpeg;*.tiff;*.tiff;*.png";

        YoloV8 yoloV8;
        Mat image;

        string image_path = "";
        string model_path;

        string video_path = "";
        string videoFilter = "*.mp4|*.mp4;";
        VideoCapture vcapture;
        VideoWriter vwriter;
        bool saveDetVideo = false;
        ByteTracker tracker;


        /// <summary>
        /// 单图推理
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button2_Click(object sender, EventArgs e)
        {

            if (image_path == "")
            {
                return;
            }

            button2.Enabled = false;
            pictureBox2.Image = null;
            textBox1.Text = "";

            Application.DoEvents();

            image = new Mat(image_path);

            List<DetectionResult> detResults = yoloV8.Detect(image);

            //绘制结果
            Mat result_image = image.Clone();
            foreach (DetectionResult r in detResults)
            {
                Cv2.PutText(result_image, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.Rectangle(result_image, r.Rect, Scalar.Red, thickness: 2);
            }

            if (pictureBox2.Image != null)
            {
                pictureBox2.Image.Dispose();
            }
            pictureBox2.Image = new Bitmap(result_image.ToMemoryStream());
            textBox1.Text = yoloV8.DetectTime();

            button2.Enabled = true;

        }

        /// <summary>
        /// 窗体加载,初始化
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void Form1_Load(object sender, EventArgs e)
        {
            image_path = "test/dog.jpg";
            pictureBox1.Image = new Bitmap(image_path);

            model_path = "model/yolov8n.engine";

            if (!File.Exists(model_path))
            {
                //有点耗时,需等待
                Nvinfer.OnnxToEngine("model/yolov8n.onnx", 20);
            }

            yoloV8 = new YoloV8(model_path, "model/lable.txt");

        }

        /// <summary>
        /// 选择图片
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button1_Click_1(object sender, EventArgs e)
        {
            OpenFileDialog ofd = new OpenFileDialog();
            ofd.Filter = imgFilter;
            if (ofd.ShowDialog() != DialogResult.OK) return;

            pictureBox1.Image = null;

            image_path = ofd.FileName;
            pictureBox1.Image = new Bitmap(image_path);

            textBox1.Text = "";
            pictureBox2.Image = null;
        }

        /// <summary>
        /// 选择视频
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button4_Click(object sender, EventArgs e)
        {
            OpenFileDialog ofd = new OpenFileDialog();
            ofd.Filter = videoFilter;
            ofd.InitialDirectory = Application.StartupPath + "\\test";
            if (ofd.ShowDialog() != DialogResult.OK) return;

            video_path = ofd.FileName;

            textBox1.Text = "";
            pictureBox1.Image = null;
            pictureBox2.Image = null;

            button3_Click(null, null);

        }

        /// <summary>
        /// 视频推理
        /// </summary>
        /// <param name="sender"></param>
        /// <param name="e"></param>
        private void button3_Click(object sender, EventArgs e)
        {
            if (video_path == "")
            {
                return;
            }

            textBox1.Text = "开始检测";

            Application.DoEvents();

            Thread thread = new Thread(new ThreadStart(VideoDetection));

            thread.Start();
            thread.Join();

            textBox1.Text = "检测完成!";
        }

        void VideoDetection()
        {
            vcapture = new VideoCapture(video_path);
            if (!vcapture.IsOpened())
            {
                MessageBox.Show("打开视频文件失败");
                return;
            }

            tracker = new ByteTracker((int)vcapture.Fps, 200);

            Mat frame = new Mat();
            List<DetectionResult> detResults;

            // 获取视频的fps
            double videoFps = vcapture.Get(VideoCaptureProperties.Fps);
            // 计算等待时间(毫秒)
            int delay = (int)(1000 / videoFps);
            Stopwatch _stopwatch = new Stopwatch();

            if (checkBox1.Checked)
            {
                vwriter = new VideoWriter("out.mp4", FourCC.X264, vcapture.Fps, new OpenCvSharp.Size(vcapture.FrameWidth, vcapture.FrameHeight));
                saveDetVideo = true;
            }
            else
            {
                saveDetVideo = false;
            }

            while (vcapture.Read(frame))
            {
                if (frame.Empty())
                {
                    MessageBox.Show("读取失败");
                    return;
                }

                _stopwatch.Restart();

                delay = (int)(1000 / videoFps);

                detResults = yoloV8.Detect(frame);

                //绘制结果
                //foreach (DetectionResult r in detResults)
                //{
                //    Cv2.PutText(frame, $"{r.Class}:{r.Confidence:P0}", new OpenCvSharp.Point(r.Rect.TopLeft.X, r.Rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                //    Cv2.Rectangle(frame, r.Rect, Scalar.Red, thickness: 2);
                //}

                Cv2.PutText(frame, "preprocessTime:" + yoloV8.preprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 30), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "inferTime:" + yoloV8.inferTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 70), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "postprocessTime:" + yoloV8.postprocessTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 110), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "totalTime:" + yoloV8.totalTime.ToString("F2") + "ms", new OpenCvSharp.Point(10, 150), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "video fps:" + videoFps.ToString("F2"), new OpenCvSharp.Point(10, 190), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                Cv2.PutText(frame, "det fps:" + yoloV8.detFps.ToString("F2"), new OpenCvSharp.Point(10, 230), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);

                List<Track> track = new List<Track>();
                Track temp;
                foreach (DetectionResult r in detResults)
                {
                    RectBox _box = new RectBox(r.Rect.X, r.Rect.Y, r.Rect.Width, r.Rect.Height);
                    temp = new Track(_box, r.Confidence, ("label", r.ClassId), ("name", r.Class));
                    track.Add(temp);
                }

                var trackOutputs = tracker.Update(track);

                foreach (var t in trackOutputs)
                {
                    Rect rect = new Rect((int)t.RectBox.X, (int)t.RectBox.Y, (int)t.RectBox.Width, (int)t.RectBox.Height);
                    //string txt = $"{t["name"]}-{t.TrackId}:{t.Score:P0}";
                    string txt = $"{t["name"]}-{t.TrackId}";
                    Cv2.PutText(frame, txt, new OpenCvSharp.Point(rect.TopLeft.X, rect.TopLeft.Y - 10), HersheyFonts.HersheySimplex, 1, Scalar.Red, 2);
                    Cv2.Rectangle(frame, rect, Scalar.Red, thickness: 2);
                }

                if (saveDetVideo)
                {
                    vwriter.Write(frame);
                }

                Cv2.ImShow("DetectionResult", frame);

                // for test
                // delay = 1;
                delay = (int)(delay - _stopwatch.ElapsedMilliseconds);
                if (delay <= 0)
                {
                    delay = 1;
                }
                //Console.WriteLine("delay:" + delay.ToString()) ;
                if (Cv2.WaitKey(delay) == 27)
                {
                    break; // 如果按下ESC,退出循环
                }
            }

            Cv2.DestroyAllWindows();
            vcapture.Release();
            if (saveDetVideo)
            {
                vwriter.Release();
            }

        }

    }

}

YoloV8.cs

using OpenCvSharp;
using OpenCvSharp.Dnn;
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Text;
using TensorRtSharp.Custom;

namespace yolov8_TensorRT_Demo
{
    public class YoloV8
    {

        float[] input_tensor_data;
        float[] outputData;
        List<DetectionResult> detectionResults;

        int input_height;
        int input_width;

        Nvinfer predictor;

        public string[] class_names;
        int class_num;
        int box_num;

        float conf_threshold;
        float nms_threshold;

        float ratio_height;
        float ratio_width;

        public double preprocessTime;
        public double inferTime;
        public double postprocessTime;
        public double totalTime;
        public double detFps;

        public String DetectTime()
        {
            StringBuilder stringBuilder = new StringBuilder();
            stringBuilder.AppendLine($"Preprocess: {preprocessTime:F2}ms");
            stringBuilder.AppendLine($"Infer: {inferTime:F2}ms");
            stringBuilder.AppendLine($"Postprocess: {postprocessTime:F2}ms");
            stringBuilder.AppendLine($"Total: {totalTime:F2}ms");

            return stringBuilder.ToString();
        }

        public YoloV8(string model_path, string classer_path)
        {
            predictor = new Nvinfer(model_path);

            class_names = File.ReadAllLines(classer_path, Encoding.UTF8);
            class_num = class_names.Length;

            input_height = 640;
            input_width = 640;

            box_num = 8400;

            conf_threshold = 0.25f;
            nms_threshold = 0.5f;

            detectionResults = new List<DetectionResult>();
        }

        void Preprocess(Mat image)
        {
            //图片缩放
            int height = image.Rows;
            int width = image.Cols;
            Mat temp_image = image.Clone();
            if (height > input_height || width > input_width)
            {
                float scale = Math.Min((float)input_height / height, (float)input_width / width);
                OpenCvSharp.Size new_size = new OpenCvSharp.Size((int)(width * scale), (int)(height * scale));
                Cv2.Resize(image, temp_image, new_size);
            }
            ratio_height = (float)height / temp_image.Rows;
            ratio_width = (float)width / temp_image.Cols;
            Mat input_img = new Mat();
            Cv2.CopyMakeBorder(temp_image, input_img, 0, input_height - temp_image.Rows, 0, input_width - temp_image.Cols, BorderTypes.Constant, 0);

            //归一化
            input_img.ConvertTo(input_img, MatType.CV_32FC3, 1.0 / 255);

            input_tensor_data = Common.ExtractMat(input_img);

            input_img.Dispose();
            temp_image.Dispose();
        }

        void Postprocess(float[] outputData)
        {
            detectionResults.Clear();

            float[] data = Common.Transpose(outputData, class_num + 4, box_num);

            float[] confidenceInfo = new float[class_num];
            float[] rectData = new float[4];

            List<DetectionResult> detResults = new List<DetectionResult>();

            for (int i = 0; i < box_num; i++)
            {
                Array.Copy(data, i * (class_num + 4), rectData, 0, 4);
                Array.Copy(data, i * (class_num + 4) + 4, confidenceInfo, 0, class_num);

                float score = confidenceInfo.Max(); // 获取最大值

                int maxIndex = Array.IndexOf(confidenceInfo, score); // 获取最大值的位置

                int _centerX = (int)(rectData[0] * ratio_width);
                int _centerY = (int)(rectData[1] * ratio_height);
                int _width = (int)(rectData[2] * ratio_width);
                int _height = (int)(rectData[3] * ratio_height);

                detResults.Add(new DetectionResult(
                   maxIndex,
                   class_names[maxIndex],
                   new Rect(_centerX - _width / 2, _centerY - _height / 2, _width, _height),
                   score));
            }

            //NMS
            CvDnn.NMSBoxes(detResults.Select(x => x.Rect), detResults.Select(x => x.Confidence), conf_threshold, nms_threshold, out int[] indices);
            detResults = detResults.Where((x, index) => indices.Contains(index)).ToList();

            detectionResults = detResults;
        }

        internal List<DetectionResult> Detect(Mat image)
        {

            var t1 = Cv2.GetTickCount();

            Stopwatch stopwatch = new Stopwatch();
            stopwatch.Start();

            Preprocess(image);

            preprocessTime = stopwatch.Elapsed.TotalMilliseconds;
            stopwatch.Restart();

            predictor.LoadInferenceData("images", input_tensor_data);

            predictor.infer();

            inferTime = stopwatch.Elapsed.TotalMilliseconds;
            stopwatch.Restart();

            outputData = predictor.GetInferenceResult("output0");

            Postprocess(outputData);

            postprocessTime = stopwatch.Elapsed.TotalMilliseconds;
            stopwatch.Stop();

            totalTime = preprocessTime + inferTime + postprocessTime;

            detFps = (double)stopwatch.Elapsed.TotalSeconds / (double)stopwatch.Elapsed.Ticks;

            var t2 = Cv2.GetTickCount();

            detFps = 1 / ((t2 - t1) / Cv2.GetTickFrequency());

            return detectionResults;

        }

    }
}

ByteTracker.cs

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;

namespace ByteTrack
{
    public class ByteTracker
    {
        readonly float _trackThresh;
        readonly float _highThresh;
        readonly float _matchThresh;
        readonly int _maxTimeLost;

        int _frameId = 0;
        int _trackIdCount = 0;

        readonly List<Track> _trackedTracks = new List<Track>(100);
        readonly List<Track> _lostTracks = new List<Track>(100);
        List<Track> _removedTracks = new List<Track>(100);

        public ByteTracker(int frameRate = 30, int trackBuffer = 30, float trackThresh = 0.5f, float highThresh = 0.6f, float matchThresh = 0.8f)
        {
            _trackThresh = trackThresh;
            _highThresh = highThresh;
            _matchThresh = matchThresh;
            _maxTimeLost = (int)(frameRate / 30.0 * trackBuffer);
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="objects"></param>
        /// <returns></returns>
        public IList<Track> Update(List<Track> tracks)
        {
            #region Step 1: Get detections 
            _frameId++;

            // Create new Tracks using the result of object detection
            List<Track> detTracks = new List<Track>();
            List<Track> detLowTracks = new List<Track>();

            foreach (var obj in tracks)
            {
                if (obj.Score >= _trackThresh)
                {
                    detTracks.Add(obj);
                }
                else
                {
                    detLowTracks.Add(obj);
                }
            }

            // Create lists of existing STrack
            List<Track> activeTracks = new List<Track>();
            List<Track> nonActiveTracks = new List<Track>();

            foreach (var trackedTrack in _trackedTracks)
            {
                if (!trackedTrack.IsActivated)
                {
                    nonActiveTracks.Add(trackedTrack);
                }
                else
                {
                    activeTracks.Add(trackedTrack);
                }
            }

            var trackPool = activeTracks.Union(_lostTracks).ToArray();

            // Predict current pose by KF
            foreach (var track in trackPool)
            {
                track.Predict();
            }
            #endregion

            #region Step 2: First association, with IoU 
            List<Track> currentTrackedTracks = new List<Track>();
            Track[] remainTrackedTracks;
            Track[] remainDetTracks;
            List<Track> refindTracks = new List<Track>();
            {
                var dists = CalcIouDistance(trackPool, detTracks);
                LinearAssignment(dists, trackPool.Length, detTracks.Count, _matchThresh,
                    out var matchesIdx,
                    out var unmatchTrackIdx,
                    out var unmatchDetectionIdx);

                foreach (var matchIdx in matchesIdx)
                {
                    var track = trackPool[matchIdx[0]];
                    var det = detTracks[matchIdx[1]];
                    if (track.State == TrackState.Tracked)
                    {
                        track.Update(det, _frameId);
                        currentTrackedTracks.Add(track);
                    }
                    else
                    {
                        track.ReActivate(det, _frameId);
                        refindTracks.Add(track);
                    }
                }

                remainDetTracks = unmatchDetectionIdx.Select(unmatchIdx => detTracks[unmatchIdx]).ToArray();
                remainTrackedTracks = unmatchTrackIdx
                    .Where(unmatchIdx => trackPool[unmatchIdx].State == TrackState.Tracked)
                    .Select(unmatchIdx => trackPool[unmatchIdx])
                    .ToArray();
            }
            #endregion

            #region Step 3: Second association, using low score dets 
            List<Track> currentLostTracks = new List<Track>();
            {
                var dists = CalcIouDistance(remainTrackedTracks, detLowTracks);
                LinearAssignment(dists, remainTrackedTracks.Length, detLowTracks.Count, 0.5f,
                                 out var matchesIdx,
                                 out var unmatchTrackIdx,
                                 out var unmatchDetectionIdx);

                foreach (var matchIdx in matchesIdx)
                {
                    var track = remainTrackedTracks[matchIdx[0]];
                    var det = detLowTracks[matchIdx[1]];
                    if (track.State == TrackState.Tracked)
                    {
                        track.Update(det, _frameId);
                        currentTrackedTracks.Add(track);
                    }
                    else
                    {
                        track.ReActivate(det, _frameId);
                        refindTracks.Add(track);
                    }
                }

                foreach (var unmatchTrack in unmatchTrackIdx)
                {
                    var track = remainTrackedTracks[unmatchTrack];
                    if (track.State != TrackState.Lost)
                    {
                        track.MarkAsLost();
                        currentLostTracks.Add(track);
                    }
                }
            }
            #endregion

            #region Step 4: Init new tracks 
            List<Track> currentRemovedTracks = new List<Track>();
            {
                // Deal with unconfirmed tracks, usually tracks with only one beginning frame
                var dists = CalcIouDistance(nonActiveTracks, remainDetTracks);
                LinearAssignment(dists, nonActiveTracks.Count, remainDetTracks.Length, 0.7f,
                                 out var matchesIdx,
                                 out var unmatchUnconfirmedIdx,
                                 out var unmatchDetectionIdx);

                foreach (var matchIdx in matchesIdx)
                {
                    nonActiveTracks[matchIdx[0]].Update(remainDetTracks[matchIdx[1]], _frameId);
                    currentTrackedTracks.Add(nonActiveTracks[matchIdx[0]]);
                }

                foreach (var unmatchIdx in unmatchUnconfirmedIdx)
                {
                    var track = nonActiveTracks[unmatchIdx];
                    track.MarkAsRemoved();
                    currentRemovedTracks.Add(track);
                }

                // Add new stracks
                foreach (var unmatchIdx in unmatchDetectionIdx)
                {
                    var track = remainDetTracks[unmatchIdx];
                    if (track.Score < _highThresh)
                        continue;

                    _trackIdCount++;
                    track.Activate(_frameId, _trackIdCount);
                    currentTrackedTracks.Add(track);
                }
            }
            #endregion

            #region Step 5: Update state
            foreach (var lostTrack in _lostTracks)
            {
                if (_frameId - lostTrack.FrameId > _maxTimeLost)
                {
                    lostTrack.MarkAsRemoved();
                    currentRemovedTracks.Add(lostTrack);
                }
            }

            var trackedTracks = currentTrackedTracks.Union(refindTracks).ToArray();
            var lostTracks = _lostTracks.Except(trackedTracks).Union(currentLostTracks).Except(_removedTracks).ToArray();
            _removedTracks = _removedTracks.Union(currentRemovedTracks).ToList();
            RemoveDuplicateStracks(trackedTracks, lostTracks);
            #endregion

            return _trackedTracks.Where(track => track.IsActivated).ToArray();
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="aTracks"></param>
        /// <param name="bTracks"></param>
        /// <param name="aResults"></param>
        /// <param name="bResults"></param>
        void RemoveDuplicateStracks(IList<Track> aTracks, IList<Track> bTracks)
        {
            _trackedTracks.Clear();
            _lostTracks.Clear();

            List<(int, int)> overlappingCombinations;
            var ious = CalcIouDistance(aTracks, bTracks);

            if (ious is null)
                overlappingCombinations = new List<(int, int)>();
            else
            {
                var rows = ious.GetLength(0);
                var cols = ious.GetLength(1);
                overlappingCombinations = new List<(int, int)>(rows * cols / 2);
                for (var i = 0; i < rows; i++)
                    for (var j = 0; j < cols; j++)
                        if (ious[i, j] < 0.15f)
                            overlappingCombinations.Add((i, j));
            }

            var aOverlapping = aTracks.Select(x => false).ToArray();
            var bOverlapping = bTracks.Select(x => false).ToArray();

            foreach (var (aIdx, bIdx) in overlappingCombinations)
            {
                var timep = aTracks[aIdx].FrameId - aTracks[aIdx].StartFrameId;
                var timeq = bTracks[bIdx].FrameId - bTracks[bIdx].StartFrameId;
                if (timep > timeq)
                    bOverlapping[bIdx] = true;
                else
                    aOverlapping[aIdx] = true;
            }

            for (var ai = 0; ai < aTracks.Count; ai++)
                if (!aOverlapping[ai])
                    _trackedTracks.Add(aTracks[ai]);

            for (var bi = 0; bi < bTracks.Count; bi++)
                if (!bOverlapping[bi])
                    _lostTracks.Add(bTracks[bi]);
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="costMatrix"></param>
        /// <param name="costMatrixSize"></param>
        /// <param name="costMatrixSizeSize"></param>
        /// <param name="thresh"></param>
        /// <param name="matches"></param>
        /// <param name="aUnmatched"></param>
        /// <param name="bUnmatched"></param>
        void LinearAssignment(float[,] costMatrix, int costMatrixSize, int costMatrixSizeSize, float thresh, out IList<int[]> matches, out IList<int> aUnmatched, out IList<int> bUnmatched)
        {
            matches = new List<int[]>();
            if (costMatrix is null)
            {
                aUnmatched = Enumerable.Range(0, costMatrixSize).ToArray();
                bUnmatched = Enumerable.Range(0, costMatrixSizeSize).ToArray();
                return;
            }

            bUnmatched = new List<int>();
            aUnmatched = new List<int>();

            var (rowsol, colsol) = Lapjv.Exec(costMatrix, true, thresh);

            for (var i = 0; i < rowsol.Length; i++)
            {
                if (rowsol[i] >= 0)
                    matches.Add(new int[] { i, rowsol[i] });
                else
                    aUnmatched.Add(i);
            }

            for (var i = 0; i < colsol.Length; i++)
                if (colsol[i] < 0)
                    bUnmatched.Add(i);
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="aRects"></param>
        /// <param name="bRects"></param>
        /// <returns></returns>
        static float[,] CalcIous(IList<RectBox> aRects, IList<RectBox> bRects)
        {
            if (aRects.Count * bRects.Count == 0) return null;

            var ious = new float[aRects.Count, bRects.Count];
            for (var bi = 0; bi < bRects.Count; bi++)
                for (var ai = 0; ai < aRects.Count; ai++)
                    ious[ai, bi] = bRects[bi].CalcIoU(aRects[ai]);

            return ious;
        }

        /// <summary>
        /// 
        /// </summary>
        /// <param name="aTtracks"></param>
        /// <param name="bTracks"></param>
        /// <returns></returns>
        static float[,] CalcIouDistance(IEnumerable<Track> aTtracks, IEnumerable<Track> bTracks)
        {
            var aRects = aTtracks.Select(x => x.RectBox).ToArray();
            var bRects = bTracks.Select(x => x.RectBox).ToArray();

            var ious = CalcIous(aRects, bRects);
            if (ious is null) return null;

            var rows = ious.GetLength(0);
            var cols = ious.GetLength(1);
            var matrix = new float[rows, cols];
            for (var i = 0; i < rows; i++)
                for (var j = 0; j < cols; j++)
                    matrix[i, j] = 1 - ious[i, j];

            return matrix;
        }
    }
}

下载

源码下载

参考 

https://github.com/devhxj/Yolo8-ByteTrack-CSharp

https://github.com/guojin-yan/TensorRT-CSharp-API

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/658788.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

深入分析 Android Activity (十)

文章目录 深入分析 Android Activity (十)1. Activity 的资源管理1.1 使用资源 ID 访问资源1.2 Drawable 资源1.3 使用 TypedArray 管理资源1.4 使用资源配置 2. Activity 的数据存储2.1 SharedPreferences2.2 文件存储2.3 SQLite 数据库2.4 ContentProvider 3. Activity 的性能…

倪老师是教我们如何去读书

之前一篇我们了解了倪老师&#xff0c;告诉我们如何去学习一些东西&#xff0c;今天这一篇&#xff0c;我把倪老师视频里面总结的几点&#xff0c;倪老师教我们如何去读书&#xff0c;我把一些小细节做了一个简单的整理&#xff0c;我们共同地去看&#xff0c;倪老师是教我们如…

JVS物联网、逻辑引擎、智能BI(重构优化)5.21功能新增说明

项目介绍 JVS是企业级数字化服务构建的基础脚手架&#xff0c;主要解决企业信息化项目交付难、实施效率低、开发成本高的问题&#xff0c;采用微服务配置化的方式&#xff0c;提供了 低代码数据分析物联网的核心能力产品&#xff0c;并构建了协同办公、企业常用的管理工具等&am…

SqliSniper:针对HTTP Header的基于时间SQL盲注模糊测试工具

关于SqliSniper SqliSniper是一款基于Python开发的强大工具&#xff0c;该工具旨在检测HTTP请求Header中潜在的基于时间的SQL盲注问题。 该工具支持通过多线程形式快速扫描和识别目标应用程序中的潜在漏洞&#xff0c;可以大幅增强安全评估过程&#xff0c;同时确保了速度和效…

使用 Django 连接 MySQL 数据库

文章目录 步骤一&#xff1a;安装必要的库和驱动步骤二&#xff1a;配置数据库连接步骤三&#xff1a;执行数据库迁移步骤四&#xff1a;开始使用 MySQL 数据库创建一个模型迁移模型到数据库使用模型进行数据操作创建新记录&#xff1a;查询记录&#xff1a;更新记录&#xff1…

Innodb Buffer Pool缓存机制(一)一条sql的执行过程

思维导图 石墨文档&#xff1a;https://shimo.im/mindmaps/NJkbnZV0ePINXzkR 一、SQL的执行 执行过程&#xff1a; 加载缓存数据&#xff0c;加载id为1的记录所在的整页数据&#xff08;相当于索引树的一个结点&#xff0c;16KB&#xff09;&#xff1b;写入更新数据的旧值到…

重生奇迹mu卡智力的方法

1、准备3个号A打手,B智力MM,C随意。 2、使用C匹配组队,但是不能选择自动进入队伍。 3、用A申请C的队伍,但是C不做通过处理。 4、用A组B,用快捷键D的方式。 5、所谓的卡智力就是智力MM可以给打手加属性加血&#xff0c;但是并不在一个队伍里享受经验&#xff0c;适用于MM不是…

如何提升网站运营效率

企业网站作为品牌展示、客户获取和商业目标实现的关键平台&#xff0c;其运营效率的提升对增强用户体验、搜索引擎排名和转化率至关重要。以下是一些有效技术和策略的介绍&#xff0c;旨在助力您提高网站运营的效率。 一、网站速度的优化 1.1 利用内容分发网络&#xff08;CD…

基于Docker的ROS开发

本文主要介绍如何使用Docker在Windows和Linux环境中部署并使用ROS&#xff0c;通过Docker Container运行ROS&#xff0c;可以方便我们在一个本地环境中运行多个ROS版本。 更多内容&#xff0c;访问专栏目录获取实时更新。 关于ROS的版本 参考ROS1 Distribution Wiki和ROS2 Dis…

轻松驾驭文件管理:一键转换格式,自定义重命名,让你的文件整理更高效便捷!

在日常工作和生活中&#xff0c;我们经常会面临需要修改文件格式的情况。无论是转换文件为更普遍使用的格式&#xff0c;还是根据特定需求调整文件的扩展名&#xff0c;都需要一个简单而有效的文件管理工具来完成这项任务。那么&#xff0c;文件批量改名高手将会让您在文件格式…

开源博客项目Blog .NET Core源码学习(28:App.Hosting项目结构分析-16)

本文学习并分析App.Hosting项目中后台管理页面的用户管理页面。   用户管理页面用于显示、检索、新建、编辑、删除用户数据&#xff0c;其附带一新建及编辑页面&#xff0c;以支撑新建和编辑用户数据&#xff0c;同时还附带重置密码页面&#xff0c;以重置用户密码。整个页面…

【安全产品】基于HFish的MySQL蜜罐溯源实验记录

MySQL蜜罐对攻击者机器任意文件读取 用HFish在3306端口部署MySQL蜜罐 配置读取文件路径 攻击者的mysql客户端版本为5.7(要求低于8.0) 之后用命令行直连 mysql -h 124.222.136.33 -P 3306 -u root -p 可以看到成功连上蜜罐的3306服务&#xff0c;但进行查询后会直接lost con…

弘君资本:原始股、普通股、优先股、分红股有什么区别?

1、原始股是公司在上市之前发行的股票&#xff0c;一般只有公司管理层、董事长、监事、公司重要职工、股权出资基金等才干取得公司的原始股。原始股需求出资认购。 在股市里也叫做限售股&#xff0c;一般的原始股&#xff0c;发起人持有的本公司股份&#xff0c;自公司建立之日…

ABP框架+Mysql(一)

生成项目 通过用官网的来生成 Get Started | ABP.IO 配上官网地址&#xff0c;需要注意的是&#xff0c;数据库选择Mysql 选择完成后&#xff0c;执行页面上的两条命令 dotnet tool install -g Volo.Abp.Cli abp new Acme.BookStore -dbms MySQL -csf 生成结束后的内容 单击打…

js在处理异步任务时,forEach和for...of循环之间的区别

先看效果 forEach循环&#xff1a; 1、forEach是数组的原生方法&#xff0c;用于遍历数组。 2、它无法直接处理异步任务&#xff0c;因为它不会等待每个任务的完成&#xff0c;而是立即执行下一个任务。 3、这意味着如果在forEach循环中执行异步任务&#xff0c;它们将会同时进…

29. 【Java教程】异常处理

Java 的异常处理是 Java 语言的一大重要特性&#xff0c;也是提高代码健壮性的最强大方法之一。当我们编写了错误的代码时&#xff0c;编译器在编译期间可能会抛出异常&#xff0c;有时候即使编译正常&#xff0c;在运行代码的时候也可能会抛出异常。本小节我们将介绍什么是异常…

多方法总结,怎么修改图片格式为jpg?

在数字化时代&#xff0c;我们常常需要灵活地处理图片&#xff0c;以适应不同的需求和平台。而在这个过程中&#xff0c;将图片格式修改为JPEG&#xff08;JPG&#xff09;是一项常见的任务。 JPEG格式以其压缩算法和较小的文件大小而成为互联网上最常见的图片格式之一。怎么修…

链式法则:神经网络前向与反向传播的基石

在深度学习的浪潮中&#xff0c;神经网络以其强大的学习和预测能力&#xff0c;成为解决复杂问题的有力工具。而神经网络之所以能够不断学习和优化&#xff0c;离不开两个核心过程&#xff1a;前向传播和反向传播。其中&#xff0c;链式法则作为微积分学中的一个基本概念&#…

校企携手|泰迪智能科技与高新启动「大数据应用技术」深度合作项目

5月22日&#xff0c;广东泰迪智能科技股份有限公司携手广东省高新技术高级技工学校举行“泰迪高新技术学校大数据双创工作室”暨广东省“产教评”技能生态链学生学徒公共实训基地签约揭牌仪式&#xff0c;标志着双方合作共建大数据应用技术专业、产教生态链实训基地及泰迪高新大…

如何在Android手机恢复误删除的数据

电话数据对我们至关重要。我们可以替换我们使用的设备&#xff0c;但不能替换我们的数据。我们以前一直在使用 CD、USB 和硬盘驱动器来保存数据。随着技术的出现&#xff0c;我们遇到了云存储。我们可以从任何地方和任意次数访问的存储。所有操作系统都有数据云&#xff0c;可用…