1、前期准备
准备好目录结构、数据集和关于YOLOv1的基础认知
1.1 创建目录结构
自己创建项目目录结构,结构目录如下:
network CNN Backbone 存放位置
weights 权重存放的位置
test_images 测试用的图片
utils 辅助功能的代码存放位置models 保存模型位置
data 训练的数据集
1.2 数据集介绍与下载
1.2.1 数据集介绍
首先了解数据集,对数据集了解后方便对数据进行相应处理。数据集详细介绍直通车:https://blog.csdn.net/qq_41946216/article/details/137683750?spm=1001.2014.3001.5501
1.2.1 数据集下载
本次采用数据集: VOC2012数据集。
数据集下载方式一:http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
数据集下载方式二:
下载并构建VOC2012数据集,从:https://gitee.com/ppov-nuc/pascal-vocdataset_-for_-yolo.git, 下载get_data文件和generate_csv.py文件到本地,放到创建的目录结构中,修改get_data中下载的内容和相应路径,然后运行批处理文件get_data,在get_dat中会自动执行generate_csv.py,如下图所示。
2. 数据集处理
在utils目录下创建工具类 generate_txt_file.py,主要用于数据集的划分和解析 Annotations/xxxxx.xml 文件中的类别和bbox信息,并将信息存入voctrain.txt和voctest.txt文件,如下图所示:
具体代码:
# author: baiCai
# 1. 导包
from xml.etree import ElementTree as ET
import os
import random
# 2. 定义一些基本的参数
# 定义所有的类名
VOC_CLASSES = (
'aeroplane', 'bicycle', 'bird', 'boat',
'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse',
'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
'''
读取所有 xml 文件,存入列表
'''
# 要读取的xml文件路径,记得自己修改路径
Annotations = '../data/VOC2012/Annotations/'
# 列出所有的xml文件
xml_files = os.listdir(Annotations)
# 打乱数据集
random.shuffle(xml_files)
'''
定义训练集和测试比例
划分Annotations中的训练集和测试集文件列表
'''
# 训练集数量
train_num = int(len(xml_files) * 0.7)
# 训练列表
train_file_list = xml_files[:train_num]
# 测测试列表
test_file_list = xml_files[train_num:]
'''
定义 xml 解析后的信息存储路径和写对象
'''
# 训练集和测试集文件名字
train_set_path = './voctrain.txt'
test_set_path = './voctest.txt'
# 3. 定义解析xml文件的函数
'''
主要解析 xml 获取 类别名字和bbox,如
{'name': 'person','bbox': [174, 101, 349, 351]}
'''
def parse_rec(filename):
# 参数:输入xml文件名
# 创建xml对象
tree = ET.parse(filename)
objects = []
# 迭代读取xml文件中的object节点,即物体信息
for obj in tree.findall('object'):
obj_struct = {}
# difficult属性,即这里不需要那些难判断的对象
difficult = int(obj.find('difficult').text)
if difficult == 1: # 若为1则跳过本次循环
continue
# 开始收集信息
obj_struct['name'] = obj.find('name').text
bbox = obj.find('bndbox')
obj_struct['bbox'] =\
[int(float(bbox.find('xmin').text)),
int(float(bbox.find('ymin').text)),
int(float(bbox.find('xmax').text)),
int(float(bbox.find('ymax').text))]
objects.append(obj_struct)
return objects
# 4. 把信息保存入文件中
def write_txt(file_list,set_path):
# # 生成训练集txt
count = 0
with open(set_path, 'w') as wt:
for xml_file in file_list:
count += 1
# 获取图片名字
image_name = xml_file.split('.')[0] + '.jpg' # 图片文件名
# 对xml_file进行解析
results = parse_rec(Annotations + xml_file)
# 如果返回的对象为空,表示张图片难以检测,因此直接跳过
if len(results) == 0:
print(xml_file)
continue
# 否则,则写入文件中
# 先写入图片名字
wt.write(image_name)
# 接着指定下面写入的格式
for result in results:
class_name = result['name']
bbox = result['bbox']
class_name = VOC_CLASSES.index(class_name) # 名字在类别中是下标位置
wt.write(' ' + str(bbox[0]) +
' ' + str(bbox[1]) +
' ' + str(bbox[2]) +
' ' + str(bbox[3]) +
' ' + str(class_name))
wt.write('\n')
wt.close()
# 5. 运行
if __name__ == '__main__':
write_txt(train_file_list,train_set_path)
write_txt(test_file_list,test_set_path)
3. 构建数据加载器
3.1定义初始化方法
读取xxxx.xml解析后的文件
对每行数据(每个图片信息)的所有中心点信息以【x,y,w,h】和标签分别存入box列表和label列表。
当前图片的边界框和标签信息即box列表和label列表,转换为LongTensor格式添加到对应的boxex列表和labels列表。
3.2 定义增强图片方法
增加方法名称 | 定义的函数 |
随机翻转图片和边界框 | random_flip(img, boxes) |
随机缩放图片和边界框 | randomScale(img, boxes) |
随机模糊图片 | randomBlur(img) |
随机调整图片亮度 | RandomBrightness(img) |
随机调整图片色调 | RandomHue(img) |
随机调整图片饱和度 | RandomSaturation(img) |
随机移动图片和边界框 | randomShift(img, boxes, labels) |
随机裁剪图片和边界框 | randomCrop(img, boxes, labels) |
用于从图像中减去均值 | subMean(self, bgr, mean) |
将BGR图像转换为RGB图像 | BGR2RGB(self, img) |
将BGR图像转换为HSV图像 | BGR2HSV(self, img) |
将HSV图像转换为BGR图像 | HSV2BGR(self, img) |