一.数据集介绍
CDlA数据集介绍:CDLA
CDLA是一个中文文档版面分析数据集,面向中文文献类(论文)场景。包含以下10个label:
数据量:
共包含5000张训练集和1000张验证集,分别在train和val目录下。每张图片对应一个同名的标注文件(.json)。
数据展示:
标注工具是labelme,所以标注格式和labelme格式一致。
数据结构:
train和val里面分别存放图片及标注结果json文件
二. 数据预处理
将json文件转换成txt文件
import json
import os
import argparse
from tqdm import tqdm
import glob
import cv2
import numpy as np
def convert_label_json(json_dir,save_dir,classes):
files=os.listdir(json_dir)
#删选出json文件
jsonFiles=[]
for file in files:
if os.path.splitext(file)[1]==".json":
jsonFiles.append(file)
#获取类型
classes=classes.split(',')
#获取json对应中对应元素
for json_path in tqdm(jsonFiles):
path=os.path.join(json_dir,json_path)
with open(path,'r') as loadFile:
print(loadFile)
json_dict=json.load(loadFile)
h,w=json_dict['imageHeight'],json_dict['imageWidth']
txt_path=os.path.join(save_dir,json_path.replace('json','txt'))
txt_file=open(txt_path,'w')
for shape_dict in json_dict['shapes']:
label=shape_dict['label']
label_index=classes.index(label)
points=shape_dict['points']
points_nor_list=[]
for point in points:
points_nor_list.append(point[0]/w)
points_nor_list.append(point[1]/h)
points_nor_list=list(map(lambda x:str(x),points_nor_list))
points_nor_str=' '.join(points_nor_list)
label_str=str(label_index)+' '+points_nor_str+'\n'
txt_file.writelines(label_str)
if __name__=="__main__":
parser=argparse.ArgumentParser(description="json convert to txt params")
#设json文件所在地址
parser.add_argument('-json',type=str,default='cdla_data/label_data/val',help='json path')
#设置txt文件保存地址
parser.add_argument('-save',type=str,default='layout_analysis/cdla_data/val',help='save path')
#设置label类型,用“,”分隔
parser.add_argument('-classes',type=str,default='Header,Text,Reference,Figure caption,Figure,Table caption,Table,Title,Footer,Equation',help='classes')
args=parser.parse_args()
print(args.json,args.save,args.classes)
convert_label_json(args.json,args.save,args.classes)
三.yoloV8模型环境搭建
采用ultralytics集成的代码进行训练:ultralytics
pip install ultralytics
yolov8预训练权重:yolov8预训练权重
四.模型训练
import sys
import os
sys.path.insert(0, os.path.dirname(os.getcwd()))
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from ultralytics import YOLO
def train_model():
# 加载模型
# model = YOLO("yolov8n.yaml") # 从头开始构建新模型
#print('model load。。。')
model = YOLO("8npt/best.pt") # 加载模型
print('model load completed。。。')
# 使用模型
model.train(data="img-layout.yaml", epochs=300, device=1)# , lr0=0.0001) # 训练模型
metrics = model.val() # 在验证集上评估模型性能
print('metric : {}'.format(metrics))
# results = model("https://ultralytics.com/images/bus.jpg") # 对图像进行预测
success = model.export(format="onnx") # 将模型导出为 ONNX 格式
if __name__ == '__main__':
train_model()
以上参数解释如下:
task:选择任务类型,可选[‘detect’, ‘segment’, ‘classify’, ‘init’]
mode: 选择是训练、验证还是预测的任务蕾西 可选[‘train’, ‘val’, ‘predict’]
model: 选择yolov8不同的模型配置文件,可选yolov8s.yaml、yolov8m.yaml、yolov8l.yaml、yolov8x.yaml
data: 选择生成的数据集配置文件
epochs:指的就是训练过程中整个数据集将被迭代多少次,显卡不行你就调小点。
batch:一次看完多少张图片才进行权重更新,梯度下降的mini-batch,显卡不行你就调小点。
五.模型测试
import os
import cv2
import sys
sys.path.insert(0, os.path.dirname(os.getcwd()))
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
from ultralytics import YOLO
def infer(image_dir):
"""
检测一个文件夹下的所有图片
Args:
image_dir: 图片文件夹路径
Returns:
None
"""
model = YOLO('train2/weights/best.pt') #模型
for filename in os.listdir(image_dir):
image_path = os.path.join(image_dir, filename)
results = model(image_path)
print(results[0].plot())
cv2.imwrite('test_result_v1/' + filename, results[0].plot()) #保存地址
if __name__ == '__main__':
image_dir = 'test_data_20240304_pic/' #图片数据
infer(image_dir)
测试效果: