第99步 深度学习图像目标检测:SSDlite建模

基于WIN10的64位系统演示

一、写在前面

本期,我们继续学习深度学习图像目标检测系列,SSD(Single Shot MultiBox Detector)模型的后续版本,SSDlite模型。

二、SSDlite简介

SSDLite 是 SSD 模型的一个变种,旨在为移动设备和边缘计算设备提供更高效的目标检测。SSDLite 的主要特点是使用了轻量级的骨干网络和特定的卷积操作来减少计算复杂性,从而提高检测速度,同时在大多数情况下仍保持了较高的准确性。

以下是 SSDLite 的主要特性和组件:

(1)轻量级骨干:

SSDLite 不使用 VGG 或 ResNet 这样的重量级骨干。相反,它使用 MobileNet 作为骨干,特别是 MobileNetV2 或 MobileNetV3。这些网络使用深度可分离的卷积和其他轻量级操作来减少计算成本。

(2)深度可分离的卷积:

这是 MobileNet 的核心组件,也被用于 SSDLite。深度可分离的卷积将传统的卷积操作分解为两个较小的操作:一个深度卷积和一个点卷积,这大大减少了计算和参数数量。

(3)多尺度特征映射:

与原始的 SSD 相似,SSDLite 也从不同的层级提取特征图以检测不同大小的物体。

(4)默认框:

SSDLite 也使用默认框(或称为锚框)来进行边界框预测。

(5)单阶段检测:

与 SSD 相同,SSDLite 也是一个单阶段检测器,同时进行边界框回归和分类。

(6)损失函数:

SSDLite 使用与 SSD 相同的组合损失,包括平滑 L1 损失和交叉熵损失。

综上,SSDLite 是为了速度和效率而设计的,特别是针对计算和内存资源有限的设备。通过使用轻量级的骨干和深度可分离的卷积,它能够在减少计算负担的同时,仍然保持合理的检测准确性。

三、数据源

来源于公共数据,文件设置如下:

大概的任务就是:用一个框框标记出MTB的位置。

四、SSDlite实战

直接上代码:

import os
import random
import torch
import torchvision
from torchvision.models.detection import ssdlite320_mobilenet_v3_large
from torchvision.transforms import functional as F
from PIL import Image
from torch.utils.data import DataLoader
import xml.etree.ElementTree as ET
import matplotlib.pyplot as plt
from torchvision import transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
import numpy as np

# Function to parse XML annotations
def parse_xml(xml_path):
    tree = ET.parse(xml_path)
    root = tree.getroot()

    boxes = []
    for obj in root.findall("object"):
        bndbox = obj.find("bndbox")
        xmin = int(bndbox.find("xmin").text)
        ymin = int(bndbox.find("ymin").text)
        xmax = int(bndbox.find("xmax").text)
        ymax = int(bndbox.find("ymax").text)

        # Check if the bounding box is valid
        if xmin < xmax and ymin < ymax:
            boxes.append((xmin, ymin, xmax, ymax))
        else:
            print(f"Warning: Ignored invalid box in {xml_path} - ({xmin}, {ymin}, {xmax}, {ymax})")

    return boxes

# Function to split data into training and validation sets
def split_data(image_dir, split_ratio=0.8):
    all_images = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]
    random.shuffle(all_images)
    split_idx = int(len(all_images) * split_ratio)
    train_images = all_images[:split_idx]
    val_images = all_images[split_idx:]
    
    return train_images, val_images


# Dataset class for the Tuberculosis dataset
class TuberculosisDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, annotation_dir, image_list, transform=None):
        self.image_dir = image_dir
        self.annotation_dir = annotation_dir
        self.image_list = image_list
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_path = os.path.join(self.image_dir, self.image_list[idx])
        image = Image.open(image_path).convert("RGB")
        
        xml_path = os.path.join(self.annotation_dir, self.image_list[idx].replace(".jpg", ".xml"))
        boxes = parse_xml(xml_path)
        
        # Check for empty bounding boxes and return None
        if len(boxes) == 0:
            return None
        
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.ones((len(boxes),), dtype=torch.int64)
        iscrowd = torch.zeros((len(boxes),), dtype=torch.int64)
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        target["image_id"] = torch.tensor([idx])
        target["iscrowd"] = iscrowd
        
        # Apply transformations
        if self.transform:
            image = self.transform(image)
    
        return image, target

# Define the transformations using torchvision
data_transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),  # Convert PIL image to tensor
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the images
])


# Adjusting the DataLoader collate function to handle None values
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return tuple(zip(*batch))


def get_ssdlite_model_for_finetuning(num_classes):
    # Load an SSDlite model with a MobileNetV3 Large backbone without pre-trained weights
    model = ssdlite320_mobilenet_v3_large(pretrained=False, num_classes=num_classes)
    return model

# Function to save the model
def save_model(model, path="SSDlite_mtb.pth", save_full_model=False):
    if save_full_model:
        torch.save(model, path)
    else:
        torch.save(model.state_dict(), path)
    print(f"Model saved to {path}")

# Function to compute Intersection over Union
def compute_iou(boxA, boxB):
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])
    
    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)
    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
    
    iou = interArea / float(boxAArea + boxBArea - interArea)
    return iou

# Adjusting the DataLoader collate function to handle None values and entirely empty batches
def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    if len(batch) == 0:
        # Return placeholder batch if entirely empty
        return [torch.zeros(1, 3, 224, 224)], [{}]
    return tuple(zip(*batch))

#Training function with modifications for collecting IoU and loss
def train_model(model, train_loader, optimizer, device, num_epochs=10):
    model.train()
    model.to(device)
    loss_values = []
    iou_values = []
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        total_ious = 0
        num_boxes = 0
        for images, targets in train_loader:
            # Skip batches with placeholder data
            if len(targets) == 1 and not targets[0]:
                continue
            # Skip batches with empty targets
            if any(len(target["boxes"]) == 0 for target in targets):
                continue
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            
            loss_dict = model(images, targets)
            losses = sum(loss for loss in loss_dict.values())
            
            optimizer.zero_grad()
            losses.backward()
            optimizer.step()
            
            epoch_loss += losses.item()
            
            # Compute IoU for evaluation
            with torch.no_grad():
                model.eval()
                predictions = model(images)
                for i, prediction in enumerate(predictions):
                    pred_boxes = prediction["boxes"].cpu().numpy()
                    true_boxes = targets[i]["boxes"].cpu().numpy()
                    for pred_box in pred_boxes:
                        for true_box in true_boxes:
                            iou = compute_iou(pred_box, true_box)
                            total_ious += iou
                            num_boxes += 1
                model.train()
        
        avg_loss = epoch_loss / len(train_loader)
        avg_iou = total_ious / num_boxes if num_boxes != 0 else 0
        loss_values.append(avg_loss)
        iou_values.append(avg_iou)
        print(f"Epoch {epoch+1}/{num_epochs} Loss: {avg_loss} Avg IoU: {avg_iou}")
    
    # Plotting loss and IoU values
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(loss_values, label="Training Loss")
    plt.title("Training Loss across Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    
    plt.subplot(1, 2, 2)
    plt.plot(iou_values, label="IoU")
    plt.title("IoU across Epochs")
    plt.xlabel("Epochs")
    plt.ylabel("IoU")
    plt.show()

    # Save model after training
    save_model(model)

# Validation function
def validate_model(model, val_loader, device):
    model.eval()
    model.to(device)
    
    with torch.no_grad():
        for images, targets in val_loader:
            images = [image.to(device) for image in images]
            targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
            model(images)

# Paths to your data
image_dir = "tuberculosis-phonecamera"
annotation_dir = "tuberculosis-phonecamera"

# Split data
train_images, val_images = split_data(image_dir)

# Create datasets and dataloaders
train_dataset = TuberculosisDataset(image_dir, annotation_dir, train_images, transform=data_transform)
val_dataset = TuberculosisDataset(image_dir, annotation_dir, val_images, transform=data_transform)

# Updated DataLoader with new collate function
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, collate_fn=collate_fn)

# Model and optimizer
model = get_ssdlite_model_for_finetuning(2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Train and validate
train_model(model, train_loader, optimizer, device="cuda", num_epochs=10)
validate_model(model, val_loader, device="cuda")

需要从头训练的,就不跑了,摆烂了。

五、写在后面

目标检测模型门槛更高了,运行起来对硬件要求也很高,时间也很久,都是小时起步的。因此只是简单介绍,算是入个门了。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/188390.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

2017年4月10日 Go生态洞察:开发者体验工作组介绍

&#x1f337;&#x1f341; 博主猫头虎&#xff08;&#x1f405;&#x1f43e;&#xff09;带您 Go to New World✨&#x1f341; &#x1f984; 博客首页——&#x1f405;&#x1f43e;猫头虎的博客&#x1f390; &#x1f433; 《面试题大全专栏》 &#x1f995; 文章图文…

高级驾驶辅助系统 (ADAS)介绍

随着汽车技术持续快速发展,推动更安全、更智能、更高效的驾驶体验一直是汽车创新的前沿。高级驾驶辅助系统( ADAS ) 是这场技术革命的关键参与者,是 指集成到现代车辆中的一组技术和功能,用于增强驾驶员安全、改善驾驶体验并协助完成各种驾驶任务。它使用传感器、摄像头、雷…

SQL Injection (Blind)`

SQL Injection (Blind) SQL Injection (Blind) SQL盲注&#xff0c;是一种特殊类型的SQL注入攻击&#xff0c;它的特点是无法直接从页面上看到注入语句的执行结果。在这种情况下&#xff0c;需要利用一些方法进行判断或者尝试&#xff0c;这个过程称之为盲注。 盲注的主要形式有…

​root账号登录群晖NAS教程​

用WinSCPPuTTY以root账号登录群晖NAS保姆教程用WinSCPPuTTY可SecureCRT 以root账号登录群晖NAS 1、先用自己的用户名 密码登陆。 2、切换到root权限 输入sudo -i,按回车,然后也是输入群辉登录的密码。成功之后,显示$ 变成 #号

SpringCloud实用-OpenFeign整合okHttp

文章目录 前言正文一、OkHttpFeignConfiguration 的启用1.1 分析配置类1.2 得出结论&#xff0c;需要增加配置1.3 调试 二、OkHttpFeignLoadBalancerConfiguration 的启用2.1 分析配置类2.2 得出结论2.3 测试 附录附1&#xff1a;本系列文章链接附2&#xff1a;OkHttpClient 增…

代码随想录算法训练营第四十五天|57. 爬楼梯、322.零钱兑换、279. 完全平方数

KamaCoder 57. 爬楼梯 题目链接&#xff1a;题目页面 (kamacoder.com) 这道题使用完全背包来实现&#xff0c;我们首先考虑的是总的楼梯数&#xff0c;因此dp数组大小为n 1 &#xff0c;其意义是&#xff0c;在n阶时有多少种方法爬到楼顶&#xff0c;因此&#xff0c;当前n状…

【高级网络程序设计】Week2-1 Sockets

一、The Basics 1. Sockets 定义An abstraction of a network interface应用 use the Socket API to create connections to remote computers send data(bytes) receive data(bytes) 2. Java network programming the java network libraryimport java.net.*;similar to…

面试:双线程交替打印奇偶数

代码如下&#xff1a; package practice1;/*** 0-100的奇数偶数打印* 1、通过对象的wait和notify进行线程阻塞* 2、通过对num%2的结果进行奇数偶数的判断输出**/ public class JiOuOne {private static volatile int num 0;private static final int max 100;public static …

【Docker】从零开始:12.容器数据卷

【Docker】从零开始&#xff1a;12.容器数据卷 1.什么是容器数据库卷2.数据的覆盖问题3.为什么要用数据卷4.Docker提供了两种卷&#xff1a;5.两种卷的区别6.bind mount7.Docker managed volumevolume 语法volume 操作参数 1.什么是容器数据库卷 卷 就是目录或文件&#xff0c…

js粒子效果(一)

效果: 代码: <!doctype html> <html> <head><meta charset"utf-8"><title>HTML5鼠标经过粒子散开动画特效</title><style>html, body {position: absolute;overflow: hidden;margin: 0;padding: 0;width: 100%;height: 1…

本地部署 ComfyUI

本地部署 ComfyUI ComfyUI 介绍ComfyUI Github 地址部署 ComfyUI配置模型地址 or 下载模型启动 ComfyUI访问 ComfyUI使用技巧页面底部显示图片预览改变连接线的格式配置 prompt 自动补全 安装 ComfyUI-Manager安装 AIGODLIKE-COMFYUI-TRANSLATION安装 ComfyUI-Custom-Scripts安…

【 拓扑排序】

文章目录 拓扑排序AOV-网拓扑排序的方法拓扑排序的一个重要应用&#xff1a;拓扑排序的算法 拓扑排序 AOV-网 无环的有向图称作有向无环图。 这种用顶点表示活动&#xff0c;用弧表示活动间的优先关系的有向图称为以顶点为活动的网&#xff08;Activity On Vertex Network&am…

centos7搭建ftp服务

一、安装 yum -y install vsftpd vi /etc/vsftpd/vsftpd.conf二、编辑配置文件 /etc/vsftpd/vsftpd.conf 内容如下 #是否允许匿名&#xff0c;默认no anonymous_enableNO#这个设定值必须要为YES 时&#xff0c;在/etc/passwd内的账号才能以实体用户的方式登入我们的vsftpd主机…

运行软件报错找不到vcruntime140_1.dll无法继续执行代码如何解决?-常见问题

关于vcruntime140_1.dll丢失的6个解决方法。在我们使用电脑的过程中&#xff0c;有时候会遇到一些错误提示&#xff0c;其中之一就是“vcruntime140_1.dll丢失”。那么&#xff0c;究竟什么是vcruntime140_1.dll文件呢&#xff1f;又是什么原因导致了它的丢失&#xff1f;接下来…

软件测试 | MySQL 唯一约束详解

&#x1f4e2;专注于分享软件测试干货内容&#xff0c;欢迎点赞 &#x1f44d; 收藏 ⭐留言 &#x1f4dd; 如有错误敬请指正&#xff01;&#x1f4e2;交流讨论&#xff1a;欢迎加入我们一起学习&#xff01;&#x1f4e2;资源分享&#xff1a;耗时200小时精选的「软件测试」资…

DBeaver连接Oracle时报错:Undefined Error

连接信息检查了很多遍&#xff0c;应该是没问题的&#xff0c;而且驱动也正常下载了&#xff0c;但是就是连不上。 找了好久&#xff0c;终于找到一个可用的方式了&#xff0c;记录一下。 在安装目录修改dbeave.ini文件&#xff0c;最后一行添加 -Duser.nameTest。重启就可以…

Python基础语法之判断语句

1.布尔类型和比较运算符 布尔类型&#xff1a;数字类型的一种。 比较运算符&#xff1a; > < > < ! 2.if语句基本格式 if 要判断的条件&#xff1a; 条件成立&#xff0c;即做~ 例子&#xff1a; 注意&#xff1a;格式上冒号和缩进 3.if else组合…

每日一题(LeetCode)----链表--链表中的下一个更大节点

每日一题(LeetCode)----链表–链表中的下一个更大节点 1.题目&#xff08;1019. 链表中的下一个更大节点&#xff09; 给定一个长度为 n 的链表 head 对于列表中的每个节点&#xff0c;查找下一个 更大节点 的值。也就是说&#xff0c;对于每个节点&#xff0c;找到它旁边的第…

基于单片机压力传感器MPX4115检测-报警系统proteus仿真+源程序

一、系统方案 1、本设计采用这51单片机作为主控器。 2、MPX4115采集压力值、DS18B20采集温度值送到液晶1602显示。 3、按键设置报警值。 4、蜂鸣器报警。 二、硬件设计 原理图如下&#xff1a; 三、单片机软件设计 1、首先是系统初始化 /*********************************…

【小沐学写作】原型设计工具汇总(Axure RP)

文章目录 1、简介2、Axure RP2.1 工具简介2.2 工具特点2.2.1 互动事件2.2.2 条件逻辑2.2.4 工作表格2.2.5 多状态容器2.2.6 数据驱动接口2.2.7 自适应视图2.2.8 流程图 2.3 工具安装2.3.1 安装2.3.2 运行 2.4 使用费用2.5 工具体验2.5.1 登陆框制作 3、其他3.1 Figma3.2 Adobe …