论文地址:https://arxiv.org/pdf/1911.08947.pdf
开源代码pytorch版本:GitHub - WenmuZhou/DBNet.pytorch: A pytorch re-implementation of Real-time Scene Text Detection with Differentiable Binarization
前言
在这篇论文之前,文字检测算法主要分为两类:基于回归的方法和基于分割的方法。基于分割的方法通常涉及以下流程,如下图蓝色箭头所示:首先,通过网络输出图像的文本分割结果,即概率图,其中每个像素表示是否属于正样本的概率。然后,通过使用预设的阈值将分割结果图转换为二值图。最后,通过一些聚合操作,例如连通域分析,将像素级的结果转换为最终的文本检测结果。然而,由于涉及使用阈值来判定前景和背景的不可微分操作,因此这一部分流程无法被直接放入网络中进行训练。所以本文引入了一种新的方法。具体而言,通过学习阈值映射(threshmap)并采用可微分的操作,将阈值的转换过程嵌入到网络中进行训练。这一创新的流程如下图中红色箭头所示,通过可微分的操作来处理阈值的学习,使得整个流程可以在神经网络的训练中进行端到端的优化。通过这种方式,文本检测模型能够自适应地学习阈值,更有效地捕捉文本的分割信息,提高了检测性能。这一方法有助于简化原有基于分割方法的后处理流程,同时使整个模型更具可训练性。
网络结构
其实从下图的网络结构中不难看出,相比较于PSENet,多了一条threshold map分支罢了,该分支的主要目的是和分割图联合得到更接近二值化的二值图,属于辅助分支。
整个网络结构流程:
图像输入特征提取主干: 使用图像输入,经过一个特征提取的主干网络,该网络负责从输入图像中提取高层次的语义特征。这可以是一个卷积神经网络(CNN)的主要部分,如ResNet或其他先进的架构。
特征金字塔上采样和级联: 从特征提取主干获得的特征被送入特征金字塔。在特征金字塔中,通过上采样将不同尺寸的特征图调整到相同的尺寸,并将它们级联在一起,形成一个具有丰富多尺度信息的特征F。这有助于模型对不同大小和尺度的目标进行有效的检测和分割。
预测概率图和阈值图: 利用级联的特征F,进行概率图(probability map P)和阈值图(threshold map T)的预测。概率图通常表示每个像素属于某个类别(在这里可能是目标文本与非文本的概率),而阈值图则用于指导后续的二值化操作。这一步的目的是产生用于后续计算的中间结果。
计算近似二值图: 利用概率图P和阈值图T,通过一定的计算过程(可能是使用阈值或其他运算),得到一个近似的二值图B。这个近似二值图用于最终的文本检测,其中文本区域被二值化为前景,而非文本区域为背景。
在训练过程中,该模型通过使用相同的监督信号对概率图 P 和近似二值图 B 进行监督训练,其中概率图表示文本区域的概率,而近似二值图是文本二值化结果。在推理阶段,只需使用概率图 P 或者近似二值图B 中的任一即可获取文本检测结果,无需依赖额外的阈值图。这种设计简化了推理流程,提高了模型的实际应用效率。
模型的输出
Probability Map(概率图): 这是一个大小为w×h×1 的张量,其中 w 和 ℎ分别表示图像的宽度和高度。概率图的每个像素表示相应位置是否为文本的概率。对于二进制文本检测任务,概率图的值通常在 0 到 1 之间,表示每个像素点属于文本的概率,1 表示高置信度是文本,0 表示低置信度是文本。
Threshold Map(阈值图): 阈值图也是一个大小为 w×h×1 的张量,其中每个像素点包含一个阈值。这些阈值用于二值化概率图,将其转换为最终的二值图。阈值图的每个值表示相应位置的二值化操作的阈值。
Binary Map(二值图): 由概率图和阈值图计算得到,也是一个大小为 w×h×1 的张量。它表示最终的文本检测结果,其中每个像素点被二值化为前景(文本)或背景(非文本)。这里提到使用了 "DB 公式" 来计算二值图,而 DB(Differentiable Binarization)通常是一个近似二值化的函数,通过可微分的操作来实现对阈值的学习和调整。
DB公式
标准二值化
一般使用分割网络(segmentation network)产生的概率图(probability map P),将P转化为一个二值图P,当像素为1的时候,认定其为有效的文本区域。i和j代表了坐标点的坐标,t是预定义的阈值
可微二值化(differentiable Binarization)
可微二值化的公式如下,其实就是带一个系数的sigmoid,其中其中T是阈值图,k取50
从图像上不难看出,二值化和标准二值化很相似,且可微分,因此可以和分割网络一起联合优化
从(b)(c)图我们不难看出通过增加参数 K,可以在模型的训练过程中加速对正确预测区域和错误预测区域的学习,以更快地收敛到最优解。这样的调整可以在某些情况下提高模型的训练效率和性能。原图,gt图,threshold map图如下所示
模型训练
自动下载的预训练模型下载地址:/home/xuzhen/.cache/torch/hub/checkpoints/resnet18-5c106cde.pth(我看了代码,他是判断有没有预训练模型没有的话才下载)
这个源代码在配置文件中加载的是train.txt和test.txt,所以我写了一个脚本,根据img文件夹和gt文件夹自动生成这两个文件的脚本
import os
def create_gt_file(img_dir, gt_dir, output_file_path):
# 检查文件夹是否存在
if not os.path.exists(img_dir) or not os.path.exists(gt_dir):
print("Error: One or both folders do not exist.")
return
img_paths = []
gt_paths = []
# 循环读取文件夹1中的文件名
for filename in os.listdir(img_dir):
img_path = os.path.join(img_dir,filename)
img_paths.append(img_path)
# 去掉后缀并在前面加上 "gt_"
gt_path = os.path.join(gt_dir, "gt_" + os.path.splitext(filename)[0] + ".txt")
gt_paths.append(gt_path)
# 写入文件
with open(output_file_path, 'w') as output_file:
# 将 img_paths 和 gt_paths 写入文件
for img_path, gt_path in zip(img_paths, gt_paths):
output_file.write(f"{img_path}\t{gt_path}\n")
print(f"{img_path}\t{gt_path}Strings written to {output_file_path}")
# 主函数
def main():
img_dir = "/data2/xuzhen8/yzh/datasets/ICDAR2015/test_images"
gt_dir = "/data2/xuzhen8/yzh/datasets/ICDAR2015/testing_localization_transcription_gt"
output_file_path = "/data2/xuzhen8/yzh/projects/DBNet.pytorch/datasets/test.txt"
create_gt_file(img_dir, gt_dir, output_file_path)
if __name__ == "__main__":
main()
每一轮训练都会打印信息,我想对这个打印信息说明一下,以便后面复习
FPS(Frames Per Second): 99.37
表示每秒处理的图像帧数。在这个上下文中,表示模型在测试阶段的推断速度。这是通过测量模型在测试集上处理图像的速度来得到的,其单位是帧数/秒。
test: recall: 0.031477, precision: 0.596330, f1: 0.059798
提供了模型在测试集上的性能指标。在这里,包括了召回率(recall)、精确度(precision)和 F1 分数(f1-score)。这些指标用于衡量模型在检测任务中的性能,其中:
召回率表示正确检测到的正类别样本占所有实际正类别样本的比例。
精确度表示模型正确检测的正类别样本占所有模型检测为正类别的样本的比例。
F1 分数是召回率和精确度的调和平均数,综合考虑了这两个指标。
current best, recall: 0.101695, precision: 0.726644, hmean: 0.178420, train_loss: 1.706732, best_model_epoch: 5.000000
提供了模型在测试集上的当前最佳性能以及训练期间的一些指标。其中:
recall、precision、hmean 是测试集上的召回率、精确度和 F1 分数。
train_loss 表示模型在训练集上的损失值,用于衡量训练过程中模型的拟合情况。
best_model_epoch 表示在训练过程中取得最佳性能的模型所对应的训练轮次。
Saving checkpoint:DBNet.pytorch/output/DBNet_resnet18_FPN_DBHead/checkpoint/
model_latest.pth
表示当前训练轮次的模型参数被保存到了指定路径下的 model_latest.pth 文件中。这通常发生在模型在测试集上取得了更好性能后,保存了当前状态的模型参数,以备将来使用或继续训练。
小辉问:能不能举个例子说明一下召回率、精确度、F1 分数。以便更好的理解
小G答:假设有一个二分类任务,目标是检测患有某种疾病的患者。我们的模型对每个样本都进行预测,可以分为以下四种情况:
True Positive (TP): 模型正确地预测了患有疾病的患者。
True Negative (TN): 模型正确地预测了没有患疾病的健康人。
False Positive (FP): 模型错误地预测了没有患疾病的健康人为患病。
False Negative (FN): 模型错误地预测了患有疾病的患者为健康人。
现在,我们可以使用这些概念来解释这些指标:
召回率(Recall):
召回率衡量了模型在所有实际患有疾病的样本中,有多少被成功地检测到。计算公式:
例如,如果总共有 100 名患有疾病的患者,而模型成功地检测到其中的 80 人,则召回率为 80/80+20 =0.8 或 80%。
精确度(Precision):
精确度衡量了模型在所有预测为患有疾病的样本中,有多少实际上是真正患有疾病的人。计算公式:
例如,如果模型预测了 90 个人患有疾病,而其中有 80 人确实是患有疾病的,则精确度为 80/80+10=0.888 或 88.8%。
F1 分数:
F1 分数是召回率和精确度的调和平均数,它综合考虑了两者的性能。计算公式:
F1 分数的取值范围在 [0,1],越接近 1 表示模型在召回率和精确度之间取得了更好的平衡。