文章目录
- 1YoloPredictor类——检测器
- 1.1继承BasePredictor解析
- 1.2继承QObject解析
- 2MainWindow类——主窗口
在前面两篇中,篇一介绍了启动界面的制作,篇二介绍了如何修改PyDracula的界面,那么这一篇我们学习一下yolo要融合进入软件中,需要了解的两个类。
1YoloPredictor类——检测器
class YoloPredictor(BasePredictor, QObject):
# 定义信号,传递不同类型的数据给GUI或其他部分
yolo2main_pre_img = Signal(np.ndarray) # 原始图像信号(图像数据类型为NumPy数组)
yolo2main_res_img = Signal(np.ndarray) # 测试结果图像信号(图像数据类型为NumPy数组)
yolo2main_status_msg = Signal(str) # 状态消息信号(如:检测中、暂停、停止、错误等)
yolo2main_fps = Signal(str) # FPS(帧率)信号(类型为字符串,显示每秒处理的帧数)
yolo2main_labels = Signal(dict) # 检测到的目标类别及其数量(字典类型)
yolo2main_progress = Signal(int) # 检测进度信号(表示任务完成的进度百分比)
yolo2main_class_num = Signal(int) # 检测到的类别数量信号(整数类型)
yolo2main_target_num = Signal(int) # 检测到的目标数量信号(整数类型)
def __init__(self, cfg=DEFAULT_CFG, overrides=None):
super(YoloPredictor, self).__init__() # 调用父类BasePredictor的构造函数进行初始化
QObject.__init__(self) # 初始化QObject,支持信号和槽机制
self.args = get_cfg(cfg, overrides) # 获取配置信息,最终会从这里获取DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml'
''''下面这个配置信息是为了填补yaml文件的空白'''
# 设置项目路径:如果args.project为空,则使用默认路径并加上任务名称
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
# 设置模型名称:根据模式(mode)命名
name = f'{self.args.mode}'
# 设置保存目录:调用increment_path以避免覆盖现有文件夹
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
self.done_warmup = False # 标记是否完成预热阶段(如:模型加载、预处理等)
if self.args.show:
self.args.show = check_imshow(warn=True) # 如果需要显示图像,则检查imshow是否可用
# GUI相关参数初始化
self.used_model_name = None # 当前使用的模型名称
self.new_model_name = None # 实时更新的模型名称
self.source = '' # 输入源路径(如视频、图像路径)
self.stop_dtc = False # 停止检测标志
self.continue_dtc = True # 是否继续检测(用于暂停检测)
self.save_res = False # 是否保存测试结果
self.save_txt = False # 是否保存标签文件(txt格式)
self.iou_thres = 0.45 # IoU阈值(用于判断是否为目标)
self.conf_thres = 0.25 # 置信度阈值(检测到目标的置信度要求)
self.speed_thres = 10 # 延迟阈值,单位毫秒(检测的最大延迟时间)
self.labels_dict = {} # 存储检测结果的字典(类别名和数量)
self.progress_value = 0 # 进度条的当前进度(0-100)
# 以下为后续任务需要的变量初始化
self.model = None # 存储YOLO检测模型
self.data = self.args.data # 存储数据字典(包含数据集路径等)
self.imgsz = None # 输入图像的大小(例如:640x640)
self.device = None # 设备(GPU或CPU)
self.dataset = None # 数据集对象
self.vid_path, self.vid_writer = None, None # 存储视频路径和视频写入器(如果需要处理视频)
self.annotator = None # 注释器对象(用于图像标注)
self.data_path = None # 数据路径(用于加载数据)
self.source_type = None # 输入源类型(视频、图像等)
self.batch = None # 批次大小(用于批处理预测)
# 添加回调函数(用于处理检测过程中的各种操作)
self.callbacks = defaultdict(list, callbacks.default_callbacks) # 默认为空列表的回调字典
callbacks.add_integration_callbacks(self) # 将回调集成到当前对象中
# 主函数用于目标检测
@smart_inference_mode() # 装饰器,可能用于切换推理模式(例如加速或禁用某些操作)
def run(self):
try:
# 如果开启详细模式,则输出空行日志
if self.args.verbose:
LOGGER.info('')
# 设置模型
self.yolo2main_status_msg.emit('Loding Model...') # 向主界面发射模型加载的状态信息
if not self.model: # 如果没有已加载模型
self.setup_model(self.new_model_name) # 初始化并加载新模型
self.used_model_name = self.new_model_name # 设置当前使用的模型名称
# 设置输入源
self.setup_source(self.source if self.source is not None else self.args.source)
# 检查保存路径/标签文件夹
if self.save_res or self.save_txt: # 如果需要保存结果或标签
(self.save_dir / 'labels' if self.save_txt else self.save_dir).mkdir(parents=True, exist_ok=True)
# 热身模型
if not self.done_warmup: # 如果模型未热身
self.model.warmup(imgsz=(1 if self.model.pt or self.model.triton else self.dataset.bs, 3, *self.imgsz))
self.done_warmup = True # 设置热身完成标志
# 初始化状态变量
self.seen, self.windows, self.dt, self.batch = 0, [], (ops.Profile(), ops.Profile(), ops.Profile()), None
# 开始目标检测
count = 0 # 当前帧数
start_time = time.time() # 用于计算帧率的开始时间
batch = iter(self.dataset) # 将数据集转化为迭代器
while True:
# 检查是否终止检测
if self.stop_dtc:
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
self.vid_writer[-1].release() # 释放视频写入器
self.yolo2main_status_msg.emit('Detection terminated!') # 向主界面发射终止信息
break
# 检查是否需要切换模型
if self.used_model_name != self.new_model_name:
self.setup_model(self.new_model_name) # 加载新的模型
self.used_model_name = self.new_model_name # 更新当前使用的模型名称
# 检查是否暂停检测
if self.continue_dtc:
self.yolo2main_status_msg.emit('Detecting...') # 向主界面发射“正在检测”信息
batch = next(self.dataset) # 获取下一批数据
self.batch = batch # 保存当前批次
path, im, im0s, vid_cap, s = batch # 从批次中提取路径、图像、原始图像、视频捕获和其他信息
visualize = increment_path(self.save_dir / Path(path).stem,
mkdir=True) if self.args.visualize else False # 是否可视化
# 计算进度和帧率(待优化)
count += 1 # 帧计数加1
if vid_cap:
all_count = vid_cap.get(cv2.CAP_PROP_FRAME_COUNT) # 获取视频总帧数
else:
all_count = 1 # 如果不是视频,则设置总帧数为1
self.progress_value = int(count / all_count * 1000) # 更新进度条(0~1000)
if count % 5 == 0 and count >= 5: # 每5帧计算一次帧率
self.yolo2main_fps.emit(str(int(5 / (time.time() - start_time)))) # 向主界面发射每秒帧率
start_time = time.time() # 重置开始时间
# 预处理图像
with self.dt[0]:
im = self.preprocess(im) # 预处理图像
if len(im.shape) == 3: # 如果图像维度为3,则扩展批次维度
im = im[None] # 扩展为批次维度
# 推理过程
with self.dt[1]:
preds = self.model(im, augment=self.args.augment, visualize=visualize) # 进行推理
# 后处理
with self.dt[2]:
self.results = self.postprocess(preds, im, im0s) # 后处理推理结果
# 可视化、保存和写入结果
n = len(im) # 图像数量(通常是1,除非是批次推理)
for i in range(n):
# 保存速度信息
self.results[i].speed = {
'preprocess': self.dt[0].dt * 1E3 / n,
'inference': self.dt[1].dt * 1E3 / n,
'postprocess': self.dt[2].dt * 1E3 / n}
p, im0 = (path[i], im0s[i].copy()) if self.source_type.webcam or self.source_type.from_img \
else (path, im0s.copy()) # 设置路径
p = Path(p) # 转换路径为Path对象
# 获取检测框和标签
label_str = self.write_results(i, self.results, (p, im, im0)) # 获取标签字符串
# 处理标签和目标数量
class_nums = 0 # 类别数
target_nums = 0 # 目标数
self.labels_dict = {} # 标签字典
if 'no detections' in label_str: # 如果没有检测到目标
pass
else:
# 解析标签字符串,更新目标和类别数量
for ii in label_str.split(',')[:-1]:
nums, label_name = ii.split('~')
self.labels_dict[label_name] = int(nums)
target_nums += int(nums)
class_nums += 1
# 保存检测结果图像或视频
if self.save_res:
self.save_preds(vid_cap, i, str(self.save_dir / p.name))
# 向主界面发送检测结果
self.yolo2main_res_img.emit(im0) # 检测后图像
self.yolo2main_pre_img.emit(im0s if isinstance(im0s, np.ndarray) else im0s[0]) # 检测前图像
self.yolo2main_class_num.emit(class_nums) # 发送类别数量
self.yolo2main_target_num.emit(target_nums) # 发送目标数量
if self.speed_thres != 0:
time.sleep(self.speed_thres / 1000) # 延时控制(单位:毫秒)
# 更新进度条
self.yolo2main_progress.emit(self.progress_value)
# 检测完成
if count + 1 >= all_count:
if isinstance(self.vid_writer[-1], cv2.VideoWriter):
self.vid_writer[-1].release() # 释放视频写入器
self.yolo2main_status_msg.emit('Detection completed') # 向主界面发射“检测完成”信息
break # 退出循环
except Exception as e: # 捕获异常
pass # 忽略异常
print(e) # 打印异常
self.yolo2main_status_msg.emit('%s' % e) # 向主界面发射异常信息
def get_annotator(self, img):
return Annotator(img, line_width=self.args.line_thickness, example=str(self.model.names))
def preprocess(self, img):
img = torch.from_numpy(img).to(self.model.device)
img = img.half() if self.model.fp16 else img.float() # uint8 to fp16/32
img /= 255 # 0 - 255 to 0.0 - 1.0
return img
def postprocess(self, preds, img, orig_img):
### important
preds = ops.non_max_suppression(preds,
self.conf_thres,
self.iou_thres,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes)
results = []
for i, pred in enumerate(preds):
orig_img = orig_img[i] if isinstance(orig_img, list) else orig_img
shape = orig_img.shape
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], shape).round()
path, _, _, _, _ = self.batch
img_path = path[i] if isinstance(path, list) else path
results.append(Results(orig_img=orig_img, path=img_path, names=self.model.names, boxes=pred))
# print(results)
return results
def write_results(self, idx, results, batch):
p, im, im0 = batch
log_string = ''
if len(im.shape) == 3:
im = im[None] # expand for batch dim
self.seen += 1
imc = im0.copy() if self.args.save_crop else im0
if self.source_type.webcam or self.source_type.from_img: # batch_size >= 1 # attention
log_string += f'{idx}: '
frame = self.dataset.count
else:
frame = getattr(self.dataset, 'frame', 0)
self.data_path = p
self.txt_path = str(self.save_dir / 'labels' / p.stem) + ('' if self.dataset.mode == 'image' else f'_{frame}')
# log_string += '%gx%g ' % im.shape[2:] # !!! don't add img size~
self.annotator = self.get_annotator(im0)
det = results[idx].boxes # TODO: make boxes inherit from tensors
if len(det) == 0:
return f'{log_string}(no detections), ' # if no, send this~~
for c in det.cls.unique():
n = (det.cls == c).sum() # detections per class
log_string += f"{n}~{self.model.names[int(c)]}," # {'s' * (n > 1)}, " # don't add 's'
# now log_string is the classes 👆
# write
for d in reversed(det):
cls, conf = d.cls.squeeze(), d.conf.squeeze()
if self.save_txt: # Write to file
line = (cls, *(d.xywhn.view(-1).tolist()), conf) \
if self.args.save_conf else (cls, *(d.xywhn.view(-1).tolist())) # label format
with open(f'{self.txt_path}.txt', 'a') as f:
f.write(('%g ' * len(line)).rstrip() % line + '\n')
if self.save_res or self.args.save_crop or self.args.show or True: # Add bbox to image(must)
c = int(cls) # integer class
name = f'id:{int(d.id.item())} {self.model.names[c]}' if d.id is not None else self.model.names[c]
label = None if self.args.hide_labels else (name if self.args.hide_conf else f'{name} {conf:.2f}')
self.annotator.box_label(d.xyxy.squeeze(), label, color=colors(c, True))
if self.args.save_crop:
save_one_box(d.xyxy,
imc,
file=self.save_dir / 'crops' / self.model.model.names[c] / f'{self.data_path.stem}.jpg',
BGR=True)
return log_string
class YoloPredictor(BasePredictor, QObject):
这里的继承了BasePredictor和QObject。- 自定义的一些参数
'''GUI相关参数初始化'''
self.used_model_name = None # 当前使用的模型名称
self.new_model_name = None # 实时更新的模型名称
self.source = '' # 输入源路径(如视频、图像路径)
self.stop_dtc = False # 停止检测标志
self.continue_dtc = True # 是否继续检测(用于暂停检测)
self.save_res = False # 是否保存测试结果
self.save_txt = False # 是否保存标签文件(txt格式)
self.iou_thres = 0.45 # IoU阈值(用于判断是否为目标)
self.conf_thres = 0.25 # 置信度阈值(检测到目标的置信度要求)
self.speed_thres = 10 # 延迟阈值,单位毫秒(检测的最大延迟时间)
self.labels_dict = {} # 存储检测结果的字典(类别名和数量)
self.progress_value = 0 # 进度条的当前进度(0-100)
# 以下为后续任务需要的变量初始化
self.model = None # 存储YOLO检测模型
self.data = self.args.data # 存储数据字典(包含数据集路径等)
self.imgsz = None # 输入图像的大小(例如:640x640)
self.device = None # 设备(GPU或CPU)
self.dataset = None # 数据集对象
self.vid_path, self.vid_writer = None, None # 存储视频路径和视频写入器(如果需要处理视频)
self.annotator = None # 注释器对象(用于图像标注)
self.data_path = None # 数据路径(用于加载数据)
self.source_type = None # 输入源类型(视频、图像等)
self.batch = None # 批次大小(用于批处理预测)入代码片
1.1继承BasePredictor解析
- BasePredictor类是 YOLOv8 目标检测/分类/分割的推理代码,封装了从 数据输入 → 预处理 → 模型推理 → 后处理 → 结果可视化/保存的完整流程。
- 继承该函数之后需要对其中与处理函数、绘图函数、结果输出函数、后处理函数进行重写。整体代码中也包含了这三个函数的重写。
- 完善YoloPredictor检测器的参数
get_cfg函数会从ROOT / 'yolo/cfg/default.yaml'将默认参数去除,但是这个文件中有很多参数是空白,需要之后填写的。
super(YoloPredictor, self).__init__() # 调用父类BasePredictor的构造函数进行初始化
QObject.__init__(self) # 初始化QObject,支持信号和槽机制
self.args = get_cfg(cfg, overrides) # 获取配置信息,最终会从这里获取DEFAULT_CFG_PATH = ROOT / 'yolo/cfg/default.yaml'
''''下面这个配置信息是为了填补yaml文件的空白'''
# 设置项目路径:如果args.project为空,则使用默认路径并加上任务名称
project = self.args.project or Path(SETTINGS['runs_dir']) / self.args.task
# 设置模型名称:根据模式(mode)命名
name = f'{self.args.mode}'
# 设置保存目录:调用increment_path以避免覆盖现有文件夹
self.save_dir = increment_path(Path(project) / name, exist_ok=self.args.exist_ok)
self.done_warmup = False # 标记是否完成预热阶段(如:模型加载、预处理等)
if self.args.show:
self.args.show = check_imshow(warn=True) # 如果需要显示图像,则检查imshow是否可用
1.2继承QObject解析
- YoloPredictor需要使用信号与槽函数。
yolo2main_pre_img = Signal(np.ndarray) # 原始图像信号(图像数据类型为NumPy数组)
yolo2main_res_img = Signal(np.ndarray) # 测试结果图像信号(图像数据类型为NumPy数组)
yolo2main_status_msg = Signal(str) # 状态消息信号(如:检测中、暂停、停止、错误等)
yolo2main_fps = Signal(str) # FPS(帧率)信号(类型为字符串,显示每秒处理的帧数)
yolo2main_labels = Signal(dict) # 检测到的目标类别及其数量(字典类型)
yolo2main_progress = Signal(int) # 检测进度信号(表示任务完成的进度百分比)
yolo2main_class_num = Signal(int) # 检测到的类别数量信号(整数类型)
yolo2main_target_num = Signal(int) # 检测到的目标数量信号(整数类型)
- 咋们单独看一下每个信号都在做些什么?
- yolo2main_pre_img
#向主界面发送检测结果
self.yolo2main_pre_img.emit(im0s if isinstance(im0s, np.ndarray) else im0s[0])#检测前图像
#信号触发后x会被显示出来,x 是 yolo2main_pre_img 这个信号 emit() 时传递的参数。
self.yolo_predict.yolo2main_pre_img.connect(lambda x: self.show_image(x, self.pre_video))
- yolo2main_res_img
#向主界面发送检测结果
self.yolo2main_res_img.emit(im0) # 检测后图像
#信号触发后x会被显示出来,x 是 yolo2main_res_img 这个信号 emit() 时传递的参数。
self.yolo_predict.yolo2main_res_img.connect(lambda x: self.show_image(x, self.res_video))
- yolo2main_status_msg
self.yolo2main_status_msg.emit('Loding Model...') # 向主界面发射模型加载的状态信息
self.yolo2main_status_msg.emit('Detection terminated!') # 向主界面发射终止信息
self.yolo2main_status_msg.emit('Detecting...') # 向主界面发射“正在检测”信息
self.yolo2main_status_msg.emit('Detection completed') # 向主界面发射“检测完成”信息
self.yolo2main_status_msg.emit('%s' % e) # 向主界面发射异常信息
#信号触发后x会被显示出来,x 是 yolo2main_status_msg这个信号 emit() 时传递的参数。
self.yolo_predict.yolo2main_status_msg.connect(lambda x: self.show_status(x))
- yolo2main_fps
#向主界面发送
self.yolo2main_fps.emit(str(int(5 / (time.time() - start_time)))) # 向主界面发射每秒帧率
#被主界面接收
self.yolo_predict.yolo2main_fps.connect(lambda x: self.fps_label.setText(x))
- yolo2main_progress
# 更新进度条
self.yolo2main_progress.emit(self.progress_value)
#主界面设置进度值
self.yolo_predict.yolo2main_progress.connect(lambda x: self.progress_bar.setValue(x))
- yolo2main_class_num
self.yolo2main_class_num.emit(class_nums) # 发送类别数量
self.yolo_predict.yolo2main_class_num.connect(lambda x:self.Class_num.setText(str(x)))
- yolo2main_target_num
self.yolo2main_target_num.emit(target_nums) # 发送目标数量
self.yolo_predict.yolo2main_target_num.connect(lambda x:self.Target_num.setText(str(x)))
至此,检测器的剖析已经到位了。
2MainWindow类——主窗口
class MainWindow(QMainWindow, Ui_MainWindow):
main2yolo_begin_sgl = Signal() # The main window sends an execution signal to the yolo instance
def __init__(self, parent=None):
super(MainWindow, self).__init__(parent)
# basic interface
self.setupUi(self)
self.setAttribute(Qt.WA_TranslucentBackground) # rounded transparent
self.setWindowFlags(Qt.FramelessWindowHint) # Set window flag: hide window borders
UIFuncitons.uiDefinitions(self)
# Show module shadows
UIFuncitons.shadow_style(self, self.Class_QF, QColor(162,129,247))
UIFuncitons.shadow_style(self, self.Target_QF, QColor(251, 157, 139))
UIFuncitons.shadow_style(self, self.Fps_QF, QColor(170, 128, 213))
UIFuncitons.shadow_style(self, self.Model_QF, QColor(64, 186, 193))
# read model folder
self.pt_list = os.listdir('./models')
self.pt_list = [file for file in self.pt_list if file.endswith('.pt')]
self.pt_list.sort(key=lambda x: os.path.getsize('./models/' + x)) # sort by file size
self.model_box.clear()
self.model_box.addItems(self.pt_list)
self.Qtimer_ModelBox = QTimer(self) # Timer: Monitor model file changes every 2 seconds
self.Qtimer_ModelBox.timeout.connect(self.ModelBoxRefre)
self.Qtimer_ModelBox.start(2000)
# Yolo-v8 thread
self.yolo_predict = YoloPredictor() # Create a Yolo instance
self.select_model = self.model_box.currentText() # default model
self.yolo_predict.new_model_name = "./models/%s" % self.select_model
self.yolo_thread = QThread() # Create yolo thread
self.yolo_predict.yolo2main_pre_img.connect(lambda x: self.show_image(x, self.pre_video))
self.yolo_predict.yolo2main_res_img.connect(lambda x: self.show_image(x, self.res_video))
self.yolo_predict.yolo2main_status_msg.connect(lambda x: self.show_status(x))
self.yolo_predict.yolo2main_fps.connect(lambda x: self.fps_label.setText(x))
# self.yolo_predict.yolo2main_labels.connect(self.show_labels)
self.yolo_predict.yolo2main_class_num.connect(lambda x:self.Class_num.setText(str(x)))
self.yolo_predict.yolo2main_target_num.connect(lambda x:self.Target_num.setText(str(x)))
self.yolo_predict.yolo2main_progress.connect(lambda x: self.progress_bar.setValue(x))
self.main2yolo_begin_sgl.connect(self.yolo_predict.run)
self.yolo_predict.moveToThread(self.yolo_thread)
# Model parameters
self.model_box.currentTextChanged.connect(self.change_model)
self.iou_spinbox.valueChanged.connect(lambda x:self.change_val(x, 'iou_spinbox')) # iou box
self.iou_slider.valueChanged.connect(lambda x:self.change_val(x, 'iou_slider')) # iou scroll bar
self.conf_spinbox.valueChanged.connect(lambda x:self.change_val(x, 'conf_spinbox')) # conf box
self.conf_slider.valueChanged.connect(lambda x:self.change_val(x, 'conf_slider')) # conf scroll bar
self.speed_spinbox.valueChanged.connect(lambda x:self.change_val(x, 'speed_spinbox'))# speed box
self.speed_slider.valueChanged.connect(lambda x:self.change_val(x, 'speed_slider')) # speed scroll bar
# Prompt window initialization
self.Class_num.setText('--')
self.Target_num.setText('--')
self.fps_label.setText('--')
self.Model_name.setText(self.select_model)
# Select detection source
self.src_file_button.clicked.connect(self.open_src_file) # select local file
# self.src_cam_button.clicked.connect(self.show_status("The function has not yet been implemented."))#chose_cam
# self.src_rtsp_button.clicked.connect(self.show_status("The function has not yet been implemented."))#chose_rtsp
# start testing button
self.run_button.clicked.connect(self.run_or_continue) # pause/start
self.stop_button.clicked.connect(self.stop) # termination
# Other function buttons
self.save_res_button.toggled.connect(self.is_save_res) # save image option
self.save_txt_button.toggled.connect(self.is_save_txt) # Save label option
self.ToggleBotton.clicked.connect(lambda: UIFuncitons.toggleMenu(self, True)) # left navigation button
self.settings_button.clicked.connect(lambda: UIFuncitons.settingBox(self, True)) # top right settings button
# initialization
self.load_config()
# The main window displays the original image and detection results
@staticmethod
def show_image(img_src, label):
try:
ih, iw, _ = img_src.shape
w = label.geometry().width()
h = label.geometry().height()
# keep the original data ratio
if iw/w > ih/h:
scal = w / iw
nw = w
nh = int(scal * ih)
img_src_ = cv2.resize(img_src, (nw, nh))
else:
scal = h / ih
nw = int(scal * iw)
nh = h
img_src_ = cv2.resize(img_src, (nw, nh))
frame = cv2.cvtColor(img_src_, cv2.COLOR_BGR2RGB)
img = QImage(frame.data, frame.shape[1], frame.shape[0], frame.shape[2] * frame.shape[1],
QImage.Format_RGB888)
label.setPixmap(QPixmap.fromImage(img))
except Exception as e:
print(repr(e))
# Control start/pause
def run_or_continue(self):
if self.yolo_predict.source == '':
self.show_status('Please select a video source before starting detection...')
self.run_button.setChecked(False)
else:
self.yolo_predict.stop_dtc = False
if self.run_button.isChecked():
self.run_button.setChecked(True) # start button
self.save_txt_button.setEnabled(False) # It is forbidden to check and save after starting the detection
self.save_res_button.setEnabled(False)
self.show_status('Detecting...')
self.yolo_predict.continue_dtc = True # Control whether Yolo is paused
if not self.yolo_thread.isRunning():
self.yolo_thread.start()
self.main2yolo_begin_sgl.emit()
else:
self.yolo_predict.continue_dtc = False
self.show_status("Pause...")
self.run_button.setChecked(False) # start button
# bottom status bar information
def show_status(self, msg):
self.status_bar.setText(msg)
if msg == 'Detection completed' or msg == '检测完成':
self.save_res_button.setEnabled(True)
self.save_txt_button.setEnabled(True)
self.run_button.setChecked(False)
self.progress_bar.setValue(0)
if self.yolo_thread.isRunning():
self.yolo_thread.quit() # end process
elif msg == 'Detection terminated!' or msg == '检测终止':
self.save_res_button.setEnabled(True)
self.save_txt_button.setEnabled(True)
self.run_button.setChecked(False)
self.progress_bar.setValue(0)
if self.yolo_thread.isRunning():
self.yolo_thread.quit() # end process
self.pre_video.clear() # clear image display
self.res_video.clear()
self.Class_num.setText('--')
self.Target_num.setText('--')
self.fps_label.setText('--')
# select local file
def open_src_file(self):
config_file = 'config/fold.json'
config = json.load(open(config_file, 'r', encoding='utf-8'))
open_fold = config['open_fold']
if not os.path.exists(open_fold):
open_fold = os.getcwd()
name, _ = QFileDialog.getOpenFileName(self, 'Video/image', open_fold, "Pic File(*.mp4 *.mkv *.avi *.flv *.jpg *.png)")
if name:
self.yolo_predict.source = name
self.show_status('Load File:{}'.format(os.path.basename(name)))
config['open_fold'] = os.path.dirname(name)
config_json = json.dumps(config, ensure_ascii=False, indent=2)
with open(config_file, 'w', encoding='utf-8') as f:
f.write(config_json)
self.stop()
# Select camera source---- have one bug
def chose_cam(self):
try:
self.stop()
MessageBox(
self.close_button, title='Note', text='loading camera...', time=2000, auto=True).exec()
# get the number of local cameras
_, cams = Camera().get_cam_num()
popMenu = QMenu()
popMenu.setFixedWidth(self.src_cam_button.width())
popMenu.setStyleSheet('''
QMenu {
font-size: 16px;
font-family: "Microsoft YaHei UI";
font-weight: light;
color:white;
padding-left: 5px;
padding-right: 5px;
padding-top: 4px;
padding-bottom: 4px;
border-style: solid;
border-width: 0px;
border-color: rgba(255, 255, 255, 255);
border-radius: 3px;
background-color: rgba(200, 200, 200,50);}
''')
for cam in cams:
exec("action_%s = QAction('%s')" % (cam, cam))
exec("popMenu.addAction(action_%s)" % cam)
x = self.src_cam_button.mapToGlobal(self.src_cam_button.pos()).x()
y = self.src_cam_button.mapToGlobal(self.src_cam_button.pos()).y()
y = y + self.src_cam_button.frameGeometry().height()
pos = QPoint(x, y)
action = popMenu.exec(pos)
if action:
self.yolo_predict.source = action.text()
self.show_status('Loading camera:{}'.format(action.text()))
except Exception as e:
self.show_status('%s' % e)
# select network source
def chose_rtsp(self):
self.rtsp_window = Window()
config_file = 'config/ip.json'
if not os.path.exists(config_file):
ip = "rtsp://admin:admin888@192.168.1.2:555"
new_config = {"ip": ip}
new_json = json.dumps(new_config, ensure_ascii=False, indent=2)
with open(config_file, 'w', encoding='utf-8') as f:
f.write(new_json)
else:
config = json.load(open(config_file, 'r', encoding='utf-8'))
ip = config['ip']
self.rtsp_window.rtspEdit.setText(ip)
self.rtsp_window.show()
self.rtsp_window.rtspButton.clicked.connect(lambda: self.load_rtsp(self.rtsp_window.rtspEdit.text()))
# load network sources
def load_rtsp(self, ip):
try:
self.stop()
MessageBox(
self.close_button, title='提示', text='加载 rtsp...', time=1000, auto=True).exec()
self.yolo_predict.source = ip
new_config = {"ip": ip}
new_json = json.dumps(new_config, ensure_ascii=False, indent=2)
with open('config/ip.json', 'w', encoding='utf-8') as f:
f.write(new_json)
self.show_status('Loading rtsp:{}'.format(ip))
self.rtsp_window.close()
except Exception as e:
self.show_status('%s' % e)
# Save test result button--picture/video
def is_save_res(self):
if self.save_res_button.checkState() == Qt.CheckState.Unchecked:
self.show_status('NOTE: Run image results are not saved.')
self.yolo_predict.save_res = False
elif self.save_res_button.checkState() == Qt.CheckState.Checked:
self.show_status('NOTE: Run image results will be saved.')
self.yolo_predict.save_res = True
# Save test result button -- label (txt)
def is_save_txt(self):
if self.save_txt_button.checkState() == Qt.CheckState.Unchecked:
self.show_status('NOTE: Labels results are not saved.')
self.yolo_predict.save_txt = False
elif self.save_txt_button.checkState() == Qt.CheckState.Checked:
self.show_status('NOTE: Labels results will be saved.')
self.yolo_predict.save_txt = True
# Configuration initialization ~~~wait to change~~~
def load_config(self):
config_file = 'config/setting.json'
if not os.path.exists(config_file):
iou = 0.26
conf = 0.33
rate = 10
save_res = 0
save_txt = 0
new_config = {"iou": iou,
"conf": conf,
"rate": rate,
"save_res": save_res,
"save_txt": save_txt
}
new_json = json.dumps(new_config, ensure_ascii=False, indent=2)
with open(config_file, 'w', encoding='utf-8') as f:
f.write(new_json)
else:
config = json.load(open(config_file, 'r', encoding='utf-8'))
if len(config) != 5:
iou = 0.26
conf = 0.33
rate = 10
save_res = 0
save_txt = 0
else:
iou = config['iou']
conf = config['conf']
rate = config['rate']
save_res = config['save_res']
save_txt = config['save_txt']
self.save_res_button.setCheckState(Qt.CheckState(save_res))
self.yolo_predict.save_res = (False if save_res==0 else True )
self.save_txt_button.setCheckState(Qt.CheckState(save_txt))
self.yolo_predict.save_txt = (False if save_txt==0 else True )
self.run_button.setChecked(False)
self.show_status("Welcome~")
# Terminate button and associated state
def stop(self):
if self.yolo_thread.isRunning():
self.yolo_thread.quit() # end thread
self.yolo_predict.stop_dtc = True
self.run_button.setChecked(False) # start key recovery
self.save_res_button.setEnabled(True) # Ability to use the save button
self.save_txt_button.setEnabled(True) # Ability to use the save button
self.pre_video.clear() # clear image display
self.res_video.clear() # clear image display
self.progress_bar.setValue(0)
self.Class_num.setText('--')
self.Target_num.setText('--')
self.fps_label.setText('--')
# Change detection parameters
def change_val(self, x, flag):
if flag == 'iou_spinbox':
self.iou_slider.setValue(int(x*100)) # The box value changes, changing the slider
elif flag == 'iou_slider':
self.iou_spinbox.setValue(x/100) # The slider value changes, changing the box
self.show_status('IOU Threshold: %s' % str(x/100))
self.yolo_predict.iou_thres = x/100
elif flag == 'conf_spinbox':
self.conf_slider.setValue(int(x*100))
elif flag == 'conf_slider':
self.conf_spinbox.setValue(x/100)
self.show_status('Conf Threshold: %s' % str(x/100))
self.yolo_predict.conf_thres = x/100
elif flag == 'speed_spinbox':
self.speed_slider.setValue(x)
elif flag == 'speed_slider':
self.speed_spinbox.setValue(x)
self.show_status('Delay: %s ms' % str(x))
self.yolo_predict.speed_thres = x # ms
# change model
def change_model(self,x):
self.select_model = self.model_box.currentText()
self.yolo_predict.new_model_name = "./models/%s" % self.select_model
self.show_status('Change Model:%s' % self.select_model)
self.Model_name.setText(self.select_model)
# label result
# def show_labels(self, labels_dic):
# try:
# self.result_label.clear()
# labels_dic = sorted(labels_dic.items(), key=lambda x: x[1], reverse=True)
# labels_dic = [i for i in labels_dic if i[1]>0]
# result = [' '+str(i[0]) + ':' + str(i[1]) for i in labels_dic]
# self.result_label.addItems(result)
# except Exception as e:
# self.show_status(e)
# Cycle monitoring model file changes
def ModelBoxRefre(self):
pt_list = os.listdir('./models')
pt_list = [file for file in pt_list if file.endswith('.pt')]
pt_list.sort(key=lambda x: os.path.getsize('./models/' + x))
# It must be sorted before comparing, otherwise the list will be refreshed all the time
if pt_list != self.pt_list:
self.pt_list = pt_list
self.model_box.clear()
self.model_box.addItems(self.pt_list)
# Get the mouse position (used to hold down the title bar and drag the window)
def mousePressEvent(self, event):
p = event.globalPosition()
globalPos = p.toPoint()
self.dragPos = globalPos
# Optimize the adjustment when dragging the bottom and right edges of the window size
def resizeEvent(self, event):
# Update Size Grips
UIFuncitons.resize_grips(self)
# Exit Exit thread, save settings
def closeEvent(self, event):
config_file = 'config/setting.json'
config = dict()
config['iou'] = self.iou_spinbox.value()
config['conf'] = self.conf_spinbox.value()
config['rate'] = self.speed_spinbox.value()
config['save_res'] = (0 if self.save_res_button.checkState()==Qt.Unchecked else 2)
config['save_txt'] = (0 if self.save_txt_button.checkState()==Qt.Unchecked else 2)
config_json = json.dumps(config, ensure_ascii=False, indent=2)
with open(config_file, 'w', encoding='utf-8') as f:
f.write(config_json)
# Exit the process before closing
if self.yolo_thread.isRunning():
self.yolo_predict.stop_dtc = True
self.yolo_thread.quit()
MessageBox(
self.close_button, title='Note', text='Exiting, please wait...', time=3000, auto=True).exec()
sys.exit(0)
else:
sys.exit(0)
代码很多有一些是准备工作,比如给控件绑定到模型文件夹。这个类比较简单,大部分内容还是比较好懂的,可以自行理解。下一篇我们基于PyDracula的代码来一步步实现yolo检测。