Paddle实现单目标检测

单目标检测

单目标检测(Single Object Detection)是人工智能领域中的一个重要研究方向,旨在通过计算机视觉技术,识别和定位图像中的特定目标物体。单目标检测可以应用于各种场景,如智能监控、自动驾驶、医疗影像分析等。

简单来说,单目标检测就是在确定一个目标在图片中的位置:

检测亮起的信号灯在图像中的位置

 本文将以信号灯检测为例,介绍单目标检测的方法

环境准备

这个案例需要安装以下两个库:

pip install paddlepaddle-gpu
pip install lxml

数据集准备

本文采用如下数据集:红绿灯检测_练习_训练集(非比赛数据)_数据集-飞桨AI Studio星河社区 (baidu.com)

这个数据集共有2000张信号灯的照片,其中1000张绿灯,1000张红灯。每张照片都对应着一个xml文件,标注着信号灯在图片中的位置:

<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<annotation>
    <folder>Images</folder>
    <filename>green_0.jpg</filename>
    <source>
        <database>Unknown</database>
    </source>
    <size>
        <width>424</width>
        <height>240</height>
        <depth>3</depth>
    </size>
    <segmented>0</segmented>
    <object>
        <name>green</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <occluded>0</occluded>
        <bndbox>
            <xmin>247</xmin>
            <ymin>147</ymin>
            <xmax>301</xmax>
            <ymax>190</ymax>
        </bndbox>
    </object>
</annotation>

这里面,<width>和<height>标签分别定义了宽和高,<name>定义了样本的类别(red或者green),<bndbox>里的标签则是定义了信号灯的位置(矩形框)

接下来我们编写dataset.py,用于定义数据集类:

import paddle
import glob
from lxml import etree
from PIL import Image  
import numpy as np 
  
# 定义一个字典,将颜色名称映射到ID  
name_to_id = {'red': 0, 'green': 1}  
  
# 将绝对坐标转换为相对坐标  
def to_labels(path):  
    # 读取XML文件内容  
    text = open(f'{path}').read().encode('utf8')  
    # 解析XML内容  
    xml = etree.HTML(text)  
    # 提取图像的宽度和高度  
    width = int(xml.xpath('//size/width/text()')[0])  
    height = int(xml.xpath('//size/height/text()')[0])  
    # 提取边界框的坐标  
    xmin = int(xml.xpath('//bndbox/xmin/text()')[0])  
    xmax = int(xml.xpath('//bndbox/xmax/text()')[0])  
    ymin = int(xml.xpath('//bndbox/ymin/text()')[0])  
    ymax = int(xml.xpath('//bndbox/ymax/text()')[0])  
    # 将绝对坐标转换为相对坐标  
    return xmin / width, ymin / height, xmax / width, ymax / height  
  
  
# 定义一个PaddlePaddle数据集类  
class Dataset(paddle.io.Dataset):  
    def __init__(self, pos='training_data'):  
        super().__init__()  # 调用父类构造函数  
        # 查找指定目录下的所有.jpg图片和.xml标签文件  
        self.imgs = glob.glob(f'{pos}/*.jpg')  
        self.labels = glob.glob(f'{pos}/*.xml')  
  
    def __getitem__(self, idx):  
        # 根据索引获取图片和标签  
        img = self.imgs[idx]  
        label = to_labels(self.labels[idx])  
        # 打开图片并转换为RGB模式  
        pil_img = Image.open(img).convert('RGB')  
        # 将PIL图片转换为numpy数组,并转换为float32类型  
        # 同时将通道顺序从HWC转换为CHW(PaddlePaddle默认输入格式)  
        t = paddle.to_tensor(np.array(pil_img, dtype=np.float32).transpose((2, 0, 1)))  
        # 返回图片张量和标签张量  
        return t, paddle.to_tensor(label[:4])  
  
    def __len__(self):  
        # 返回数据集中图片的数量  
        return len(self.imgs)

训练脚本

单目标检测可以看作一个回归问题,输出4个值,用于确定目标的坐标,因此我们可以使用resnet,并指定其类别数量为4(即输出4个值),并采用MSE损失函数(因为这是回归问题),据此,可以写出训练脚本的代码:

import paddle  
from dataset import Dataset  
  
# 初始化Dataset实例,设置数据位置为'training_data'  
dataset = Dataset(pos='training_data')  
  
# 使用ResNet18网络结构,并设置输出类别数为4  
net = paddle.vision.resnet18(num_classes=4)  
# 将网络封装为PaddlePaddle的Model对象  
model = paddle.Model(net)  
  
# 准备模型训练,包括优化器(Adam)和损失函数(均方误差损失)  
model.prepare(  
    paddle.optimizer.Adam(parameters=model.parameters()),  
    paddle.nn.MSELoss(),  
)  
  
# 训练模型,设置训练轮数为160,批处理大小为16 
model.fit(dataset, epochs=160, batch_size=16, verbose=1)  
  
# 保存模型到'output/model'路径  
model.save('output/model')

可以看到,训练脚本还是非常简单的。

简单使用

使用脚本也很简单:

import matplotlib.pyplot as plt  
import matplotlib.patches as patches  
import numpy as np  
from PIL import Image  
import paddle  
  
# 图片路径  
img_path = 'testing_data/red_1003.jpg' 
# 打开图片并转换为RGB格式  
pil_img = Image.open(img_path).convert('RGB')  
# 将PIL图片转换为Paddle Tensor,并调整通道顺序  
t = paddle.to_tensor([np.array(pil_img, dtype=np.float32).transpose((2, 0, 1))])  
  
# 加载ResNet18模型,并设置为4个类别  
net = paddle.vision.resnet18(num_classes=4)  
model = paddle.Model(net)  
# 加载训练好的模型权重  
model.load('output/model')  
  
# 预测图片  
pred = model.predict_batch(t)[0][0]  
print(f'预测结果:{pred}')  
  
# 根据预测结果计算边界框坐标  
xmin = float(pred[0]) * 424  
ymin = float(pred[1]) * 240  
xmax = float(pred[2]) * 424  
ymax = float(pred[3]) * 240  
  
# 显示原始图片  
plt.imshow(np.array(t[0], dtype=np.int32).transpose((1, 2, 0)))  
  
# 定义多边形的顶点坐标(这里是预测的边界框)  
vertices = np.array([[xmin, ymin], [xmin, ymax], [xmax, ymax], [xmax, ymin]])  
# 创建一个多边形对象,用于绘制边界框  
polygon = patches.Polygon(vertices, closed=True, edgecolor='black', facecolor='none')  
# 将多边形添加到当前坐标轴上  
plt.gca().add_patch(polygon)  
# 显示图片和边界框  
plt.show()

输出:

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/675016.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

玩转Linux进度条

准备工作&#xff1a; 一.关于缓冲区 首先&#xff0c;咱们先来一段有意思的代码&#xff1a; #include<stdio.h> #include<unistd.h> int main() {printf("you can see me");sleep(5);} 你可以在你的本地运行一下&#xff0c;这里我告诉大家运行结果…

android睡眠分期图

一、效果图 做医疗类项目&#xff0c;经常会遇到做各种图表&#xff0c;本文做的睡眠分期图。 二、代码 引入用到的库 api joda-time:joda-time:2.10.1 调用代码 /*** 睡眠* 分期*/private SleepChartAdapter mAdapter;private SleepChartAttrs mAttrs;private List<SleepI…

d2-crud-plus 使用小技巧(六)—— 表单下拉选择 行样式 溢出时显示异常优化

问题 vue2 elementUI d2-crud-plus&#xff0c;数据类型为select时&#xff0c;行样式显示为tag样式&#xff0c;但是如果选择内容过长就会出现下面这种bug&#xff0c;显然用户体验不够友好。 期望 代码 js export const crudOptions (vm) > {return {...columns:…

成功解决“ModuleNotFoundError: No Module Named Pycocotools”错误的全面指南

成功解决“ModuleNotFoundError: No Module Named Pycocotools”错误的全面指南 在Python的数据科学、计算机视觉和机器学习项目中&#xff0c;经常需要用到各种工具和库来加速开发过程。其中&#xff0c;pycocotools 是一个专门用于处理 COCO 数据集的库&#xff0c;它提供了多…

2024年Google算法更新打击低质量(如AI生成)内容后,英文SEO优化人员该如何调整谷歌SEO优化策略?

3月5日&#xff0c;谷歌发布了2024年的首次算法更新。与以往更新不同&#xff0c;本次更新更加复杂&#xff0c;这次更新旨在提高搜索结果的质量和相关性&#xff0c;可能对外贸网站排名和流量产生显著影响。也将产生更大的网站数据波动。但在担心自己的网站数据受到影响之前&a…

Django 里的增删改查

下面是步骤 先更新 urls.py 来添加新的url from django.contrib import admin from django.urls import path from app01 import viewsurlpatterns [path(demo/, views.demo), ]在 models.py 里创建表 from django.db import models# Create your models here. class UserI…

毫米级精度3D人脸扫描设备,助推打造元宇宙虚拟分身

在元宇宙中&#xff0c;虚拟分身对应的是一个三维模型&#xff0c;数字化的过程则是三维重建过程&#xff0c;通过3D人脸扫描可以通过多相机同步采集人脸部&#xff0c;可快速、准确地重建出真人地脸部模型及贴图&#xff0c;通过3D人脸扫描设备可快速重建出高逼真的虚拟分身。…

BioTech - 使用 Kubeflow 多机多卡 运行 高精度蛋白质结构的迭代预测

欢迎关注我的CSDN:https://spike.blog.csdn.net/ 本文地址:https://blog.csdn.net/caroline_wendy/article/details/139418138 核心逻辑,参考:使用 循环(Recycle)迭代的蛋白质结构预测 获取 高精度结构 Kubeflow 是一个开源的 Kubernetes 原生框架,专注于简化、可移植和可…

【反悔贪心】算法讲解

目录 cf865D 环形喂猪 建筑抢修 cf865D 思路&#xff1a; 我们贪心的原则是尽可能的多卖&#xff0c;而且尽可能的卖的多。 整体的贪心思路就是能卖就卖&#xff0c;卖完后放入堆中&#xff08;以便反悔&#xff09;&#xff0c;先不考虑能卖多少&#xff0c;因为堆是按照价…

02--nginx代理缓存

前言&#xff1a;比较常用的用法反向代理&#xff0c;和缓存的一些操作&#xff0c;用虚拟环境复刻出来&#xff0c;里面参数不用详细记录&#xff0c;用作复习&#xff0c;使用时直接查找即可。环境搭建过程参考前一篇文章nginx基础。 1、基础环境 IP角色作用192.168.189.143…

AWR设置工程仿真频率、原理图仿真频率、默认单位

AWR设置工程仿真频率、原理图仿真频率、默认单位 生活不易&#xff0c;喵喵叹气。马上就要上班了&#xff0c;公司的ADS的版权紧缺&#xff0c;主要用的软件都是NI 的AWR&#xff0c;只能趁着现在没事做先学习一下子了&#xff0c;希望不要裁我。 最近稍微学习了一下AWR这个软…

参加质量源于设计QbD培训能学到什么

近年来&#xff0c;产品质量已经成为了企业能否立足市场的关键。因此&#xff0c;质量源于设计&#xff08;QbD&#xff09;的理念应运而生&#xff0c;它强调在产品开发初期就注重质量设计&#xff0c;以最大限度地降低潜在风险&#xff0c;提高产品的稳定性和可靠性。参加质量…

诺亚财富——财富管理行业的进化逻辑

詹姆斯•卡斯的著作《有限与无限的游戏》中&#xff0c;传递出这样一种观点&#xff1a; “有限的游戏&#xff0c;其目的在于赢得胜利&#xff1b;无限的游戏&#xff0c;却旨在让游戏永远进行下去。有限的游戏在边界内玩&#xff0c;无限的游戏玩的就是边界。” 在商业社会…

我的app开始养活我了

大家在日常使用各类 app 时应该会发现&#xff0c;进入 app 会有个开屏广告&#xff0c;在使用 app 中&#xff0c;时不时的也会有广告被我们刷到。 这时候如果我们看完了这个广告&#xff0c;或者点击了这个广告的话&#xff0c;app商家就会获得这个广告的佣金。 这个佣金就是…

用WebStorm和VS Code断点调试Vue

大家好&#xff0c;我是咕噜铁蛋&#xff01;。今天&#xff0c;我想和大家分享一下如何在WebStorm和VS Code这两款流行的开发工具中&#xff0c;使用断点调试Vue.js项目。Vue.js作为前端三大框架之一&#xff0c;以其轻量级和组件化的特性&#xff0c;受到了广大开发者的喜爱。…

18、matlab信号生成与预处理--剔除异常值:hampel()函数

1、语法 说明&#xff1a;对输入向量x应用Hampel滤波器来检测和去除异常值。 1&#xff09;y hampel(x) 参数&#xff1a;x&#xff1a;输入信号 y:预处理的输出信号 对于x的每个样本&#xff0c;函数计算由样本及其周围的六个样本组成的窗口的中位数&#xff0c;每边三…

Linux下的Git应用及配置

1、卸载 2、安装 3、创建并初始化 4、配置 &#xff08;附加删除语句&#xff09; 5、查看&#xff08;tree .git/&#xff09; 6、增加和提交 7、打印日志 8、验证已操作工作

LeetCode刷题:反转链表

leetCode真题 206. 反转链表 属于基础简单题目 常见的做法有递归和while循环 递归 // 1. 递归参数和返回值public static ListNode reverseList(ListNode head) {// 1. 递归终止条件if (head null || head.next null) {return head;}// 递归逻辑ListNode last reverseL…

安全攻防知识——CTF之MISC

前言&#xff1a; 本周技术分享将介绍安全攻防知识中的MISC部分。MISC&#xff0c;中文即杂项&#xff0c;包括隐写、数据还原、脑洞、社会工程、压缩包解密、流量分析取证、与信息安全相关的大数据等。让我们一起来了解更多吧&#xff01; 一&#xff09;文件结构简介 1.常见…

手把手制作Vue3+Flask全栈项目 全栈开发之路实战篇 问卷网站(一)login页面

全栈开发一条龙——前端篇 第一篇&#xff1a;框架确定、ide设置与项目创建 第二篇&#xff1a;介绍项目文件意义、组件结构与导入以及setup的引入。 第三篇&#xff1a;setup语法&#xff0c;设置响应式数据。 第四篇&#xff1a;数据绑定、计算属性和watch监视 第五篇 : 组件…