基于SSD的安全帽检测

目录

  • 1. 作者介绍
  • 2. SSD算法介绍
    • 2.1 SSD算法网络结构
    • 2.2 SSD算法训练过程
    • 2.3 SSD算法优缺点
  • 3. 基于SSD的安全帽检测实验
    • 3.1 VOC 2007安全帽数据集
    • 3.2 SSD网络架构
    • 3.3 训练和验证所需的2007_train.txt和2007_val.txt文件生成
    • 3.4 模型训练
    • 3.5 GUI界面
    • 3.6 结果展示
    • 3.7 文件下载
  • 4. 参考连接

1. 作者介绍

胡振远,男,西安工程大学电子信息学院,2023级研究生
研究方向:机器视觉与人工智能
电子邮件:zhenyuan@stu.xpu.edu.cn

张思怡,女,西安工程大学电子信息学院,2022级研究生,张宏伟人工智能课题组
研究方向:机器视觉与人工智能
电子邮件:981664791@qq.com

2. SSD算法介绍

单次检测多框检测器(Single Shot MultiBox Detector,SSD)是一种目标检测算法,它可以在一张图像上同时检测多个目标,并返回它们的位置和类别。SSD算法的基本原理是将图像分成多个不同尺度的特征图,然后在每个特征图上使用卷积操作来检测目标。这种方法能够捕捉不同尺度的目标,因为小尺寸的目标可能只出现在高分辨率的特征图中,而大尺寸的目标可能会出现在低分辨率的特征图中。

SSD算法首先使用一个基础网络(如图VGG16)来提取图像的特征。然后,在基础网络的顶层添加了一系列卷积层,这些层会生成不同分辨率的特征图。这些多尺度的特征图能够覆盖不同大小和形状的目标。在每个特征图的每个位置,SSD生成一组预定义的锚框(也称为先验框)。这些锚框具有不同的宽高比和大小,用于捕捉各种可能的目标。在每个特征图上,SSD使用卷积滤波器进行检测。对于每个锚框,SSD通过分类分支预测其目标类别的概率分布,并通过回归分支预测锚框的边界框偏移量。分类分支负责判断锚框内是否存在目标及其类别,而回归分支负责调整锚框的位置和大小以更准确地围绕目标。

在所有锚框的预测完成后,SSD应用非极大值抑制(Non-Maximum Suppression, NMS)算法来删除重叠的锚框。NMS根据预测的置信度分数选择最优的锚框,同时抑制与其重叠较大的其他锚框,确保每个目标只保留一个检测框。最终,SSD算法将经过过滤的锚框作为检测结果,返回目标的类别和精确位置。通过这种多尺度的检测机制,SSD能够高效地进行目标检测,兼顾检测速度和精度,广泛应用于实时检测任务中。
VGG16 网络架构

2.1 SSD算法网络结构

SSD算法的网络结构可以分为两个部分:基础网络和检测网络。基础网络通常采用经典的卷积神经网络(如VGG16),用于提取图像特征。检测网络由一系列卷积层和池化层组成,用于在特征图上执行目标检测任务。
具体来说,SSD算法的检测网络通常包括以下几个部分:

  1. 特征提取层:在基础网络的基础上,添加一些卷积层和池化层来进一步提取图像的深层特征。这些层通过不断的卷积和下采样操作,提取出更加抽象且具有语义信息的特征表示。
  2. 卷积特征图层:将特征映射到不同的尺度,生成一系列的特征图。这些特征图对应于不同的分辨率,使得网络可以在不同尺度上进行检测,从而能够检测出不同大小的目标。
  3. 检测层:在每个特征图上执行目标检测任务,具体包括分类分支和回归分支。分类分支负责确定特定位置上是否存在目标及其类别,回归分支则负责预测目标的边界框坐标。每个检测层通常包括多个先验框(default boxes),并针对每个先验框进行分类和回归预测。
  4. 后处理层:应用非极大值抑制(Non-Maximum Suppression, NMS)算法来过滤掉冗余的检测结果,并返回最终的目标检测结果。NMS通过选择置信度最高的边界框,并抑制与其重叠较大的其他框,确保每个目标只保留一个最优检测框,从而提高检测结果的准确性和鲁棒性。
    在这里插入图片描述

2.2 SSD算法训练过程

SSD算法的训练过程可以分为两个阶段:预训练和微调。

在预训练阶段,SSD算法使用基础网络进行图像分类任务的训练,以提取图像特征。这个阶段通常采用经典的卷积神经网络(如VGG16或ResNet),并在大型图像分类数据集(如ImageNet)上进行训练,从而获得高质量的图像特征表示。这些预训练的特征随后用于构建SSD的检测网络,在此基础上添加额外的卷积层以生成多尺度特征图。

在微调阶段,SSD算法对整个网络进行微调,以优化目标检测性能。这个阶段的训练数据集通常包括带有目标检测标注的真实图像,如PASCAL VOC或COCO数据集。数据集中的每张图像都包含目标的位置和类别标签。为了增强模型的泛化能力和鲁棒性,常采用各种数据增强技术,如随机裁剪、水平翻转、颜色抖动和缩放等。

在微调阶段,SSD算法通过优化损失函数来学习网络的权重和偏置参数,以最小化目标检测误差。SSD的损失函数通常由两部分组成:分类损失和回归损失。分类损失用于评估分类分支的预测结果是否正确,常使用交叉熵损失函数来衡量。回归损失用于评估回归分支的预测结果是否准确,通常采用平滑L1损失函数(又称Huber损失)来度量预测边界框与真实边界框之间的偏差。

为了加快训练速度和提高模型性能,SSD算法还使用了一些技术措施。例如,数据增强通过增加训练数据的多样性来提高模型的泛化能力;批量归一化(Batch Normalization)用于加速训练过程并稳定模型的训练;Dropout技术用于防止过拟合;学习率调整策略(如学习率衰减或自适应学习率优化器)用于在训练过程中动态调整学习率,从而更有效地找到最优解。

2.3 SSD算法优缺点

SSD算法作为一种用于对象检测的深度学习模型,其主要优点在于高效的检测速度和较高的检测精度。SSD通过在单次前向传递中同时预测多个尺度和纵横比的边界框,从而实现了实时检测。它采用不同尺度的特征图进行多尺度检测,使得对不同大小的对象具有较好的适应性。

此外,SSD的网络结构相对简单,不需要像R-CNN系列那样进行候选区域的生成和分类,因此在推理速度上具有显著优势,适合于实时应用场景,如自动驾驶、视频监控和移动设备上的应用。

然而,SSD算法也存在一些不足之处。首先,在检测小目标时,SSD的表现往往不如一些更加复杂的算法,因为较早层的特征图分辨率较低,导致小目标信息容易丢失。其次,虽然SSD在速度上占优,但在极高精度要求的任务中,其检测精度可能不如一些后续发展的检测算法,如RetinaNet和YOLOv4。最后,由于SSD直接对特征图进行检测,对于背景复杂的场景,其误检率可能较高,需要进一步的后处理步骤来提高精度。

3. 基于SSD的安全帽检测实验

3.1 VOC 2007安全帽数据集

在这里插入图片描述
VOC 2007安全帽数据集是Pascal Visual Object Classes (VOC) Challenge 2007的一部分,旨在为对象检测和分类任务提供标准化的数据集和评估框架。该数据集包含一系列具有复杂场景和多样化物体的图像,其中安全帽是一个具体的目标类别。VOC 2007数据集包括训练集、验证集和测试集,分别用于模型训练、参数调优和性能评估。每个图像都附带有详细的标注信息,包括物体类别、边界框位置等,这些标注信息是由人工精确标记的,以确保高质量的标签数据。此外,数据集还提供了预定义的评价指标,如平均精度(AP),用于衡量模型在对象检测任务中的表现。VOC 2007安全帽数据集在计算机视觉领域广泛应用,尤其是在训练和测试对象检测算法方面,是许多研究工作的基准数据集之一。

3.2 SSD网络架构

所需环境:
torch == 1.2.0

创建一个用于图像检测和预测的SSD对象。首先导入了必要的库和模块,然后定义了一些默认参数,包括模型路径、类别文件路径、输入图像尺寸、主干网络类型、置信度阈值、非极大抑制阈值、先验框尺寸、是否使用无失真缩放以及是否使用CUDA加速。在初始化SSD对象时,更新这些默认参数,计算类别数量并加载先验框,同时初始化用于绘制边界框的颜色,并加载模型和预训练权重。对于图像检测,代码将输入图像转换为RGB格式并调整尺寸,然后进行预处理,包括归一化和添加batch维度,接着将图像输入网络获取预测结果并解码,最后在原图像上绘制边界框和标签。提供计算模型每秒处理帧数的方法,以及生成用于评估模型性能的mAP评估结果的功能。

import colorsys
import os
import time
import warnings

import numpy as np
import torch
import torch.backends.cudnn as cudnn
from PIL import Image, ImageDraw, ImageFont

from nets.ssd import SSD300
from utils.anchors import get_anchors
from utils.utils import cvtColor, get_classes, resize_image, preprocess_input
from utils.utils_bbox import BBoxUtility

warnings.filterwarnings("ignore")

class SSD(object):
    _defaults = {
        #--------------------------------------------------------------------------#
        #   使用自己训练好的模型进行预测要修改model_path和classes_path
        #   model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
        #
        #--------------------------------------------------------------------------#
        "model_path": "E:\yanyi\yanyixia\AI\ssd-pytorch-bilibili\logs\ep023-loss3.253-val_loss3.239.pth",
        "classes_path": 'model_data/voc_classes.txt',
        #---------------------------------------------------------------------#
        #   用于预测的图像大小,和train时使用同一个即可
        #---------------------------------------------------------------------#
        "input_shape": [300, 300],
        #-------------------------------#
        #   主干网络的选择
        #   vgg
        #-------------------------------#
        "backbone": "vgg",
        #---------------------------------------------------------------------#
        #   只有得分大于置信度的预测框会被保留下来
        #---------------------------------------------------------------------#
        "confidence": 0.5,
        #---------------------------------------------------------------------#
        #   非极大抑制所用到的nms_iou大小
        #---------------------------------------------------------------------#
        "nms_iou": 0.45,
        #---------------------------------------------------------------------#
        #   用于指定先验框的大小
        #---------------------------------------------------------------------#
        'anchors_size': [30, 60, 111, 162, 213, 264, 315],
        #---------------------------------------------------------------------#
        #   该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
        #---------------------------------------------------------------------#
        "letterbox_image": False,
        #-------------------------------#
        #   是否使用Cuda
        #   没有GPU可以设置成False
        #-------------------------------#
        "cuda": False,
    }

    @classmethod
    def get_defaults(cls, n):
        if n in cls._defaults:
            return cls._defaults[n]
        else:
            return "Unrecognized attribute name '" + n + "'"

    #---------------------------------------------------#
    #   初始化ssd
    #---------------------------------------------------#
    def __init__(self, **kwargs):
        self.__dict__.update(self._defaults)
        for name, value in kwargs.items():
            setattr(self, name, value)
        #---------------------------------------------------#
        #   计算总的类的数量
        #---------------------------------------------------#
        self.class_names, self.num_classes = get_classes(self.classes_path)
        self.anchors = torch.from_numpy(get_anchors(self.input_shape, self.anchors_size, self.backbone)).type(torch.FloatTensor)
        if self.cuda:
            self.anchors = self.anchors.cuda()
        self.num_classes                    = self.num_classes + 1

        #---------------------------------------------------#
        #   画框设置不同的颜色
        #---------------------------------------------------#
        hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
        self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
        self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))

        self.bbox_util = BBoxUtility(self.num_classes)
        self.generate()

    #---------------------------------------------------#
    #   载入模型
    #---------------------------------------------------#
    def generate(self):
        #-------------------------------#
        #   载入模型与权值
        #-------------------------------#
        self.net    = SSD300(self.num_classes, self.backbone)
        device      = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.net.load_state_dict(torch.load(self.model_path, map_location=device))
        self.net    = self.net.eval()
        print('{} model, anchors, and classes loaded.'.format(self.model_path))

        if self.cuda:
            self.net = torch.nn.DataParallel(self.net)
            cudnn.benchmark = True
            self.net = self.net.cuda()

    #---------------------------------------------------#
    #   检测图片
    #---------------------------------------------------#
    def detect_image(self, image):
        #---------------------------------------------------#
        #   计算输入图片的高和宽
        #---------------------------------------------------#
        image_shape = np.array(np.shape(image)[0:2])
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data  = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        #---------------------------------------------------------#
        #   添加上batch_size维度,图片预处理,归一化。
        #---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

        with torch.no_grad():
            #---------------------------------------------------#
            #   转化成torch的形式
            #---------------------------------------------------#
            images = torch.from_numpy(image_data).type(torch.FloatTensor)
            if self.cuda:
                images = images.cuda()
            #---------------------------------------------------------#
            #   将图像输入网络当中进行预测!
            #---------------------------------------------------------#
            outputs = self.net(images)
            #-----------------------------------------------------------#
            #   将预测结果进行解码
            #-----------------------------------------------------------#
            results = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image,
                                                    nms_iou = self.nms_iou, confidence = self.confidence)
            #--------------------------------------#
            #   如果没有检测到物体,则返回原图
            #--------------------------------------#
            if len(results[0]) <= 0:
                return image

            top_label   = np.array(results[0][:, 4], dtype = 'int32')
            top_conf    = results[0][:, 5]
            top_boxes   = results[0][:, :4]
        #---------------------------------------------------------#
        #   设置字体与边框厚度
        #---------------------------------------------------------#
        font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * np.shape(image)[1] + 0.5).astype('int32'))
        thickness = max((np.shape(image)[0] + np.shape(image)[1]) // self.input_shape[0], 1)
        
        #---------------------------------------------------------#
        #   图像绘制
        #---------------------------------------------------------#


        for i, c in enumerate(top_label):
            predicted_class = self.class_names[int(c)]
            box = top_boxes[i]
            score = top_conf[i]

            top, left, bottom, right = box

            # 确保坐标在图像边界内
            top = max(0, np.floor(top).astype('int32'))
            left = max(0, np.floor(left).astype('int32'))
            bottom = min(image.size[1], np.floor(bottom).astype('int32'))
            right = min(image.size[0], np.floor(right).astype('int32'))

            label = '{} {:.2f}'.format(predicted_class, score)
            draw = ImageDraw.Draw(image)

            # 获取文本的尺寸
            label_size = draw.textbbox((0, 0), label, font=font)[2:]  # 返回值是(left, top, right, bottom),我们只需要宽高

            label = label.encode('utf-8')
            print(label, top, left, bottom, right)

            if top - label_size[1] >= 0:
                text_origin = np.array([left, top - label_size[1]])
            else:
                text_origin = np.array([left, top + 1])

            # 绘制边界框
            for j in range(thickness):
                draw.rectangle([left + j, top + j, right - j, bottom - j], outline=self.colors[int(c)])
            # 绘制标签背景
            draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[int(c)])
            # 绘制文本
            draw.text(text_origin, str(label, 'UTF-8'), fill=(0, 0, 0), font=font)
            del draw

        return image

    def get_FPS(self, image, test_interval):
        #---------------------------------------------------#
        #   计算输入图片的高和宽
        #---------------------------------------------------#
        image_shape = np.array(np.shape(image)[0:2])
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data  = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        #---------------------------------------------------------#
        #   添加上batch_size维度,图片预处理,归一化。
        #---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

        with torch.no_grad():
            #---------------------------------------------------#
            #   转化成torch的形式
            #---------------------------------------------------#
            images = torch.from_numpy(image_data).type(torch.FloatTensor)
            if self.cuda:
                images = images.cuda()
            #---------------------------------------------------------#
            #   将图像输入网络当中进行预测!
            #---------------------------------------------------------#
            outputs     = self.net(images)
            #-----------------------------------------------------------#
            #   将预测结果进行解码
            #-----------------------------------------------------------#
            results     = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image, 
                                                    nms_iou = self.nms_iou, confidence = self.confidence)

        t1 = time.time()
        for _ in range(test_interval):
            with torch.no_grad():
                #---------------------------------------------------------#
                #   将图像输入网络当中进行预测!
                #---------------------------------------------------------#
                outputs     = self.net(images)
                #-----------------------------------------------------------#
                #   将预测结果进行解码
                #-----------------------------------------------------------#
                results     = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image, 
                                                        nms_iou = self.nms_iou, confidence = self.confidence)

        t2 = time.time()
        tact_time = (t2 - t1) / test_interval
        return tact_time

    def get_map_txt(self, image_id, image, class_names, map_out_path):
        f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w") 
        #---------------------------------------------------#
        #   计算输入图片的高和宽
        #---------------------------------------------------#
        image_shape = np.array(np.shape(image)[0:2])
        #---------------------------------------------------------#
        #   在这里将图像转换成RGB图像,防止灰度图在预测时报错。
        #   代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
        #---------------------------------------------------------#
        image       = cvtColor(image)
        #---------------------------------------------------------#
        #   给图像增加灰条,实现不失真的resize
        #   也可以直接resize进行识别
        #---------------------------------------------------------#
        image_data  = resize_image(image, (self.input_shape[1], self.input_shape[0]), self.letterbox_image)
        #---------------------------------------------------------#
        #   添加上batch_size维度,图片预处理,归一化。
        #---------------------------------------------------------#
        image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)

        with torch.no_grad():
            #---------------------------------------------------#
            #   转化成torch的形式
            #---------------------------------------------------#
            images = torch.from_numpy(image_data).type(torch.FloatTensor)
            if self.cuda:
                images = images.cuda()
            #---------------------------------------------------------#
            #   将图像输入网络当中进行预测!
            #---------------------------------------------------------#
            outputs     = self.net(images)
            #-----------------------------------------------------------#
            #   将预测结果进行解码
            #-----------------------------------------------------------#
            results     = self.bbox_util.decode_box(outputs, self.anchors, image_shape, self.input_shape, self.letterbox_image, 
                                                    nms_iou = self.nms_iou, confidence = self.confidence)
            #--------------------------------------#
            #   如果没有检测到物体,则返回原图
            #--------------------------------------#
            if len(results[0]) <= 0:
                return 

            top_label   = np.array(results[0][:, 4], dtype = 'int32')
            top_conf    = results[0][:, 5]
            top_boxes   = results[0][:, :4]
        
        for i, c in list(enumerate(top_label)):
            predicted_class = self.class_names[int(c)]
            box             = top_boxes[i]
            score           = str(top_conf[i])

            top, left, bottom, right = box
            if predicted_class not in class_names:
                continue

            f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))

        f.close()
        return

3.3 训练和验证所需的2007_train.txt和2007_val.txt文件生成

该部分代码用于处理和生成用于训练目标检测模型的数据集标签文件,支持不同模式的操作。首先定义了几个参数,如annotation_modeclasses_pathtrainval_percenttrain_percent,并指定了VOC数据集的路径。根据annotation_mode的值,代码可以生成不同的文件。在annotation_mode为0或1时,代码从VOC数据集中的Annotations文件夹读取XML文件,生成训练集、验证集和测试集的图片ID列表文件(trainval.txt、train.txt、val.txt、test.txt)。在annotation_mode为0或2时,代码进一步处理生成2007_train.txt和2007_val.txt文件,这些文件包含每个图像的路径以及目标边界框和类别信息。通过解析XML文件中的目标对象标签,代码提取边界框坐标和类别信息,并将其写入对应的训练或验证文件中。总之,这段代码主要用于将VOC格式的数据集转换为适合目标检测模型训练的格式。

import os
import random
import xml.etree.ElementTree as ET

from utils.utils import get_classes

#--------------------------------------------------------------------------------------------------------------------------------#
#   annotation_mode用于指定该文件运行时计算的内容
#   annotation_mode为0代表整个标签处理过程,包括获得VOCdevkit/VOC2007/ImageSets里面的txt以及训练用的2007_train.txt、2007_val.txt
#   annotation_mode为1代表获得VOCdevkit/VOC2007/ImageSets里面的txt
#   annotation_mode为2代表获得训练用的2007_train.txt、2007_val.txt
#--------------------------------------------------------------------------------------------------------------------------------#
annotation_mode     = 0
#----------------------------------------------------------------
#   仅在annotation_mode为0和2的时候有效
#-------------------------------------------------------------------#
classes_path        = 'model_data/voc_classes.txt'

trainval_percent    = 0.9
train_percent       = 0.9
#-------------------------------------------------------#
#   指向VOC数据集所在的文件夹
#   默认指向根目录下的VOC数据集
#-------------------------------------------------------#
VOCdevkit_path  = 'VOCdevkit'

VOCdevkit_sets  = [('2007', 'train'), ('2007', 'val')]
classes, _      = get_classes(classes_path)

def convert_annotation(year, image_id, list_file):
    in_file = open(os.path.join(VOCdevkit_path, 'VOC%s/Annotations/%s.xml'%(year, image_id)), encoding='utf-8')
    tree=ET.parse(in_file)
    root = tree.getroot()

    for obj in root.iter('object'):
        difficult = 0 
        if obj.find('difficult')!=None:
            difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (int(float(xmlbox.find('xmin').text)), int(float(xmlbox.find('ymin').text)), int(float(xmlbox.find('xmax').text)), int(float(xmlbox.find('ymax').text)))
        list_file.write(" " + ",".join([str(a) for a in b]) + ',' + str(cls_id))
        
if __name__ == "__main__":
    random.seed(0)
    if annotation_mode == 0 or annotation_mode == 1:
        print("Generate txt in ImageSets.")
        xmlfilepath     = os.path.join(VOCdevkit_path, 'VOC2007/Annotations')
        saveBasePath    = os.path.join(VOCdevkit_path, 'VOC2007/ImageSets/Main')
        temp_xml        = os.listdir(xmlfilepath)
        total_xml       = []
        for xml in temp_xml:
            if xml.endswith(".xml"):
                total_xml.append(xml)

        num     = len(total_xml)  
        list    = range(num)  
        tv      = int(num*trainval_percent)  
        tr      = int(tv*train_percent)  
        trainval= random.sample(list,tv)  
        train   = random.sample(trainval,tr)  
        
        print("train and val size",tv)
        print("train size",tr)
        ftrainval   = open(os.path.join(saveBasePath,'trainval.txt'), 'w')  
        ftest       = open(os.path.join(saveBasePath,'test.txt'), 'w')  
        ftrain      = open(os.path.join(saveBasePath,'train.txt'), 'w')  
        fval        = open(os.path.join(saveBasePath,'val.txt'), 'w')  
        
        for i in list:  
            name=total_xml[i][:-4]+'\n'  
            if i in trainval:  
                ftrainval.write(name)  
                if i in train:  
                    ftrain.write(name)  
                else:  
                    fval.write(name)  
            else:  
                ftest.write(name)  
        
        ftrainval.close()  
        ftrain.close()  
        fval.close()  
        ftest.close()
        print("Generate txt in ImageSets done.")

    if annotation_mode == 0 or annotation_mode == 2:
        print("Generate 2007_train.txt and 2007_val.txt for train.")
        for year, image_set in VOCdevkit_sets:
            image_ids = open(os.path.join(VOCdevkit_path, 'VOC%s/ImageSets/Main/%s.txt'%(year, image_set)), encoding='utf-8').read().strip().split()
            list_file = open('%s_%s.txt'%(year, image_set), 'w', encoding='utf-8')
            for image_id in image_ids:
                list_file.write('%s/VOC%s/JPEGImages/%s.jpg'%(os.path.abspath(VOCdevkit_path), year, image_id))

                convert_annotation(year, image_id, list_file)
                list_file.write('\n')
            list_file.close()
        print("Generate 2007_train.txt and 2007_val.txt for train done.")

3.4 模型训练

该部分代码使用SSD模型进行目标检测的训练过程,主要包括冻结阶段和解冻阶段的训练。首先,代码加载必要的库并设置了一些超参数,比如输入形状、骨干网络类型、预训练模型路径等。然后,代码获取类别和锚点配置,并初始化SSD模型。如果预训练权重存在,则会加载这些权重。

接下来定义了损失函数和历史记录器,并从训练和验证数据集对应的txt文件中读取数据。对于冻结阶段训练,代码设置了批次大小和学习率,并创建数据加载器。通过设置网络的一部分参数不可训练,代码实现了冻结部分网络的功能。在训练过程中,代码使用Adam优化器和学习率调度器进行优化。

在解冻阶段训练,再次设置了批次大小和学习率,重新创建数据加载器,并解冻之前冻结的网络部分,使其参与训练。最后,代码在两个阶段内都调用fit_one_epoch函数来执行实际的训练过程,包括前向传播、计算损失、反向传播和参数更新。通过这种方式,代码逐步调整模型的权重,使其在训练数据上表现良好。

import warnings
import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.hub import load_state_dict_from_url
from nets.ssd import SSD300
from nets.ssd_training import MultiboxLoss, weights_init
from utils.anchors import get_anchors
from utils.callbacks import LossHistory
from utils.dataloader import SSDDataset, ssd_dataset_collate
from utils.utils import get_classes
from utils.utils_fit import fit_one_epoch

warnings.filterwarnings("ignore")

if __name__ == "__main__":
    # 参数配置
    Cuda = False
    classes_path = 'model_data/voc_classes.txt'
    model_path = 'E:\yanyi\yanyixia\AI\model\ssd_weights.pth'
    input_shape = [300, 300]
    backbone = "vgg"
    pretrained = False
    anchors_size = [30, 60, 111, 162, 213, 264, 315]
    Init_Epoch = 0
    Freeze_Epoch = 50
    Freeze_batch_size = 16
    Freeze_lr = 5e-4
    UnFreeze_Epoch = 100
    Unfreeze_batch_size = 4
    Unfreeze_lr = 1e-4
    Freeze_Train = True
    num_workers = 4
    train_annotation_path = '2007_train.txt'
    val_annotation_path = '2007_val.txt'

    # 获取classes和anchor
    class_names, num_classes = get_classes(classes_path)
    num_classes += 1
    anchors = get_anchors(input_shape, anchors_size, backbone)

    model = SSD300(num_classes, backbone, pretrained)
    if not pretrained:
        weights_init(model)
    if model_path != '':
        # 加载预训练权重
        print('Load weights {}.'.format(model_path))
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        model_dict = model.state_dict()
        pretrained_dict = torch.load(model_path, map_location=device)
        pretrained_dict = {k: v for k, v in pretrained_dict.items() if np.shape(model_dict[k]) == np.shape(v)}
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

    model_train = model.train()
    if Cuda:
        model_train = torch.nn.DataParallel(model)
        cudnn.benchmark = True
        model_train = model_train.cuda()

    criterion = MultiboxLoss(num_classes, neg_pos_ratio=3.0)
    loss_history = LossHistory("logs/")

    # 读取数据集对应的txt
    with open(train_annotation_path) as f:
        train_lines = f.readlines()
    with open(val_annotation_path) as f:
        val_lines = f.readlines()
    num_train = len(train_lines)
    num_val = len(val_lines)

    # 冻结阶段训练
    if True:
        batch_size = Freeze_batch_size
        lr = Freeze_lr
        start_epoch = Init_Epoch
        end_epoch = Freeze_Epoch

        epoch_step = num_train // batch_size
        epoch_step_val = num_val // batch_size

        if epoch_step == 0 or epoch_step_val == 0:
            raise ValueError("数据集过小,无法进行训练,请扩充数据集。")

        optimizer = optim.Adam(model_train.parameters(), lr, weight_decay=5e-4)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.94)

        train_dataset = SSDDataset(train_lines, input_shape, anchors, batch_size, num_classes, train=True)
        val_dataset = SSDDataset(val_lines, input_shape, anchors, batch_size, num_classes, train=False)

        gen = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                         drop_last=True, collate_fn=ssd_dataset_collate)
        gen_val = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                             drop_last=True, collate_fn=ssd_dataset_collate)

        # 冻结一定部分训练
        if Freeze_Train:
            if backbone == "vgg":
                for param in model.vgg[:28].parameters():
                    param.requires_grad = False
            else:
                for param in model.mobilenet.parameters():
                    param.requires_grad = False

        for epoch in range(start_epoch, end_epoch):
            fit_one_epoch(model_train, model, criterion, loss_history, optimizer, epoch,
                          epoch_step, epoch_step_val, gen, gen_val, end_epoch, Cuda)
            lr_scheduler.step()

    # 解冻阶段训练
    if True:
        batch_size = Unfreeze_batch_size
        lr = Unfreeze_lr
        start_epoch = Freeze_Epoch
        end_epoch = UnFreeze_Epoch

        epoch_step = num_train // batch_size
        epoch_step_val = num_val // batch_size

        if epoch_step == 0 or epoch_step_val == 0:
            raise ValueError("数据集过小,无法进行训练,请扩充数据集。")

        optimizer = optim.Adam(model_train.parameters(), lr, weight_decay=5e-4)
        lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.94)

        train_dataset = SSDDataset(train_lines, input_shape, anchors, batch_size, num_classes, train=True)
        val_dataset = SSDDataset(val_lines, input_shape, anchors, batch_size, num_classes, train=False)

        gen = DataLoader(train_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                         drop_last=True, collate_fn=ssd_dataset_collate)
        gen_val = DataLoader(val_dataset, shuffle=True, batch_size=batch_size, num_workers=num_workers, pin_memory=True,
                             drop_last=True, collate_fn=ssd_dataset_collate)

        # 解冻后训练
        if Freeze_Train:
            if backbone == "vgg":
                for param in model.vgg[:28].parameters():
                    param.requires_grad = True
            else:
                for param in model.mobilenet.parameters():
                    param.requires_grad = True

        for epoch in range(start_epoch, end_epoch):
            fit_one_epoch(model_train, model, criterion, loss_history, optimizer, epoch,
                          epoch_step, epoch_step_val, gen, gen_val, end_epoch, Cuda)
            lr_scheduler.step()

3.5 GUI界面

该部分代码实现了一个基于SSD模型的安全帽检测系统的图形用户界面(GUI),使用PyQt5框架。它提供了两个功能:检测图像中的安全帽和检测视频中的安全帽。

import sys
import cv2
import numpy as np
from PyQt5.QtWidgets import QApplication, QMainWindow, QPushButton, QLabel, QFileDialog, QVBoxLayout, QWidget, QHBoxLayout, QFrame
from PyQt5.QtGui import QPixmap, QImage, QFont
from PyQt5.QtCore import Qt, QTimer
from PIL import Image
from ssd import SSD

class MainWindow(QMainWindow):
    def __init__(self):
        super().__init__()

        self.ssd = SSD()
        self.initUI()
        self.timer = QTimer()
        self.timer.timeout.connect(self.update_frame)

    def initUI(self):
        self.setWindowTitle('基于SSD的安全帽检测系统')
        self.setGeometry(100, 100, 1200, 800)  # 设置窗口大小

        mainLayout = QVBoxLayout()
        titleLabel = QLabel('基于SSD的安全帽检测系统', self)
        titleLabel.setAlignment(Qt.AlignCenter)
        titleLabel.setFont(QFont('Arial', 24))
        mainLayout.addWidget(titleLabel)

        centerLayout = QHBoxLayout()
        centerLayout.setAlignment(Qt.AlignCenter)

        self.imageLabel = QLabel(self)
        self.imageLabel.setAlignment(Qt.AlignCenter)
        self.imageLabel.setFrameShape(QFrame.Box)
        self.imageLabel.setFixedSize(1100, 600)  # 设置固定大小
        centerLayout.addWidget(self.imageLabel)

        mainLayout.addLayout(centerLayout)

        buttonLayout = QHBoxLayout()

        self.selectImageButton = QPushButton('请选择图片进行检测', self)
        self.selectImageButton.setStyleSheet("font-size: 30px;")
        self.selectImageButton.setFont(QFont('Arial', 20))
        self.selectImageButton.clicked.connect(self.select_image)
        buttonLayout.addWidget(self.selectImageButton)

        self.selectVideoButton = QPushButton('请选择视频进行检测', self)
        self.selectVideoButton.setStyleSheet("font-size: 30px;")
        self.selectVideoButton.setFont(QFont('Arial', 20))
        self.selectVideoButton.clicked.connect(self.select_video)
        buttonLayout.addWidget(self.selectVideoButton)

        mainLayout.addLayout(buttonLayout)

        container = QWidget()
        container.setLayout(mainLayout)
        self.setCentralWidget(container)

        self.setStyleSheet("""
            QPushButton {
                background-color: #4CAF50;
                color: white;
                border: none;
                padding: 15px 32px;
                text-align: center;
                text-decoration: none;
                display: inline-block;
                font-size: 16px;
                margin: 4px 2px;
                transition-duration: 0.4s;
                cursor: pointer;
                border-radius: 12px;
            }
            QPushButton:hover {
                background-color: white;
                color: black;
                border: 2px solid #4CAF50;
            }
            QLabel {
                background-color: white;
            }
            QFrame {
                border: 2px solid #4CAF50;
                border-radius: 15px;
            }
        """)

    def select_image(self):
        imagePath, _ = QFileDialog.getOpenFileName(self, "选择图片", "",
                                                   "Images (*.png *.xpm *.jpg *.jpeg *.bmp *.tif *.tiff)")
        if imagePath:
            image = Image.open(imagePath)
            result_image = self.ssd.detect_image(image)
            result_image = result_image.convert("RGB")
            result_image = np.array(result_image)
            height, width, channel = result_image.shape
            bytesPerLine = 3 * width
            qImg = QImage(result_image.data, width, height, bytesPerLine, QImage.Format_RGB888)
            pixmap = QPixmap.fromImage(qImg)

            self.imageLabel.setPixmap(pixmap.scaled(self.imageLabel.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
            self.imageLabel.adjustSize()

    def select_video(self):
        videoPath, _ = QFileDialog.getOpenFileName(self, "选择视频", "", "Videos (*.mp4 *.avi *.mkv *.mov)")
        if videoPath:
            self.video_path = videoPath
            self.capture = cv2.VideoCapture(self.video_path)
            self.timer.start(30)  # 每30ms更新一次

    def update_frame(self):
        ret, frame = self.capture.read()
        if not ret:
            self.timer.stop()
            return

        frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        image = Image.fromarray(frame)
        result_image = self.ssd.detect_image(image)
        result_frame = np.array(result_image)
        # result_frame = cv2.cvtColor(result_frame, cv2.COLOR_RGB2BGR)

        height, width, channel = result_frame.shape
        bytesPerLine = 3 * width
        qImg = QImage(result_frame.data, width, height, bytesPerLine, QImage.Format_RGB888)
        pixmap = QPixmap.fromImage(qImg)

        self.imageLabel.setPixmap(pixmap.scaled(self.imageLabel.size(), Qt.KeepAspectRatio, Qt.SmoothTransformation))
        self.imageLabel.adjustSize()


if __name__ == '__main__':
    app = QApplication(sys.argv)
    mainWindow = MainWindow()
    mainWindow.show()
    sys.exit(app.exec_())

3.6 结果展示

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

3.7 文件下载

训练所需的ssd_weights.pth和主干的权值可以在百度云下载。

链接: https://pan.baidu.com/s/1iUVE50oLkzqhtZbUL9el9w
提取码: jgn8

4. 参考连接

睿智的目标检测23——Pytorch搭建SSD目标检测平台

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/705865.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Unity+AI01】在Unity中调用DeepSeek大模型!实现AI对话功能!

要在Unity中调用DeepSeek的API并实现用户输入文本后返回对话的功能&#xff0c;你需要遵循以下步骤&#xff1a; 获取API密钥&#xff1a; 首先&#xff0c;你需要从DeepSeek获取API密钥。这通常涉及到注册账户&#xff0c;并可能需要订阅相应的服务。 集成HTTP请求库&#xf…

基于python多光谱遥感数据处理、图像分类、定量评估及机器学习

原文链接&#xff1a;基于python多光谱遥感数据处理、图像分类、定量评估及机器学习 普通数码相机记录了红、绿、蓝三种波长的光&#xff0c;多光谱成像技术除了记录这三种波长光之外&#xff0c;还可以记录其他波长&#xff08;例如&#xff1a;近红外、热红外等&#xff09;光…

pytorch神经网络训练(AlexNet)

导包 import osimport torchimport torch.nn as nnimport torch.optim as optimfrom torch.utils.data import Dataset, DataLoaderfrom PIL import Imagefrom torchvision import models, transforms 定义自定义图像数据集 class CustomImageDataset(Dataset): 定义一个自…

Ubuntu22.04 下 pybind11 搭建,示例

Pybind11 是一个轻量级的库&#xff0c;用于在 C 中创建 Python 绑定。Ubuntu22下安装pybind11步骤如下&#xff1a; 1. 安装 pybind11 1.1 pip 命令安装 pip3 install pybind11 1.2 源代码安装 安装依赖库&#xff1a; sudo pip install -i https://pypi.tuna.tsinghua.e…

AVR晶体管测试仪开源项目编译

AVR晶体管测试仪开源项目编译 &#x1f4cd;原项目地址&#xff1a;https://github.com/Mikrocontroller-net/transistortester/tree/master&#x1f33f; https://github.com/svn2github/transistortester&#x1f33f; https://github.com/wagiminator/ATmega-Transistor-Tes…

2. Revit API UI 之 IExternalCommand 和 IExternalApplication

2. Revit API UI 之 IExternalCommand 和 IExternalApplication 上一篇我们大致看了下 RevitAPI 的一级命名空间划分&#xff0c;再简单讲了一下Attributes命名空间下的3个类&#xff0c;并从一个代码样例&#xff0c;提到了Attributes和IExternalCommand &#xff0c;前者是指…

vite配置unocss

在vue3vitetseslintprettierstylelinthuskylint-stagedcommitlintcommitizencz-git介绍了关于vitevue工程化搭建&#xff0c;现在在这个基础上&#xff0c;我们增加一下unocss unocss官方文档 具体开发中使用遇到的问题可以参考不喜欢原子化CSS得我&#xff0c;还是在新项目中使…

NumPy和数组

1.NumPy是什么 NumPy&#xff08;Numerical Python的缩写&#xff09;是一个开源的Python科学计算模块&#xff0c;其中包含了许多实用的数学函数&#xff0c;用来处理数值型数据。NumPy中&#xff0c;最重要和使用最频繁的对象就是N维数组。 为什么要学习NumPy&#xff1f; …

Java高级技术探索:深入理解JVM内存分区与GC机制

文章目录 引言JVM内存分区概览垃圾回收机制&#xff08;GC&#xff09;GC算法基础常见垃圾回收器ParNew /Serial old 收集器运行示意图 优化实践结语 引言 Java作为一门广泛应用于企业级开发的编程语言&#xff0c;其背后的Java虚拟机&#xff08;JVM&#xff09;扮演着至关重…

TikTok Ads广告综合指南:竞价策略及效果建议

作为全球最受欢迎的应用程序之一&#xff0c;TikTok不仅为用户提供了记录分享生活中美好时刻、交流全球创意的平台&#xff0c;也给全球的企业提供了一个直接触达用户的平台。随着Z时代用户人群的购买力不断上升&#xff0c;出海广告主们也逐渐将目光放在TikTok方面的营销。 上…

【Linux系统编程】线程

Linux线程 文章目录 Linux线程1.进程与线程区别2.线程优点3.API概要4.线程1.线程的创建2.线程等待内存共享验证3.线程退出关于对void** &的理解拓展 4.互斥锁1.创建及销毁互斥锁2.加锁及解锁 5.什么情况下会造成死锁6.条件**1. 创建及销毁条件变量****2. 等待****3. 触发**…

基于大数据的主流电商平台获取商品详情数据SKU数据价格数据

主流电商平台&#xff1a;淘宝 1688 闲鱼 京东 唯品会 蘑菇街 一号店 阿里妈妈 阿里巴巴 苏宁 亚马逊 易贝 速卖通 电子元件 网易考拉 洋码头 VVIC MIC Lazada 拼多多 ​ ​​​​​​​关于电商大数据的介绍&#xff1a; 主流电商大数据的采集&#xff1a;电商API接口的接入…

潮玩宇宙大逃杀APP系统开发成品案例分享指南

这是一款多人游戏&#xff0c;玩家需要选择一个房间躲避杀手。满足人数后&#xff0c;杀手会随机挑选一个房间杀掉里面所有的参与者&#xff0c;其他房间的幸存者将平均瓜分被杀房间的元宝。玩家在选中房间后&#xff0c;倒计时结束前可以自由切换不同房间。 软件项目开发成品…

【Linux】进程控制3——进程程序替换

一&#xff0c;前言 创建子进程的目的之一就是为了代劳父进程执行父进程的部分代码&#xff0c;也就是说本质上来说父子进程都是执行的同一个代码段的数据&#xff0c;在子进程修改数据的时候进行写时拷贝修改数据段的部分数据。 但是还有一个目的——将子进程在运行时指向一个…

自动控制原理【期末复习】(二)

无人机上桨之后可以在调试架上先调试&#xff1a; 1.根轨迹的绘制 /// 前面针对的是时域分析&#xff0c;下面针对频域分析&#xff1a; 2.波特图 3.奈维斯特图绘制 1.奈氏稳定判据 2.对数稳定判据 3.相位裕度和幅值裕度

JavaScript的数组排序

天行健&#xff0c;君子以自强不息&#xff1b;地势坤&#xff0c;君子以厚德载物。 每个人都有惰性&#xff0c;但不断学习是好好生活的根本&#xff0c;共勉&#xff01; 文章均为学习整理笔记&#xff0c;分享记录为主&#xff0c;如有错误请指正&#xff0c;共同学习进步。…

Sora和快手可灵背后的核心技术 | 3DVAE:通过小批量特征交换实现身体和面部的三维形状变分自动编码器

【摘要】学习3D脸部和身体生成模型中一个解开的、可解释的和结构化的潜在表示仍然是一个开放的问题。当需要控制身份特征时,这个问题尤其突出。在本文中,论文提出了一种直观而有效的自监督方法来训练一个3D形状变分自动编码器(VAE),以鼓励身份特征的解开潜在表示。通过交换不同…

自学网络安全的三个必经阶段(含路线图)

一、为什么选择网络安全&#xff1f; 这几年随着我国《国家网络空间安全战略》《网络安全法》《网络安全等级保护2.0》等一系列政策/法规/标准的持续落地&#xff0c;网络安全行业地位、薪资随之水涨船高。 未来3-5年&#xff0c;是安全行业的黄金发展期&#xff0c;提前踏入…

Python:基础爬虫

Python爬虫学习&#xff08;网络爬虫&#xff08;又称为网页蜘蛛&#xff0c;网络机器人&#xff0c;在FOAF社区中间&#xff0c;更经常的称为网页追逐者&#xff09;&#xff0c;是一种按照一定的规则&#xff0c;自动地抓取万维网信息的程序或者脚本。另外一些不常使用的名字…

上海晋名室外危废品暂存柜助力储能电站行业危废品安全储存

近日又有一台SAVEST室外危废暂存柜项目成功验收交付使用&#xff0c;此次项目主要用于储能电站行业废油、废锂电池等危废品的安全储存。 用户单位在日常工作运营中涉及到废油、废锂电池等危废品的室外安全储存问题。4月中旬用户技术总工在寻找解决方案的过程中搜索到上海晋名的…