YOLOv10(1):初探,训练自己的数据_yolov10 训练-CSDN博客
YOLOv10(2):网络结构及其检测模型代码部分阅读-CSDN博客
目录
1. 写在前面
2. 双标签分配
3. 计算Loss的预备条件
(1)网络输出
(2)Ground Truth
4. Loss的计算
1. 写在前面
YOLOv10几乎完美继承了YOLOv8的Loss计算方法,不同点就是YOLOv8只需要计算一次的,YOLOv10需要计算两次。究其原因,就是YOLOv10采用了双标签分配策略。
2. 双标签分配
YOLOv10同样采用了解耦检测头,实现了anchor free的训练和推理策略。 对Decoupled-Head类型的算法基本都有这么一种观念,就是不同任务间,用于分类和定位的Anchor(或称作cell,指分类和回归前的特征图中的一个单元)往往不一致,需要经过对齐。所谓对齐,简单来看就是计算Loss时,将分类与回归所使用的预测Cell进行统一化。
TAL(Task Align Learning),最早出现在论文Task-aligned One-stage Object Detection中。该论文提供一种思想,即通过构建一种“对齐度量”,来统一分类和回归的Anchor,进而实现最终在推理时,获得一个更高得分的分类框以及更准确的定位框系数。
更通俗地讲,TAL就是给Feature Map中的每一个Cell(或称作Anchor)分配Ground Truth框。在这种前提下,有的Cell能够分配到Ground Truth框,有的Cell分配不到GT框。根据Feature Map与GT的分配情况,构建用于Loss计算的target_labels、target_bboxes和target_scores。
YOLOv10在对齐时采用了one2one和one2many两种策略。在训练时,同时使用one2one和one2many进行Loss的计算,在推理时,仅使用one2one分支。
3. 计算Loss的预备条件
(1)网络输出
分析head.py中的v10Detect,确认网络最终输出形式
v10Detect在执行forward时,如果是training模式,会输出一个字典{"one2many": one2many, "one2one": one2one},如果是export,只输出one2one经过解算的结果。
v10Detect继承自Detect,one2one与one2many采用了不同的分支。one2one使用了v10Detect的self.cv3。one2many则采用了Detect原有的self.cv2和self.cv3。
以上,第525行,one2one是一个列表,列表中含有三个元素。
其中,对应网络输入分辨率是640*640的话,one2one[0]是shape(N, 64+nc, 80, 80),one2one[1]是shape(N, 64+nc, 40, 40),one2one[2]是shape(N, 64+nc, 20, 20)。
one2many的shape与one2one是一样的。
(2)Ground Truth
在此,Ground Truth是一个字典,包含了用于训练输入的数据和用来计算Loss的标注信息。
GT预读和加载在./ultralytics/models/yolo/train.py中的get_dataloader中定义。重点关注“build_yolo_dataset”和“build_dataloader”两个接口就可以。
在此我们先做一个直接的说明,后续会针对数据的组织形式专门开展一期讲解。
我们在训练时,会生成一个batch数据,这是一个字典,其中包含了如下的信息:
batch[“img”]: 训练图像数据;
batch[“bboxes”]: 标注的边框信息;
batch[“cls”]: 标注的边框对应的类别;
4. Loss的计算
在前面我们已经知道了v8DetectionLoss中那些成员会参与Loss的计算,此处我们将大致梳理一下流程,后续将对其中比较重要的TAL和bbox_loss专门进行详细的讲解。
如下为v8DetectionLoss的__call__函数的前半部分,此处的__call__函数效果类似于C++中的仿函数,能够想调用普通函数一样调用类。
这一部分实际上是做一些准备工作。第218行和第219行是将输入的pred进行分解和变换,分别得到shape(N, 8400, 64)和shape(N, 8400, nc)的tensor,用于后续的TAL对其。注意,此处的8400是针对输入时640*640而言的,具体的,8400 = 80*80 + 40*40 + 20*20。
第229行做了一个Anchor框架,其中anchor_points是一个shape(8400, 2)的tensor,anchor_points[:, 2]表示一个cell的中心点坐标。stride_tensor是一个shape(8400, 1)的tensor,stride_tensor[:, 1]表示某一cell与原始尺度之间的stride信息。
如下是v8DetectionLoss的后半部分,其中涉及到了比较精髓的TAL和bbox_loss操作。