yolo系列模型训练数据集全流程制作方法(附数据增强代码)

yolo系列的模型在目标检测领域里面受众非常广,也十分流行,但是在使用yolo进行目标检测训练的时候,往往要将VOC格式的数据集转化为yolo专属的数据集,而yolo的训练数据集制作方法呢,最常见的也是有两种,下面我们只讲述一种最常用的方法,也是我最常使用的。

1. voc转yolo格式

我最常使用的目标检测数据集为VOC格式,而它的格式一般如下所示:

- dataset
  |- annotations
  |  |- image1.xml
  |  |- image2.xml
  |  |- ...
  |
  |- images
  |  |- image1.jpg
  |  |- image2.jpg
  |  |- ...
  • dataset 是数据集的根目录。
  • annotations 目录包含每个图像对应的 XML 注释文件。
  • images 目录包含每个图像文件。

而我们要转换的yolo格式如下所示:

- dataset
  |- images
  |  |- image1.jpg
  |  |- image2.jpg
  |  |- ...
  |
  |- labels
  |  |- image1.txt
  |  |- image2.txt
  |  |- ...
  • dataset 是数据集的根目录。
  • images 目录包含每个图像文件,通常是以 .jpg 或 .png 等格式保存的图像文件。
  • labels 目录包含每个图像对应的标签文件,通常是以 .txt 格式保存的文本文件。

而 labels 里面的内容填写格式为下图所示:
在这里插入图片描述

通常,每行的格式为:class x_center y_center width height,其中class代表的是图片中目标所对应的类别,x_center, y_center是边界框的中心点坐标相对于图像宽度和高度的归一化值,widthheight 是边界框的宽度和高度相对于图像宽度和高度的归一化值。
举例如下:
在这里插入图片描述

转换代码:

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join


def convert(size, box):
    x_center = (box[0] + box[1]) / 2.0
    y_center = (box[2] + box[3]) / 2.0
    x = x_center / size[0]
    y = y_center / size[1]

    w = (box[1] - box[0]) / size[0]
    h = (box[3] - box[2]) / size[1]

    return (x, y, w, h)


def convert_annotation(xml_files_path, save_txt_files_path, classes):
    xml_files = os.listdir(xml_files_path)
    for xml_name in xml_files:
        xml_file = os.path.join(xml_files_path, xml_name)
        out_txt_path = os.path.join(save_txt_files_path, xml_name.split('.')[0] + '.txt')
        out_txt_f = open(out_txt_path, 'w')
        tree = ET.parse(xml_file)
        root = tree.getroot()
        size = root.find('size')
        w = int(size.find('width').text)
        h = int(size.find('height').text)

        for obj in root.iter('object'):
            #difficult = obj.find('difficult').text
            cls = obj.find('name').text
            #if cls not in classes or int(difficult) == 1:
                #continue
            cls_id = classes.index(cls)
            xmlbox = obj.find('bndbox')
            b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
                 float(xmlbox.find('ymax').text))
            # b=(xmin, xmax, ymin, ymax)
            # print(w, h, b)
            bb = convert((w, h), b)
            out_txt_f.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')


if __name__ == "__main__":
    # 把forklift_pallet的voc的xml标签文件转化为yolo的txt标签文件
    # 1、需要转化的类别,这里我直接用数字代表类别,由于我是八类,所以从0到7
    classes = ['0', '1', '2', '3', '4', '5', '6', '7']
    # 2、voc格式的xml标签文件路径
    xml_files1 = 'annotations'
    # 3、转化为yolo格式的txt标签文件存储路径
    save_txt_files1 = 'labels'

    convert_annotation(xml_files1, save_txt_files1, classes)

上面代码中注释了一部分内容,比如difficult这一项,由于我xml文件里面没有difficult,所以就注释掉了,大家按照自己的需求进行使用即可。

划分数据集

在我们进行yolo目标检测模型训练之前,需要先将数据集进行合理的划分,比如说划分为训练集:验证集=8:2,或者训练集:验证集:测试集=7:2:1。不过我一般习惯只划分训练集和验证集,也就是按8:2的比例进行划分,代码如下所示:

import os
import shutil
import random

# 定义数据集文件夹路径
dataset_path = 'dataset'
images_path = os.path.join(dataset_path, 'images')
labels_path = os.path.join(dataset_path, 'labels')

# 定义划分后的文件夹路径
new_path = 'mydata'
train_path = os.path.join(new_path, 'train')
val_path = os.path.join(new_path, 'val')

# 创建train和val文件夹
os.makedirs(os.path.join(train_path, 'images'), exist_ok=True)
os.makedirs(os.path.join(train_path, 'labels'), exist_ok=True)
os.makedirs(os.path.join(val_path, 'images'), exist_ok=True)
os.makedirs(os.path.join(val_path, 'labels'), exist_ok=True)

# 获取所有图片文件的文件名
image_files = os.listdir(images_path)
# 随机打乱文件顺序
random.shuffle(image_files)

# 定义验证集所占比例
val_split = 0.1
# 计算验证集大小
num_val = int(len(image_files) * val_split)

# 将数据集按照比例划分到train和val文件夹中
for i, image_file in enumerate(image_files):
    src_image = os.path.join(images_path, image_file)
    src_label = os.path.join(labels_path, image_file.replace('.jpg', '.txt'))
    if i < num_val:
        dst_image = os.path.join(val_path, 'images', image_file)
        dst_label = os.path.join(val_path, 'labels', image_file.replace('.jpg', '.txt'))
    else:
        dst_image = os.path.join(train_path, 'images', image_file)
        dst_label = os.path.join(train_path, 'labels', image_file.replace('.jpg', '.txt'))
    shutil.copy(src_image, dst_image)
    shutil.copy(src_label, dst_label)

划分完成以后的文件夹格式为:

- mydata
  |- train
  |  |- images
  |  |- labels
  |
  |- val
  |  |- images
  |  |- labels

images和labels分别是对应的数据集图片和txt标签。

数据增强

在我们参加一些目标检测类比赛的时候,往往会遇见比赛训练集不足的情况,这将极大程度上影响我们的模型精度,这时候可能就需要用到一些数据增强方法,如翻转、随机裁剪等等。当然,yolo系列的模型一般都自带有数据增强,但是我们也可以尝试训练前进行增强看看效果。
代码如下:

# -*- coding=utf-8 -*-

import time
import random
import copy
import cv2
import os
import math
import numpy as np
from skimage.util import random_noise
from lxml import etree, objectify
import xml.etree.ElementTree as ET
import argparse


# 显示图片
def show_pic(img, bboxes=None):
    '''
    输入:
        img:图像array
        bboxes:图像的所有boudning box list, 格式为[[x_min, y_min, x_max, y_max]....]
        names:每个box对应的名称
    '''
    for i in range(len(bboxes)):
        bbox = bboxes[i]
        x_min = bbox[0]
        y_min = bbox[1]
        x_max = bbox[2]
        y_max = bbox[3]
        cv2.rectangle(img, (int(x_min), int(y_min)), (int(x_max), int(y_max)), (0, 255, 0), 3)
    cv2.namedWindow('pic', 0)  # 1表示原图
    cv2.moveWindow('pic', 0, 0)
    cv2.resizeWindow('pic', 1200, 800)  # 可视化的图片大小
    cv2.imshow('pic', img)
    cv2.waitKey(0)
    cv2.destroyAllWindows()


# 图像均为cv2读取
class DataAugmentForObjectDetection():
    def __init__(self, rotation_rate=0.5, max_rotation_angle=5,
                 crop_rate=0.5, shift_rate=0.5, change_light_rate=0.5,
                 add_noise_rate=0.5, flip_rate=0.5,
                 cutout_rate=0.5, cut_out_length=50, cut_out_holes=1, cut_out_threshold=0.5,
                 is_addNoise=True, is_changeLight=True, is_cutout=True, is_rotate_img_bbox=True,
                 is_crop_img_bboxes=True, is_shift_pic_bboxes=True, is_filp_pic_bboxes=True):

        # 配置各个操作的属性
        self.rotation_rate = rotation_rate
        self.max_rotation_angle = max_rotation_angle
        self.crop_rate = crop_rate
        self.shift_rate = shift_rate
        self.change_light_rate = change_light_rate
        self.add_noise_rate = add_noise_rate
        self.flip_rate = flip_rate
        self.cutout_rate = cutout_rate

        self.cut_out_length = cut_out_length
        self.cut_out_holes = cut_out_holes
        self.cut_out_threshold = cut_out_threshold

        # 是否使用某种增强方式
        self.is_addNoise = is_addNoise
        self.is_changeLight = is_changeLight
        self.is_cutout = is_cutout
        self.is_rotate_img_bbox = is_rotate_img_bbox
        self.is_crop_img_bboxes = is_crop_img_bboxes
        self.is_shift_pic_bboxes = is_shift_pic_bboxes
        self.is_filp_pic_bboxes = is_filp_pic_bboxes

    # ----1.加噪声---- #
    def _addNoise(self, img):
        '''
        输入:
            img:图像array
        输出:
            加噪声后的图像array,由于输出的像素是在[0,1]之间,所以得乘以255
        '''
        # return cv2.GaussianBlur(img, (11, 11), 0)
        return random_noise(img, mode='gaussian', seed=int(time.time()), clip=True) * 255

    # ---2.调整亮度--- #
    def _changeLight(self, img):
        alpha = random.uniform(0.35, 1)
        blank = np.zeros(img.shape, img.dtype)
        return cv2.addWeighted(img, alpha, blank, 1 - alpha, 0)

    # ---3.cutout--- #
    def _cutout(self, img, bboxes, length=100, n_holes=1, threshold=0.5):
        '''
        原版本:https://github.com/uoguelph-mlrg/Cutout/blob/master/util/cutout.py
        Randomly mask out one or more patches from an image.
        Args:
            img : a 3D numpy array,(h,w,c)
            bboxes : 框的坐标
            n_holes (int): Number of patches to cut out of each image.
            length (int): The length (in pixels) of each square patch.
        '''

        def cal_iou(boxA, boxB):
            '''
            boxA, boxB为两个框,返回iou
            boxB为bouding box
            '''
            # determine the (x, y)-coordinates of the intersection rectangle
            xA = max(boxA[0], boxB[0])
            yA = max(boxA[1], boxB[1])
            xB = min(boxA[2], boxB[2])
            yB = min(boxA[3], boxB[3])

            if xB <= xA or yB <= yA:
                return 0.0

            # compute the area of intersection rectangle
            interArea = (xB - xA + 1) * (yB - yA + 1)

            # compute the area of both the prediction and ground-truth
            # rectangles
            boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
            boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)
            iou = interArea / float(boxBArea)
            return iou

        # 得到h和w
        if img.ndim == 3:
            h, w, c = img.shape
        else:
            _, h, w, c = img.shape
        mask = np.ones((h, w, c), np.float32)
        for n in range(n_holes):
            chongdie = True  # 看切割的区域是否与box重叠太多
            while chongdie:
                y = np.random.randint(h)
                x = np.random.randint(w)

                y1 = np.clip(y - length // 2, 0,
                             h)  # numpy.clip(a, a_min, a_max, out=None), clip这个函数将将数组中的元素限制在a_min, a_max之间,大于a_max的就使得它等于 a_max,小于a_min,的就使得它等于a_min
                y2 = np.clip(y + length // 2, 0, h)
                x1 = np.clip(x - length // 2, 0, w)
                x2 = np.clip(x + length // 2, 0, w)

                chongdie = False
                for box in bboxes:
                    if cal_iou([x1, y1, x2, y2], box) > threshold:
                        chongdie = True
                        break
            mask[y1: y2, x1: x2, :] = 0.
        img = img * mask
        return img

    # ---4.旋转--- #
    def _rotate_img_bbox(self, img, bboxes, angle=5, scale=1.):
        '''
        参考:https://blog.csdn.net/u014540717/article/details/53301195crop_rate
        输入:
            img:图像array,(h,w,c)
            bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
            angle:旋转角度
            scale:默认1
        输出:
            rot_img:旋转后的图像array
            rot_bboxes:旋转后的boundingbox坐标list
        '''
        # 旋转图像
        w = img.shape[1]
        h = img.shape[0]
        # 角度变弧度
        rangle = np.deg2rad(angle)  # angle in radians
        # now calculate new image width and height
        nw = (abs(np.sin(rangle) * h) + abs(np.cos(rangle) * w)) * scale
        nh = (abs(np.cos(rangle) * h) + abs(np.sin(rangle) * w)) * scale
        # ask OpenCV for the rotation matrix
        rot_mat = cv2.getRotationMatrix2D((nw * 0.5, nh * 0.5), angle, scale)
        # calculate the move from the old center to the new center combined
        # with the rotation
        rot_move = np.dot(rot_mat, np.array([(nw - w) * 0.5, (nh - h) * 0.5, 0]))
        # the move only affects the translation, so update the translation
        rot_mat[0, 2] += rot_move[0]
        rot_mat[1, 2] += rot_move[1]
        # 仿射变换
        rot_img = cv2.warpAffine(img, rot_mat, (int(math.ceil(nw)), int(math.ceil(nh))), flags=cv2.INTER_LANCZOS4)

        # 矫正bbox坐标
        # rot_mat是最终的旋转矩阵
        # 获取原始bbox的四个中点,然后将这四个点转换到旋转后的坐标系下
        rot_bboxes = list()
        for bbox in bboxes:
            xmin = bbox[0]
            ymin = bbox[1]
            xmax = bbox[2]
            ymax = bbox[3]
            point1 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymin, 1]))
            point2 = np.dot(rot_mat, np.array([xmax, (ymin + ymax) / 2, 1]))
            point3 = np.dot(rot_mat, np.array([(xmin + xmax) / 2, ymax, 1]))
            point4 = np.dot(rot_mat, np.array([xmin, (ymin + ymax) / 2, 1]))
            # 合并np.array
            concat = np.vstack((point1, point2, point3, point4))
            # 改变array类型
            concat = concat.astype(np.int32)
            # 得到旋转后的坐标
            rx, ry, rw, rh = cv2.boundingRect(concat)
            rx_min = rx
            ry_min = ry
            rx_max = rx + rw
            ry_max = ry + rh
            # 加入list中
            rot_bboxes.append([rx_min, ry_min, rx_max, ry_max])

        return rot_img, rot_bboxes

    # ---5.裁剪--- #
    def _crop_img_bboxes(self, img, bboxes):
        '''
        裁剪后的图片要包含所有的框
        输入:
            img:图像array
            bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
        输出:
            crop_img:裁剪后的图像array
            crop_bboxes:裁剪后的bounding box的坐标list
        '''
        # 裁剪图像
        w = img.shape[1]
        h = img.shape[0]
        x_min = w  # 裁剪后的包含所有目标框的最小的框
        x_max = 0
        y_min = h
        y_max = 0
        for bbox in bboxes:
            x_min = min(x_min, bbox[0])
            y_min = min(y_min, bbox[1])
            x_max = max(x_max, bbox[2])
            y_max = max(y_max, bbox[3])

        d_to_left = x_min  # 包含所有目标框的最小框到左边的距离
        d_to_right = w - x_max  # 包含所有目标框的最小框到右边的距离
        d_to_top = y_min  # 包含所有目标框的最小框到顶端的距离
        d_to_bottom = h - y_max  # 包含所有目标框的最小框到底部的距离

        # 随机扩展这个最小框
        crop_x_min = int(x_min - random.uniform(0, d_to_left))
        crop_y_min = int(y_min - random.uniform(0, d_to_top))
        crop_x_max = int(x_max + random.uniform(0, d_to_right))
        crop_y_max = int(y_max + random.uniform(0, d_to_bottom))

        # 随机扩展这个最小框 , 防止别裁的太小
        # crop_x_min = int(x_min - random.uniform(d_to_left//2, d_to_left))
        # crop_y_min = int(y_min - random.uniform(d_to_top//2, d_to_top))
        # crop_x_max = int(x_max + random.uniform(d_to_right//2, d_to_right))
        # crop_y_max = int(y_max + random.uniform(d_to_bottom//2, d_to_bottom))

        # 确保不要越界
        crop_x_min = max(0, crop_x_min)
        crop_y_min = max(0, crop_y_min)
        crop_x_max = min(w, crop_x_max)
        crop_y_max = min(h, crop_y_max)

        crop_img = img[crop_y_min:crop_y_max, crop_x_min:crop_x_max]

        # 裁剪boundingbox
        # 裁剪后的boundingbox坐标计算
        crop_bboxes = list()
        for bbox in bboxes:
            crop_bboxes.append([bbox[0] - crop_x_min, bbox[1] - crop_y_min, bbox[2] - crop_x_min, bbox[3] - crop_y_min])

        return crop_img, crop_bboxes

    # ---6.平移--- #
    def _shift_pic_bboxes(self, img, bboxes):
        '''
        平移后的图片要包含所有的框
        输入:
            img:图像array
            bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
        输出:
            shift_img:平移后的图像array
            shift_bboxes:平移后的bounding box的坐标list
        '''
        # 平移图像
        w = img.shape[1]
        h = img.shape[0]
        x_min = w  # 裁剪后的包含所有目标框的最小的框
        x_max = 0
        y_min = h
        y_max = 0
        for bbox in bboxes:
            x_min = min(x_min, bbox[0])
            y_min = min(y_min, bbox[1])
            x_max = max(x_max, bbox[2])
            y_max = max(y_max, bbox[3])

        d_to_left = x_min  # 包含所有目标框的最大左移动距离
        d_to_right = w - x_max  # 包含所有目标框的最大右移动距离
        d_to_top = y_min  # 包含所有目标框的最大上移动距离
        d_to_bottom = h - y_max  # 包含所有目标框的最大下移动距离

        x = random.uniform(-(d_to_left - 1) / 3, (d_to_right - 1) / 3)
        y = random.uniform(-(d_to_top - 1) / 3, (d_to_bottom - 1) / 3)

        M = np.float32([[1, 0, x], [0, 1, y]])  # x为向左或右移动的像素值,正为向右负为向左; y为向上或者向下移动的像素值,正为向下负为向上
        shift_img = cv2.warpAffine(img, M, (img.shape[1], img.shape[0]))

        #  平移boundingbox
        shift_bboxes = list()
        for bbox in bboxes:
            shift_bboxes.append([bbox[0] + x, bbox[1] + y, bbox[2] + x, bbox[3] + y])

        return shift_img, shift_bboxes

    # ---7.镜像--- #
    def _filp_pic_bboxes(self, img, bboxes):
        '''
            平移后的图片要包含所有的框
            输入:
                img:图像array
                bboxes:该图像包含的所有boundingboxs,一个list,每个元素为[x_min, y_min, x_max, y_max],要确保是数值
            输出:
                flip_img:平移后的图像array
                flip_bboxes:平移后的bounding box的坐标list
        '''
        # 翻转图像

        flip_img = copy.deepcopy(img)
        h, w, _ = img.shape

        sed = random.random()

        if 0 < sed < 0.33:  # 0.33的概率水平翻转,0.33的概率垂直翻转,0.33是对角反转
            flip_img = cv2.flip(flip_img, 0)  # _flip_x
            inver = 0
        elif 0.33 < sed < 0.66:
            flip_img = cv2.flip(flip_img, 1)  # _flip_y
            inver = 1
        else:
            flip_img = cv2.flip(flip_img, -1)  # flip_x_y
            inver = -1

        # 调整boundingbox
        flip_bboxes = list()
        for box in bboxes:
            x_min = box[0]
            y_min = box[1]
            x_max = box[2]
            y_max = box[3]

            if inver == 0:
                # 0:垂直翻转
                flip_bboxes.append([x_min, h - y_max, x_max, h - y_min])
            elif inver == 1:
                # 1:水平翻转
                flip_bboxes.append([w - x_max, y_min, w - x_min, y_max])
            elif inver == -1:
                # -1:水平垂直翻转
                flip_bboxes.append([w - x_max, h - y_max, w - x_min, h - y_min])
        return flip_img, flip_bboxes

    # 图像增强方法
    def dataAugment(self, img, bboxes):
        '''
        图像增强
        输入:
            img:图像array
            bboxes:该图像的所有框坐标
        输出:
            img:增强后的图像
            bboxes:增强后图片对应的box
        '''
        change_num = 0  # 改变的次数
        # print('------')
        while change_num < 1:  # 默认至少有一种数据增强生效

            if self.is_rotate_img_bbox:
                if random.random() > self.rotation_rate:  # 旋转
                    change_num += 1
                    angle = random.uniform(-self.max_rotation_angle, self.max_rotation_angle)
                    scale = random.uniform(0.7, 0.8)
                    img, bboxes = self._rotate_img_bbox(img, bboxes, angle, scale)

            if self.is_shift_pic_bboxes:
                if random.random() < self.shift_rate:  # 平移
                    change_num += 1
                    img, bboxes = self._shift_pic_bboxes(img, bboxes)

            if self.is_changeLight:
                if random.random() > self.change_light_rate:  # 改变亮度
                    change_num += 1
                    img = self._changeLight(img)

            if self.is_addNoise:
                if random.random() < self.add_noise_rate:  # 加噪声
                    change_num += 1
                    img = self._addNoise(img)
            if self.is_cutout:
                if random.random() < self.cutout_rate:  # cutout
                    change_num += 1
                    img = self._cutout(img, bboxes, length=self.cut_out_length, n_holes=self.cut_out_holes,
                                       threshold=self.cut_out_threshold)
            if self.is_filp_pic_bboxes:
                if random.random() < self.flip_rate:  # 翻转
                    change_num += 1
                    img, bboxes = self._filp_pic_bboxes(img, bboxes)

        return img, bboxes


# xml解析工具
class ToolHelper():
    # 从xml文件中提取bounding box信息, 格式为[[x_min, y_min, x_max, y_max, name]]
    def parse_xml(self, path):
        '''
        输入:
            xml_path: xml的文件路径
        输出:
            从xml文件中提取bounding box信息, 格式为[[x_min, y_min, x_max, y_max, name]]
        '''
        tree = ET.parse(path)
        root = tree.getroot()
        objs = root.findall('object')
        coords = list()
        for ix, obj in enumerate(objs):
            name = obj.find('name').text
            box = obj.find('bndbox')
            x_min = int(box[0].text)
            y_min = int(box[1].text)
            x_max = int(box[2].text)
            y_max = int(box[3].text)
            coords.append([x_min, y_min, x_max, y_max, name])
        return coords

    # 保存图片结果
    def save_img(self, file_name, save_folder, img):
        cv2.imwrite(os.path.join(save_folder, file_name), img)

    # 保持xml结果
    def save_xml(self, file_name, save_folder, img_info, height, width, channel, bboxs_info):
        '''
        :param file_name:文件名
        :param save_folder:#保存的xml文件的结果
        :param height:图片的信息
        :param width:图片的宽度
        :param channel:通道
        :return:
        '''
        folder_name, img_name = img_info  # 得到图片的信息

        E = objectify.ElementMaker(annotate=False)

        anno_tree = E.annotation(
            E.folder(folder_name),
            E.filename(img_name),
            E.path(os.path.join(folder_name, img_name)),
            E.source(
                E.database('Unknown'),
            ),
            E.size(
                E.width(width),
                E.height(height),
                E.depth(channel)
            ),
            E.segmented(0),
        )

        labels, bboxs = bboxs_info  # 得到边框和标签信息
        for label, box in zip(labels, bboxs):
            anno_tree.append(
                E.object(
                    E.name(label),
                    E.pose('Unspecified'),
                    E.truncated('0'),
                    E.difficult('0'),
                    E.bndbox(
                        E.xmin(box[0]),
                        E.ymin(box[1]),
                        E.xmax(box[2]),
                        E.ymax(box[3])
                    )
                ))

        etree.ElementTree(anno_tree).write(os.path.join(save_folder, file_name), pretty_print=True)


if __name__ == '__main__':

    need_aug_num = 5  # 每张图片需要增强的次数

    is_endwidth_dot = True  # 文件是否以.jpg或者png结尾

    dataAug = DataAugmentForObjectDetection()  # 数据增强工具类

    toolhelper = ToolHelper()  # 工具

    # 获取相关参数
    parser = argparse.ArgumentParser()
    parser.add_argument('--source_img_path', type=str, default='images')
    parser.add_argument('--source_xml_path', type=str, default='Annotations')
    parser.add_argument('--save_img_path', type=str, default='enhance_images')
    parser.add_argument('--save_xml_path', type=str, default='enhance_Annotations')
    args = parser.parse_args()
    source_img_path = args.source_img_path  # 图片原始位置
    source_xml_path = args.source_xml_path  # xml的原始位置

    save_img_path = args.save_img_path  # 图片增强结果保存文件
    save_xml_path = args.save_xml_path  # xml增强结果保存文件

    # 如果保存文件夹不存在就创建
    if not os.path.exists(save_img_path):
        os.mkdir(save_img_path)

    if not os.path.exists(save_xml_path):
        os.mkdir(save_xml_path)

    for parent, _, files in os.walk(source_img_path):
        files.sort()
        for file in files:
            cnt = 0
            pic_path = os.path.join(parent, file)
            xml_path = os.path.join(source_xml_path, file[:-4] + '.xml')
            values = toolhelper.parse_xml(xml_path)  # 解析得到box信息,格式为[[x_min,y_min,x_max,y_max,name]]
            coords = [v[:4] for v in values]  # 得到框
            labels = [v[-1] for v in values]  # 对象的标签

            # 如果图片是有后缀的
            if is_endwidth_dot:
                # 找到文件的最后名字
                dot_index = file.rfind('.')
                _file_prefix = file[:dot_index]  # 文件名的前缀
                _file_suffix = file[dot_index:]  # 文件名的后缀
            img = cv2.imread(pic_path)

            # show_pic(img, coords)  # 显示原图
            while cnt < need_aug_num:  # 继续增强
                auged_img, auged_bboxes = dataAug.dataAugment(img, coords)
                auged_bboxes_int = np.array(auged_bboxes).astype(np.int32)
                height, width, channel = auged_img.shape  # 得到图片的属性
                img_name = '{}_{}{}'.format(_file_prefix, cnt + 1, _file_suffix)  # 图片保存的信息
                toolhelper.save_img(img_name, save_img_path,
                                    auged_img)  # 保存增强图片

                toolhelper.save_xml('{}_{}.xml'.format(_file_prefix, cnt + 1),
                                    save_xml_path, (save_img_path, img_name), height, width, channel,
                                    (labels, auged_bboxes_int))  # 保存xml文件
                # show_pic(auged_img, auged_bboxes)  # 强化后的图
                print(img_name)
                cnt += 1  # 继续增强下一张

增强后的效果图如下所示:
在这里插入图片描述

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:/a/168565.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

练习六-使用Questasim来用verilog使用function函数

[TOC](使用Questasim来用verilog使用function函数 1&#xff0c;verilog中使用函数function2&#xff0c;RTL代码3&#xff0c;测试代码4&#xff0c;输出波形 1&#xff0c;verilog中使用函数function 目的&#xff1a; &#xff08;1&#xff09;了解函数的定义和在模块设计中…

欧拉操作系统下离线安装字体的操作步骤

背景 某 Web 应用部署到欧拉操作系统后&#xff0c;应用中导出的 PDF 文件中文全部显示乱码&#xff0c;原因是字体缺失&#xff0c;但是目标系统上并没有联网&#xff0c;必须找到字体的离线安装包。 CSDN 上还有40个积分&#xff0c;下载了两个相关的资源后&#xff0c;目标…

目标检测框存在内嵌情况分析与解决

这里写目录标题 问题描述原因分析与解决方法&#xff1a;后续及思考参考文档 问题描述 目标检测模型输出的检测框存在内嵌情况。 原因分析与解决方法&#xff1a; 根据经验&#xff0c;第一感觉是后处理nms部分出了问题。来看下对应的代码&#xff1a; static float CalcIou…

GaussDB SQL基础语法示例-GOTO语句

目录 一、前言 二、在GaussDB数据库中的概念及语法 1、基本概念 2、语法 三、在GaussDB数据库中的基础示例和限制场景说明 1、基础示例 2、限制场景说明 四、小结 一、前言 SQL是用于访问和处理数据库的标准计算机语言。GaussDB支持SQL标准&#xff08;默认支持SQL2、…

新版Testwell CTC++带来哪些新变化?

Testwell CTC在版本10中引入了新的工具ctcreport来直接从符号和数据文件生成HTML报告。详细的特性描述可以在测试井CTC帮助中找到。在本文档中&#xff0c;描述了与前一代报告相比的改进和变化。 Adaptable Layout可调整布局 您可以选择一个适合于项目结构的布局。布局决定了报…

已超1000+测试员分享!Python自动化测试案例实战

随着企业对测试工程师的能力要求日渐增长&#xff0c;对我们每一位测试工程师而言既是压力也是提升的动力&#xff0c;不提升就意味着没有出路&#xff0c;没有发展&#xff01;我们职业发展的命运是靠自己的能力来把握的&#xff0c;而不是一味的惧怕高要求&#xff0c;惧怕难…

vue和uni-app的递归组件排坑

有这样一个数组数据&#xff0c;实际可能有很多级。 tree: [{id: 1,name: 1,children: [{ id: 2, name: 1-1, children: [{id: 7, name: 1-1-1,children: []}]},{ id: 3, name: 1-2 }]},{id: 4,name: 2,children: [{ id: 5, name: 2-1 },{ id: 6, name: 2-2 }]} ]要渲染为下面…

KaiwuDB 监控组件及辅助 SQL 调优介绍

一、介绍 KaiwuDB 具备完善的行为数据采集功能&#xff0c;此功能要求 KaiwuDB 数据库系统 C/E/T 端不同进程的不同维度的指标采集功能十分完善&#xff1b;在不同进程完成指标采集后&#xff0c;会通过 Opentelemetry 和 Collector 将指标存入 Prometheus&#xff0c;以便查找…

单脉冲测角-和差比幅法-方向图传播因子-函数编写

方向图传播因子-函数编写 和差比幅法单脉冲测角原理代码仿真结果参数说明 和差比幅法单脉冲测角原理 有关单脉冲测角和差比幅法的原理已经在博文单脉冲测角-和差比幅法中详细介绍了&#xff0c;我们在实际仿真的时候&#xff0c;往往需要在给定来波方向下方向图转化因子&#…

安防视频监控平台EasyCVR服务器部署后出现报错,导致无法级联到域名服务器,该如何解决?

视频监控平台EasyCVR能在复杂的网络环境中&#xff0c;将分散的各类视频资源进行统一汇聚、整合、集中管理&#xff0c;在视频监控播放上&#xff0c;安防监控平台可支持1、4、9、16个画面窗口播放&#xff0c;可同时播放多路视频流&#xff0c;也能支持视频定时轮播。视频监控…

005 OpenCV直方图

目录 一、环境 二、直方图原理概述 三、代码 一、环境 本文使用环境为&#xff1a; Windows10Python 3.9.17opencv-python 4.8.0.74 二、直方图原理概述 OpenCV是一个广泛使用的开源计算机视觉库&#xff0c;它提供了许多用于图像处理和分析的函数和算法。其中&#xff…

虚拟机里为什么桥接模式可以广播,NAT模式不能广播?

在虚拟机网络配置中&#xff0c;桥接模式&#xff08;Bridged mode&#xff09;允许虚拟机在与主机相同的网络上作为一个独立的设备出现。这意味着虚拟机可以接收和发送广播消息&#xff0c;就像物理机器一样&#xff0c;因为它们处于同一个物理网络上。 相反&#xff0c;NAT模…

单片非晶磁性测量系统典型磁参数的不确定度与重复性

典型磁参数的不确定度与重复性 典型的测试点 最佳不确定度 ( k 2 ) 最佳重复性 损耗Ps P1.0 ④ 3.0% 1.0% P1.3 3.0% 1.0% P1.4 3.0% 1.0% P1.5 3.0% 1.0% 磁感Bm B25 ⑤ 1.0% 0.3% B50 1.0% 0.3% B80 1.0% 0.3% 单片非晶磁性测量系统测量条件 &…

著名的勃艮第葡萄酒是如何分类的?

勃艮第代表了与他们的地理位置密切相关的所有葡萄酒和葡萄酒风格&#xff0c;1936年法国根据产地对勃艮第葡萄酒进行了分类&#xff0c;勃艮第地区内的100个被批准的葡萄酒种植区被界定&#xff0c;这些地块被分为四个等级&#xff0c;最高等级代表了种植最高品质葡萄酒的最佳土…

亚马逊防关联如何做?看这一篇就够了

我们都知道亚马逊在众多跨境电商平台里属于严格的那个&#xff0c;商家们常常调侃亚马逊死法千万种&#xff0c;但最惨的还是账户被平台关联封号。有的新手刚注册还没开始就被关联封号了&#xff0c;有的业绩不错的店铺操作没注意&#xff0c;在别的地方登录了一下就被封了&…

软件定义卫星:数字卫星实践

随着巨型低轨卫星星座、卫星互联网等计划的推进&#xff0c;近年来全球卫星产业迅速发展&#xff0c;在轨卫星呈现规模化、网络化以及智能化趋势。大规模卫星系统为飞机、船舶、车辆等提供了各种各样的天基服务&#xff0c;对国防、科研、生产生活具有重要意义。 与此同时&…

Python基础:迭代器(Iterators)详解

什么是迭代器&#xff1f; 迭代器是一个可以记住遍历的位置的对象。迭代器对象从集合的第一个元素开始访问&#xff0c;直到所有的元素被访问完结束。迭代器只能往前不会后退。 1. 迭代对象 Python中使用迭代器的地方很多&#xff0c;大多数的容器对象都是可迭代对象&#xff…

Analyzing coredump file by gdb

1 coredump文件简介 在Linux中&#xff0c;当进程崩溃或异常终止时&#xff0c;系统会将进程的内存状态写入一个coredump文件中&#xff0c;这个文件包含了进程崩溃时的内存映像&#xff0c;可以用于分析进程崩溃的原因。 2 coredump文件的存储路径 执行如下指令查询coredum…

猫罐头哪个牌子质量好性价比高?推荐十款猫罐头品牌排行榜!

当前口罩形势严峻&#xff0c;海外罐头的价格飙升&#xff0c;且市场上充斥着难以辨别真伪的产品。除了海外罐头&#xff0c;您是否有其他口碑良好的国产罐头推荐&#xff1f; 作为从事了6年猫咪铲屎官的我来说&#xff0c;对于猫咪的日常饮食来源有一些经验。我也给我家的猫咪…

pikach靶场暴力破解

pikach靶场暴力破解 文章目录 pikach靶场暴力破解安装pikach靶场暴力破解第一关第二关第三关第四关 安装pikach靶场 进入github下载pikach的源码 不是linux推荐下载压缩包 下载完成后放入phpstudy中进行解压放入www网站根目录下 在数据库中新建数据库为pikachu create data…