OCR经典神经网络(一)文本识别算法CRNN算法原理及其在icdar15数据集上的应用

OCR经典神经网络(一)文本识别算法CRNN算法原理及其在icdar15数据集上的应用

  • 文本识别是OCR(Optical Character Recognition)的一个子任务,其任务为:识别一个固定区域的的文本内容。

    • 在OCR的两阶段方法里,文本识别模型接在文本检测(如DB算法)后面,将图像信息转换为文字信息
    • 具体来讲:如下图所示,文本识别模型的输入是一张经过文本检测后的文本行图片,输出图片中的文字内容和置信度。

    在这里插入图片描述

    ('实力活力', 0.9861845970153809)
    
  • 文本识别的应用场景很多,如:文档识别、路标识别、车牌识别、工业编号识别等等。下表展示了主流的算法类别和主要论文。

    • 今天我们了解下由华中科技大学白翔老师团队在2015年提出的CRNN模型。
    • 论文链接:https://arxiv.org/pdf/1507.05717
    • 百度开源的paddleocr中集成了此算法:https://github.com/PaddlePaddle/PaddleOCR
算法类别主要思路主要论文
传统算法滑动窗口、字符提取、动态规划-
ctc基于ctc的方法,序列不对齐,更快速识别CRNN, Rosetta
Attention基于attention的方法,应用于非常规文本RARE, DAN, PREN
Transformer基于transformer的方法SRN, NRTR, Master, ABINet
校正校正模块学习文本边界并校正成水平方向RARE, ASTER, SAR
分割基于分割的方法,提取字符位置再做分类Text Scanner, Mask TextSpotter

1 CRNN网络结构

如下图所示,CRNN由CNN+RNN+CTC三部分组成。

  • 特征提取部分使用主流的卷积结构,常用的有ResNet、MobileNet、VGG等,原始论文中使用的VGG。
  • 由于文本识别任务的特殊性,输入数据中存在大量的上下文信息,卷积神经网络的卷积核特性使其更关注于局部信息,缺乏长依赖的建模能力,因此仅使用卷积网络很难挖掘到文本之间的上下文联系。为了解决这一问题,CRNN文本识别算法引入了双向 LSTM用来增强上下文建模,通过实验证明双向LSTM模块可以有效的提取出图片中的上下文信息。
  • 时间步输出是24个,但是图片中字符数不一定都是24,长短不一(注:是0-9数字以及a-z字母组合,还有一个blank标识符,总共37类)。因此,最终将输出的特征序列输入到CTC模块,直接解码序列结果。该结构被验证有效,并广泛应用在文本识别任务中。

在这里插入图片描述

1.1 CNN结构

CNN结构采用的是VGG的结构,并且对VGG网络做了一些微调:

  • 为了能将CNN提取的特征作为输入,输入到RNN网络中,将第三和第四个maxpooling的核尺度 2 × 2 2 × 2 2×2改为了 1 × 2 1 × 2 1×2

    • 将第三和第四个maxpooling改变的原因:为了方便的将CNN的提取特征作为RNN的输入。
    • 首先要注意的是这个网络的输入为W × 32,也就是说该网络对输入图片的宽没有特殊的要求,但是高都必须resize到32。
    • 文中举例说明:如果一张包含10个字符的图片大小为100 × 32,经过上述的CNN网络得到的特征尺度为24 × 1(忽略通道数),这样得到一个序列。每一列特征对应原图的一个矩形区域(如下图所示),这样就很方便作为RNN的输入进行下一步的计算了,而且每个特征与输入有一个一对一的对应关系。

    在这里插入图片描述

  • 为了加速网络的训练,在第五和第六个卷积层后面加上了BN层;

在这里插入图片描述

1.2 RNN结构

为了防止训练时梯度的消失,采用了LSTM神经单元作为RNN的单元。作者认为对于序列的预测,序列的前向信息和后向信息都有助于序列的预测,所以作者采用了双向RNN网络

在这里插入图片描述

1.3 CTC转录层

  • 如果使用传统的loss function,需要对齐训练样本,有24个时间步,就需要有24个对应的标签,在该任务中显然不合适,除非可以把图片中的每一个字符都单独检测出来,一个字符对应一个标签,但这需要很强大的文字检测算法,而CTCLoss不需要对齐样本。

  • 24个时间步得到24个标签,再进行一个β变换,才得到最终标签。

    • 24个时间步可以看作原图中分成24列,每一列输出一个标签,有时一个字母占据好几列,例如字母S占据三列,则这三列输出类别都应该是S,有的列没有字母,则输出空白类别。
    • 得到最终类别时将连续重复的字符去重(空白符两侧的相同字符不去重,因为真实标签中可能存在连续重复字符,例如green中的两个连续的e不应该去重,则生成标签的时候就该是类似e-e这种,则不会去重),最终去除空白符即可得到最终标签。
  • 由于CTCLoss计算有些复杂,具体可参考:CTC算法详解。深度学习框架中,一般都集成了CTC Loss,我们直接使用即可。

def ctc_loss():
    import torch
    # https://pytorch.org/docs/1.13/generated/torch.nn.CTCLoss.html#torch.nn.CTCLoss
    # Target are to be padded
    T = 50  # Input sequence length
    C = 20  # Number of classes (including blank)
    N = 2   # Batch size
    S = 30  # Target sequence length of longest target in batch (padding length)
    S_min = 10  # 目标序列的最小长度,这里仅用于生成随机目标长度时的范围限制

     # Initialize random batch of input vectors, for *size = (T,N,C)
    # 对输入进行了softmax操作并取了对数,以符合CTC损失函数的输入要求
    # input shape = (50, 2, 20)
    input = torch.randn(T, N, C).log_softmax(2).detach().requires_grad_()

     # Initialize random batch of targets (0 = blank, 1:C = classes)
    # target shape = (2, 30)
    target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)

    # 长度为N的张量,表示每个输入序列的实际长度
    # 这里假设所有输入序列都是完整的T长度
    input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
    # 长度为N的张量,表示每个目标序列的实际长度。
    # 这里通过随机生成一个介于S_min和S之间的整数来模拟不同长度的目标序列
    target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)

    ctc_loss = nn.CTCLoss()
    # input是模型的输出(对数概率),target是目标序列
    # input_lengths和target_lengths分别指定了输入序列和目标序列的实际长度。
    loss = ctc_loss(input, target, input_lengths, target_lengths)
    loss.backward()

if __name__ == '__main__':
    ctc_loss()

1.4 CRNN网络的简单实现

使用pytorch框架实现CRNN网络如下:

import torch.nn as nn
from torchinfo import summary

class CRNN(nn.Module):

    def __init__(self, img_channel, img_height, img_width, num_class,
                 map_to_seq_hidden=64, rnn_hidden=256, leaky_relu=False):
        super(CRNN, self).__init__()

        self.cnn, (output_channel, output_height, output_width) = \
            self._cnn_backbone(img_channel, img_height, img_width, leaky_relu)

        self.map_to_seq = nn.Linear(output_channel * output_height, map_to_seq_hidden)

        self.rnn1 = nn.LSTM(map_to_seq_hidden, rnn_hidden, bidirectional=True)

        # 如果接双向lstm输出,则要 *2,固定用法
        self.rnn2 = nn.LSTM(2 * rnn_hidden, rnn_hidden, bidirectional=True)

        self.dense = nn.Linear(2 * rnn_hidden, num_class)

    # CNN主干网络
    def _cnn_backbone(self, img_channel, img_height, img_width, leaky_relu):
        assert img_height % 16 == 0
        assert img_width % 4 == 0

        # 超参设置
        channels = [img_channel, 64, 128, 256, 256, 512, 512, 512]
        kernel_sizes = [3, 3, 3, 3, 3, 3, 2]
        strides = [1, 1, 1, 1, 1, 1, 1]
        paddings = [1, 1, 1, 1, 1, 1, 0]

        cnn = nn.Sequential()

        def conv_relu(i, batch_norm=False):
            # shape of input: (batch, input_channel, height, width)
            input_channel = channels[i]
            output_channel = channels[i + 1]

            cnn.add_module(
                f'conv{i}',
                nn.Conv2d(input_channel, output_channel, kernel_sizes[i], strides[i], paddings[i])
            )

            if batch_norm:
                cnn.add_module(f'batchnorm{i}', nn.BatchNorm2d(output_channel))

            relu = nn.LeakyReLU(0.2, inplace=True) if leaky_relu else nn.ReLU(inplace=True)
            cnn.add_module(f'relu{i}', relu)

        # size of image: (channel, height, width) = (img_channel, img_height, img_width)
        conv_relu(0)
        cnn.add_module('pooling0', nn.MaxPool2d(kernel_size=2, stride=2))
        # (64, img_height // 2, img_width // 2)

        conv_relu(1)
        cnn.add_module('pooling1', nn.MaxPool2d(kernel_size=2, stride=2))
        # (128, img_height // 4, img_width // 4)

        conv_relu(2)
        conv_relu(3)
        cnn.add_module(
            'pooling2',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (256, img_height // 8, img_width // 4)

        conv_relu(4, batch_norm=True)
        conv_relu(5, batch_norm=True)
        cnn.add_module(
            'pooling3',
            nn.MaxPool2d(kernel_size=(2, 1))
        )  # (512, img_height // 16, img_width // 4)

        conv_relu(6)  # (512, img_height // 16 - 1, img_width // 4 - 1)

        output_channel, output_height, output_width = \
            channels[-1], img_height // 16 - 1, img_width // 4 - 1
        return cnn, (output_channel, output_height, output_width)

    # CNN+LSTM前向计算
    def forward(self, images):
        # shape of images: (batch, channel, height, width)

        conv = self.cnn(images)
        batch, channel, height, width = conv.size()

        conv = conv.view(batch, channel * height, width)
        conv = conv.permute(2, 0, 1)  # (width, batch, feature)

        # 卷积接全连接。全连接输入形状为(width, batch, channel*height),
        # 输出形状为(width, batch, hidden_layer),分别对应时序长度,batch,特征数,符合LSTM输入要求
        seq = self.map_to_seq(conv)

        recurrent, _ = self.rnn1(seq)
        recurrent, _ = self.rnn2(recurrent)

        output = self.dense(recurrent)
        return output  # shape: (seq_len, batch, num_class)

if __name__ == '__main__':
    net = CRNN(img_channel=3, img_height=32, img_width=100, num_class=37)
    summary(net, input_size=(1, 3, 32, 100))
==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
CRNN                                     [24, 1, 37]               --
├─Sequential: 1-1                        [1, 512, 1, 24]           --
│    └─Conv2d: 2-1                       [1, 64, 32, 100]          1,792
│    └─ReLU: 2-2                         [1, 64, 32, 100]          --
│    └─MaxPool2d: 2-3                    [1, 64, 16, 50]           --
│    └─Conv2d: 2-4                       [1, 128, 16, 50]          73,856
│    └─ReLU: 2-5                         [1, 128, 16, 50]          --
│    └─MaxPool2d: 2-6                    [1, 128, 8, 25]           --
│    └─Conv2d: 2-7                       [1, 256, 8, 25]           295,168
│    └─ReLU: 2-8                         [1, 256, 8, 25]           --
│    └─Conv2d: 2-9                       [1, 256, 8, 25]           590,080
│    └─ReLU: 2-10                        [1, 256, 8, 25]           --
│    └─MaxPool2d: 2-11                   [1, 256, 4, 25]           --
│    └─Conv2d: 2-12                      [1, 512, 4, 25]           1,180,160
│    └─BatchNorm2d: 2-13                 [1, 512, 4, 25]           1,024
│    └─ReLU: 2-14                        [1, 512, 4, 25]           --
│    └─Conv2d: 2-15                      [1, 512, 4, 25]           2,359,808
│    └─BatchNorm2d: 2-16                 [1, 512, 4, 25]           1,024
│    └─ReLU: 2-17                        [1, 512, 4, 25]           --
│    └─MaxPool2d: 2-18                   [1, 512, 2, 25]           --
│    └─Conv2d: 2-19                      [1, 512, 1, 24]           1,049,088
│    └─ReLU: 2-20                        [1, 512, 1, 24]           --
├─Linear: 1-2                            [24, 1, 64]               32,832
├─LSTM: 1-3                              [24, 1, 512]              659,456
├─LSTM: 1-4                              [24, 1, 512]              1,576,960
├─Linear: 1-5                            [24, 1, 37]               18,981
==========================================================================================
Total params: 7,840,229
Trainable params: 7,840,229
Non-trainable params: 0
Total mult-adds (M): 675.96
==========================================================================================
Input size (MB): 0.04
Forward/backward pass size (MB): 5.23
Params size (MB): 31.36
Estimated Total Size (MB): 36.63
==========================================================================================

2 CRNN在icdar15数据集上的微调(paddleocr)

  • 我们这里使用百度开源的paddleocr来对CRNN模型有更深的认识:

    • paddleocr地址:https://github.com/PaddlePaddle/PaddleOCR
    • paddleocr中集成的算法列表:https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/algorithm/overview.md
  • # git拉取下来,解压
    git clone https://gitee.com/paddlepaddle/PaddleOCR
    
    # 然后进入PaddleOCR目录,安装PaddleOCR第三方依赖
    pip install -r requirements.txt
    
  • 我们在paddleocr/tests目录下,创建py文件进行下图的测试

    在这里插入图片描述

from paddleocr import PaddleOCR

ocr = PaddleOCR()

# 默认会下载官方训练好的模型,并将下载的模型放到用户目录下(我这里是:C:\\Users\\Undo/.paddleocr)
result = ocr.ocr(img='./rec_img.png'
                 , det=False  # 文本检测器,默认算法为DBNet
                 , rec=True   # 方向分类器
                 , cls=True   # 文本识别,默认为CRNN模型
                 )

for line in result:
    print(line)
[('实力活力', 0.7939199805259705)]

2.1 CRNN网络的搭建

2.1.1 backbone

  • MobileNet模型是Google针对手机等嵌入式设备提出的一种轻量化深度神经网络,使用的核心思想是:深度可分离卷积。MobileNet系列中主要包括MobileNet V1、MobileNet V2、MobileNet V3。

    • 论文链接:Searching for MobileNetV3

    • MobileNetV3 有两个版本,MobileNetV3-Small 与 MobileNetV3-Large 分别对应对计算和存储要求低和高的版本。

    • MobileNetV3继续采用了轻量级的深度可分离卷积和残差块等结构,依然是由多个模块组成,但是每个模块得到了优化和升级,包括瓶颈结构、SE模块和NL模块(如下图)。

      在这里插入图片描述

    • 整体来说MobileNetV3有两大创新点:

      • 互补搜索技术组合:由资源受限的NAS执行模块级搜索,NetAdapt执行局部搜索。

      • 网络结构改进(如下图):将最后一步的平均池化层前移并移除最后一个卷积层,引入h-swish激活函数。

        在这里插入图片描述

  • PaddleOCR 使用 MobileNetV3 作为骨干网络,paddleocr/configs/rec/rec_icdar15_train.yml中默认使用MobileNetV3-Large版本。同样,需要对MobileNetV3的结构做一些微调,主要对下采样进行修改。

# paddleocr/configs/rec/rec_icdar15_train.yml
Architecture:
  model_type: rec
  algorithm: CRNN         # 使用CRNN模型
  Transform:
  Backbone:               
    name: MobileNetV3     # ppocr/modeling/backbones/rec_mobilenet_v3.py
    scale: 0.5            # 这里进行了缩放
    model_name: large     # 骨干网络使用MobileNetV3的large版本
  Neck:                   
    name: SequenceEncoder # ppocr/modeling/necks/rnn.py
    encoder_type: rnn
    hidden_size: 96
  Head:
    name: CTCHead         # ppocr/modeling/heads/rec_ctc_head.py
    fc_decay: 0

Loss:
  name: CTCLoss           # ppocr/losses/rec_ctc_loss.py 损失计算

PostProcess:
  name: CTCLabelDecode    # ppocr/postprocess/rec_postprocess.py  后处理

Metric:
  name: RecMetric         # ppocr/metrics/rec_metric.py 指标评估
  main_indicator: acc

在这里插入图片描述

# ppocr/modeling/backbones/rec_mobilenet_v3.py

from paddle import nn

from ppocr.modeling.backbones.det_mobilenet_v3 import (
    ResidualUnit,
    ConvBNLayer,
    make_divisible,
)

__all__ = ["MobileNetV3"]


class MobileNetV3(nn.Layer):
    def __init__(
        self,
        in_channels=3,
        model_name="small",
        scale=0.5,
        large_stride=None,
        small_stride=None,
        disable_se=False,
        **kwargs,
    ):
        super(MobileNetV3, self).__init__()
        self.disable_se = disable_se
        if small_stride is None:
            small_stride = [2, 2, 2, 2]
        if large_stride is None:
            large_stride = [1, 2, 2, 2]

        assert isinstance(
            large_stride, list
        ), "large_stride type must " "be list but got {}".format(type(large_stride))
        assert isinstance(
            small_stride, list
        ), "small_stride type must " "be list but got {}".format(type(small_stride))
        assert (
            len(large_stride) == 4
        ), "large_stride length must be " "4 but got {}".format(len(large_stride))
        assert (
            len(small_stride) == 4
        ), "small_stride length must be " "4 but got {}".format(len(small_stride))

        if model_name == "large":
            cfg = [
                # k, exp, c,  se,     nl,  s,
                [3, 16, 16, False, "relu", large_stride[0]],            # step 1 高下采样2倍,宽下采样2倍
                [3, 64, 24, False, "relu", (large_stride[1], 1)],       # step 2 高下采样2倍,宽不变
                [3, 72, 24, False, "relu", 1],
                [5, 72, 40, True, "relu", (large_stride[2], 1)],        # step 3 高下采样2倍,宽不变
                [5, 120, 40, True, "relu", 1],
                [5, 120, 40, True, "relu", 1],
                [3, 240, 80, False, "hardswish", 1],
                [3, 200, 80, False, "hardswish", 1],
                [3, 184, 80, False, "hardswish", 1],
                [3, 184, 80, False, "hardswish", 1],
                [3, 480, 112, True, "hardswish", 1],
                [3, 672, 112, True, "hardswish", 1],
                [5, 672, 160, True, "hardswish", (large_stride[3], 1)],  # step 4 高下采样2倍,宽不变
                [5, 960, 160, True, "hardswish", 1],
                [5, 960, 160, True, "hardswish", 1],
            ]
            cls_ch_squeeze = 960
        elif model_name == "small":
            cfg = [
                # k, exp, c,  se,     nl,  s,
                [3, 16, 16, True, "relu", (small_stride[0], 1)],
                [3, 72, 24, False, "relu", (small_stride[1], 1)],
                [3, 88, 24, False, "relu", 1],
                [5, 96, 40, True, "hardswish", (small_stride[2], 1)],
                [5, 240, 40, True, "hardswish", 1],
                [5, 240, 40, True, "hardswish", 1],
                [5, 120, 48, True, "hardswish", 1],
                [5, 144, 48, True, "hardswish", 1],
                [5, 288, 96, True, "hardswish", (small_stride[3], 1)],
                [5, 576, 96, True, "hardswish", 1],
                [5, 576, 96, True, "hardswish", 1],
            ]
            cls_ch_squeeze = 576
        else:
            raise NotImplementedError(
                "mode[" + model_name + "_model] is not implemented!"
            )

        supported_scale = [0.35, 0.5, 0.75, 1.0, 1.25]
        assert (
            scale in supported_scale
        ), "supported scales are {} but input scale is {}".format(
            supported_scale, scale
        )

        inplanes = 16
        # conv1
        self.conv1 = ConvBNLayer(
            in_channels=in_channels,
            out_channels=make_divisible(inplanes * scale),
            kernel_size=3,
            stride=2,
            padding=1,
            groups=1,
            if_act=True,
            act="hardswish",
        )
        i = 0
        block_list = []
        inplanes = make_divisible(inplanes * scale)
        for k, exp, c, se, nl, s in cfg:
            se = se and not self.disable_se
            block_list.append(
                ResidualUnit(
                    in_channels=inplanes,
                    mid_channels=make_divisible(scale * exp),
                    out_channels=make_divisible(scale * c),
                    kernel_size=k,
                    stride=s,
                    use_se=se,
                    act=nl,
                )
            )
            inplanes = make_divisible(scale * c)
            i += 1
        self.blocks = nn.Sequential(*block_list)

        self.conv2 = ConvBNLayer(
            in_channels=inplanes,
            out_channels=make_divisible(scale * cls_ch_squeeze),
            kernel_size=1,
            stride=1,
            padding=0,
            groups=1,
            if_act=True,
            act="hardswish",
        )

        self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)  # step 5 高下采样2倍,宽下采样2倍
        self.out_channels = make_divisible(scale * cls_ch_squeeze)

    def forward(self, x):
        x = self.conv1(x)
        x = self.blocks(x)
        x = self.conv2(x)
        x = self.pool(x)
        return x


if __name__ == '__main__':
    import paddle as torch
    x = torch.rand((1, 3, 32, 100))
    net = MobileNetV3(in_channels=3, model_name='large', scale=0.5)
    # [1, 480, 1, 25]
    print(net(x).shape)

2.1.2 neck

neck 部分将backbone输出的视觉特征图转换为1维向量输入送到 LSTM 网络中,输出序列特征

# ppocr/modeling/necks/rnn.py
class SequenceEncoder(nn.Layer):
    def __init__(self, in_channels, encoder_type, hidden_size=48, **kwargs):
        super(SequenceEncoder, self).__init__()
        self.encoder_reshape = Im2Seq(in_channels)
        self.out_channels = self.encoder_reshape.out_channels
        ......

    def forward(self, x):
        if self.encoder_type != "svtr":
            # 1、neck部分将backbone输出的视觉特征图转换为1维向量
            # (bs, channels, H, W) -> (bs, H * W, channels),即(bs, seq_len, embedding_dim)
            # reshape后的x = (bs, 25, 480) , 480 = channels(960) * scale(0.5)
            x = self.encoder_reshape(x)

            # 2、转换后,送到LSTM网络中,输出序列特征
            # self.encoder
            # EncoderWithRNN(
            #   (lstm): LSTM(480, 96, num_layers=2
            #     (0): BiRNN(
            #       (cell_fw): LSTMCell(480, 96)
            #       (cell_bw): LSTMCell(480, 96)
            #     )
            #     (1): BiRNN(
            #       (cell_fw): LSTMCell(192, 96)
            #       (cell_bw): LSTMCell(192, 96)
            #     )
            #   )
            # )
            if not self.only_reshape:
                x = self.encoder(x)  # [bs, 25, 96 * 2]
            return x
        else:
            x = self.encoder(x)
            x = self.encoder_reshape(x)
            return x

2.1.3 head

预测头部分由全连接层和softmax组成,用于计算序列特征时间步上的标签概率分布

class CTCHead(nn.Layer):
    def __init__(
        self,
        in_channels,
        out_channels,
        fc_decay=0.0004,
        mid_channels=None,
        return_feats=False,
        **kwargs,
    ):
        super(CTCHead, self).__init__()
        if mid_channels is None:
            weight_attr, bias_attr = get_para_bias_attr(
                l2_decay=fc_decay, k=in_channels
            )
            # Linear(in_features=192, out_features=37, dtype=float32)
            self.fc = nn.Linear(
                in_channels, out_channels, weight_attr=weight_attr, bias_attr=bias_attr
            )
        else:
            ......
        self.out_channels = out_channels
        self.mid_channels = mid_channels
        self.return_feats = return_feats

    def forward(self, x, targets=None):
        if self.mid_channels is None:
            # (bs, 25, 192) -> predicts (bs, 25, 37)
            predicts = self.fc(x)
        else:
            x = self.fc1(x)
            predicts = self.fc2(x)

        if self.return_feats:
            result = (x, predicts)
        else:
            result = predicts
        if not self.training:  # 非训练时,经过SoftMax,可以得到各时间步上的概率最大的预测结果
            predicts = F.softmax(predicts, axis=2)
            result = predicts

        return result

2.2 数据集加载及模型训练

2.2.1 数据集的下载

提供一份处理过的icdar15数据集:

下载地址:https://pan.baidu.com/s/1VP2Y_IhxAUwQABDmbXrgIg 提取码: ek25

数据集应有如下文件结构:

|-train_data
  |-ic15_data
    |- rec_gt_train.txt
    |- train
        |- word_001.png
        |- word_002.jpg
        |- word_003.jpg
        | ...
    |- rec_gt_test.txt
    |- test
        |- word_001.png
        |- word_002.jpg
        |- word_003.jpg
        | ...

其中txt文件里的内容如下:

" 图像文件名         图像标注信息 "

train/word_1.png	Genaxis Theatre
train/word_2.png	[06]
...

下载完数据集后,我们复制一份paddleocr/configs/rec/rec_icdar15_train.yml文件到paddleocr\tests\configs进行修改:

......
Train: # 修改训练集的路径及其他相关信息
  dataset:
    name: SimpleDataSet
    data_dir: D:/python/datas/cv/ic15_data  
    label_file_list: ["D:/python/datas/cv/ic15_data/rec_gt_train.txt"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 100]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: True             # 是否打乱
    batch_size_per_card: 256  # 批次大小
    drop_last: True           # 最后1个批次是否删除
    num_workers: 0            # 读取数据的进程数
    use_shared_memory: False

Eval: # 修改验证集的路径及其他相关信息
  dataset:
    name: SimpleDataSet
    data_dir: D:/python/datas/cv/ic15_data
    label_file_list: ["D:/python/datas/cv/ic15_data/rec_gt_test.txt"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 100]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: False
    drop_last: False
    batch_size_per_card: 256
    num_workers: 0
    use_shared_memory: False

2.2.2 模型的训练与预测

  • 我这里不用命令行执行,在paddleocr\tests目录下创建一个py文件执行训练过程
  • 通过下面的py文件,我们就可以愉快的查看源码了。
  • 模型训练、评估细节,可参考官方文档:https://github.com/PaddlePaddle/PaddleOCR/blob/main/docs/ppocr/model_train/recognition.md
def train_rec():
    from tools.train import program, set_seed, main
    # 配置文件的源地址地址: paddleocr/configs/rec/rec_icdar15_train.yml
    config, device, logger, vdl_writer = program.preprocess(is_train=True)

    ###############修改配置(也可在yml文件中修改)##################
    # 加载预训练模型
    # https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar
    # 或者 https://paddleocr.bj.bcebos.com/PP-OCRv3/english/en_PP-OCRv3_rec_train.tar
    config["Global"]["pretrained_model"] = r"D:\python\models\layout_ocr\rec_mv3_none_bilstm_ctc_v2.0_train\best_accuracy"
    # 字典路径(这里只支持26个小写字母+10个数字)
    config["Global"]["character_dict_path"] = r"D:\python\py_works\paddleocr\ppocr\utils\ic15_dict.txt"
    # 评估频率
    config["Global"]["eval_batch_step"] = [0, 200]
    # log的打印频率
    config["Global"]["print_batch_step"] = 50
    # 训练的epochs
    config["Global"]["epoch_num"] = 1
    # 随机种子
    seed = config["Global"]["seed"] if "seed" in config["Global"] else 1024
    set_seed(seed)

    ###############模型训练##################
    main(config, device, logger, vdl_writer, seed)

    
def infer_rec_img():
    # 加载自己训练的模型
    from tools.infer_rec import main, program

    config, device, logger, vdl_writer = program.preprocess()
    config["Global"]["use_gpu"] = False
    config["Global"]["infer_img"] = r"D:/python/py_works/paddleocr/tests/rec_img_slow.png"
    config["Global"]["checkpoints"] = r"D:\python\py_works\paddleocr\tests\output\rec\ic15\best_accuracy"
    config["Global"]["character_dict_path"] = r"D:\python\py_works\paddleocr\ppocr\utils\ic15_dict.txt"
    # 这里为了能python文件执行,加了add_config这个参数,源码中没有
    main(add_config=(config, device, logger, vdl_writer))    

if __name__ == '__main__':
    # 模型训练
    train_rec()
    # 加载自己模型进行推理
    # ppocr INFO: 	 result: slow	0.863487958908081
    infer_rec_img()

2.2.3 数据集的加载

# paddleocr/tools/train.py
def main(config, device, logger, vdl_writer, seed):
    # init dist environment
    if config["Global"]["distributed"]:
        dist.init_parallel_env()

    global_config = config["Global"]

    # build dataloader
    set_signal_handlers()
    # 1、创建dataloader
    train_dataloader = build_dataloader(config, "Train", device, logger, seed)

    ......
    if config["Eval"]:
        valid_dataloader = build_dataloader(config, "Eval", device, logger, seed)
    else:
        valid_dataloader = None
    step_pre_epoch = len(train_dataloader)

    # 2、后处理程序
    # build post process
    post_process_class = build_post_process(config["PostProcess"], global_config)

    # 3、模型构建
    # build model
    .....
    model = build_model(config["Architecture"])

    use_sync_bn = config["Global"].get("use_sync_bn", False)
    if use_sync_bn:
        model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
        logger.info("convert_sync_batchnorm")

    model = apply_to_static(model, config, logger)

    # 4、构建损失函数
    # build loss
    loss_class = build_loss(config["Loss"])

    # 5、构建优化器
    # build optim
    optimizer, lr_scheduler = build_optimizer(
        config["Optimizer"],
        epochs=config["Global"]["epoch_num"],
        step_each_epoch=len(train_dataloader),
        model=model,
    )

    # 6、创建评估函数
    # build metric
    eval_class = build_metric(config["Metric"])
    ......

    # 7、加载预训练模型
    # load pretrain model
    pre_best_model_dict = load_model(
        config, model, optimizer, config["Architecture"]["model_type"]
    )

    if config["Global"]["distributed"]:
        model = paddle.DataParallel(model)

    # 8、模型训练
    # start train
    program.train(
        config,
        train_dataloader,
        valid_dataloader,
        device,
        model,
        loss_class,
        optimizer,
        lr_scheduler,
        post_process_class,
        eval_class,
        pre_best_model_dict,
        logger,
        step_pre_epoch,
        vdl_writer,
        scaler,
        amp_level,
        amp_custom_black_list,
        amp_custom_white_list,
        amp_dtype,
    )
  • 通过build_dataloader加载数据集
# paddleocr/ppocr/data/__init__.py
def build_dataloader(config, mode, device, logger, seed=None):
    config = copy.deepcopy(config)

    support_dict = [
        "SimpleDataSet", # 配置文件中为SimpleDataSet
        "LMDBDataSet",
        "PGDataSet",
        "PubTabDataSet",
        "LMDBDataSetSR",
        "LMDBDataSetTableMaster",
        "MultiScaleDataSet",
        "TextDetDataset",
        "TextRecDataset",
        "MSTextRecDataset",
        "PubTabTableRecDataset",
        "KieDataset",
        "LaTeXOCRDataSet",
    ]
    module_name = config[mode]["dataset"]["name"]
    assert module_name in support_dict, Exception(
        "DataSet only support {}".format(support_dict)
    )
    assert mode in ["Train", "Eval", "Test"], "Mode should be Train, Eval or Test."
    # 1、创建dataset
    dataset = eval(module_name)(config, mode, logger, seed)
    ......
    # 2、创建data_loader
    data_loader = DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        places=device,
        num_workers=num_workers,
        return_list=True,
        use_shared_memory=use_shared_memory,
        collate_fn=collate_fn,
    )

    return data_loader
  • SimpleDataSet中定义了数据的预处理:

    • 1、通过ppocr.data.imaug.operators.DecodeImage 将读取的image(二进制数据)转换为numpy数组, 即(高度H、宽度W、通道数C)

    • 2、通过ppocr.data.imaug.label_ops.CTCLabelEncode将label进行标签编码+one-hot编码(data[“length”] + data[“label”] + data[“label_ace”])

    • 3、通过ppocr.data.imaug.rec_img_aug.RecResizeImg将image进行缩放+归一化+padding, image shape=(3, 32, 100)

    • 4、通过ppocr.data.imaug.operators.KeepKeys 仅仅将['image', 'label', 'length']保存到list中

class SimpleDataSet(Dataset):
    ......
    
    def __getitem__(self, idx):
        file_idx = self.data_idx_order_list[idx]
        data_line = self.data_lines[file_idx]
        try:
            data_line = data_line.decode("utf-8")
            substr = data_line.strip("\n").strip("\r").split(self.delimiter)
            file_name = substr[0]
            file_name = self._try_parse_filename_list(file_name)
            label = substr[1]
            img_path = os.path.join(self.data_dir, file_name)
            data = {"img_path": img_path, "label": label}
            if not os.path.exists(img_path):
                raise Exception("{} does not exist!".format(img_path))
            with open(data["img_path"], "rb") as f:
                img = f.read()
                data["image"] = img
            data["ext_data"] = self.get_ext_data()
            # 数据transform
            # ppocr.data.imaug.operators.DecodeImage    1、将读取的image(二进制数据)转换为numpy数组, 即(高度H、宽度W、通道数C)
            # ppocr.data.imaug.label_ops.CTCLabelEncode 2、将label进行标签编码+one-hot编码(data["length"] + data["label"] + data["label_ace"])
            # ppocr.data.imaug.rec_img_aug.RecResizeImg 3、将image进行缩放+归一化+padding, image shape=(3, 32, 100)
            # ppocr.data.imaug.operators.KeepKeys       4、仅仅将['image', 'label', 'length']保存到list中
            outs = transform(data, self.ops)
        except:
           ......
        return outs

    def __len__(self):
        return len(self.data_idx_order_list)

2.2.4 后处理函数

通过build_post_process构建后处理,

# ppocr/postprocess/rec_postprocess.py  后处理
class CTCLabelDecode(BaseRecLabelDecode):
    """Convert between text-label and text-index"""

    def __init__(self, character_dict_path=None, use_space_char=False, **kwargs):
        super(CTCLabelDecode, self).__init__(character_dict_path, use_space_char)

    def __call__(self, preds, label=None, return_word_box=False, *args, **kwargs):
        if isinstance(preds, tuple) or isinstance(preds, list):
            preds = preds[-1]
        if isinstance(preds, paddle.Tensor):
            preds = preds.numpy()
        # 1、获取各个时间步上的最大索引值以及概率
        preds_idx = preds.argmax(axis=2)
        preds_prob = preds.max(axis=2)
        # 2、将索引值进行解码(转换为文字)
        text = self.decode(
            preds_idx,
            preds_prob,
            is_remove_duplicate=True,
            return_word_box=return_word_box,
        )
        if return_word_box:
            for rec_idx, rec in enumerate(text):
                wh_ratio = kwargs["wh_ratio_list"][rec_idx]
                max_wh_ratio = kwargs["max_wh_ratio"]
                rec[2][0] = rec[2][0] * (wh_ratio / max_wh_ratio)
        if label is None:
            # 3、推理时,直接返回text[('解码后的文字', 文字置信度的平均值)]
            return text
        # 4、训练时, 还需要将label进行解码
        label = self.decode(label)
        return text, label

    def add_special_char(self, dict_character):
        dict_character = ["blank"] + dict_character
        return dict_character
  • 核心函数就是BaseRecLabelDecode类中的decode函数
    # ppocr/postprocess/rec_postprocess.py中的BaseRecLabelDecode类
    def decode(
        self,
        text_index,
        text_prob=None,
        is_remove_duplicate=False,
        return_word_box=False,
    ):
        """convert text-index into text-label."""
        result_list = []
        ignored_tokens = self.get_ignored_tokens() # 忽略tokens, 其中[0]代表ctc中的blank位
        batch_size = len(text_index)
        for batch_idx in range(batch_size):
            selection = np.ones(len(text_index[batch_idx]), dtype=bool)
            if is_remove_duplicate:
                # 1、合并blank之间相同的字符,即【当前位置索引】和【下一位置索引】不相同的就保留
                selection[1:] = text_index[batch_idx][1:] != text_index[batch_idx][:-1]
            for ignored_token in ignored_tokens:
                selection &= text_index[batch_idx] != ignored_token
            # 2、将解码的结果存在char_list中
            char_list = [
                self.character[text_id] for text_id in text_index[batch_idx][selection]
            ]
            if text_prob is not None:
                conf_list = text_prob[batch_idx][selection]
            else:
                conf_list = [1] * len(selection)
            if len(conf_list) == 0:
                conf_list = [0]
            # 3、char_list合并为字符串
            text = "".join(char_list)

            if self.reverse:  # for arabic rec
                text = self.pred_reverse(text)

            if return_word_box:
               ......
            else:
                # 置信度为每个识别的字符置信度的平均值
                result_list.append((text, np.mean(conf_list).tolist()))
        return result_list

2.2.5 CTC Loss

  • 通过build_loss函数构建CTC Loss
  • CRNN 模型的损失函数为 CTC loss, 飞桨集成了常用的 Loss 函数,只需调用实现即可
#  paddleocr/ppocr/losses/rec_ctc_loss.py
class CTCLoss(nn.Layer):
    def __init__(self, use_focal_loss=False, **kwargs):
        super(CTCLoss, self).__init__()
         # blank 是 ctc 的无意义连接符
        self.loss_func = nn.CTCLoss(blank=0, reduction="none")
        self.use_focal_loss = use_focal_loss

    def forward(self, predicts, batch):
        if isinstance(predicts, (list, tuple)):
            predicts = predicts[-1]
        # 转置模型 head 层的预测结果,沿channel层排列 
        # (bs, 25, 37) -> (25, bs, 37)
        predicts = predicts.transpose((1, 0, 2)) 
        N, B, _ = predicts.shape
        # [N, N, ..., N]一共bs个,每个长度都为N
        preds_lengths = paddle.to_tensor(
            [N] * B, dtype="int64", place=paddle.CPUPlace()
        )
        # batch一个list
        # batch[0]为bs个预处理好image Tensor
        # batch[1]为bs个编码好的token序列,即label,shape = (bs, seq_len)
        # batch[2]为bs个token序列的实际长度(因为有填充)
        labels = batch[1].astype("int32")
        label_lengths = batch[2].astype("int64")
        # 计算损失函数
        loss = self.loss_func(predicts, labels, preds_lengths, label_lengths)
        if self.use_focal_loss:
            weight = paddle.exp(-loss)
            weight = paddle.subtract(paddle.to_tensor([1.0]), weight)
            weight = paddle.square(weight)
            loss = paddle.multiply(loss, weight)
        loss = loss.mean()
        return {"loss": loss}    

其他训练细节诸如:构建优化器、创建评估函数、加载预训练模型、模型训练等,大家可以查看源码,不再赘述。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/872404.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

七,Spring Boot 当中的 yaml 语法使用

七,Spring Boot 当中的 yaml 语法使用 文章目录 七,Spring Boot 当中的 yaml 语法使用1. yaml 的介绍2. yaml 基本语法3. yaml 数据类型4. 学习测试的准备工作4.1 yaml 字面量4.2 yaml 数组4.3 yaml 对象 5. yaml 使用细节和注意事项6. 总结:…

2024高教社杯数学建模竞赛解题思路

高教社杯数学建模竞赛解题思路:独家出版,思路解析模型代码结果可视化。 A题思路及程序链接:https://mbd.pub/o/bread/ZpqblJZs B题思路及程序链接:https://mbd.pub/o/bread/ZpqblJZx D题思路及程序链接:https://mbd.pu…

常用排序算法(上)

目录 前言: 1.排序的概念及其运用 1.1排序的概念 1.2排序运用 1.3 常见的排序算法 2.常见排序算法的实现 2.1 堆排序 2.1 1 向下调整算法 2.1 2 建堆 2.1 3 排序 2.2 插入排序 2.1.1基本思想: 2.1.2直接插入排序: 2.1.3 插…

elementUI——checkbox复选框监听不到change事件,通过watch监听来解决——基础积累

今天在写后台管理系统的时候,遇到一个需求,就是要求监听复选框的change事件,场景就是:两个复选框互斥,且可以取消勾选。 就是这两个复选框可以同时都不勾选,如果勾选的话,另一个一定要取消勾选。…

具身智能猜想 ——机器人进化

设想一个机器人进化的仿真模拟环境,可以通过 “基因突变” 产生新功能,让机器人逐步进化。以下是这个进化系统的关键要素和可能的实现步骤: 1. 仿真环境 虚拟世界:创建一个包含多样化任务和挑战的虚拟环境,如探索、抓…

多智能体强化学习:citylearn城市建筑能量优化和需求响应

今天分享一个用于能量优化的强化学习框架,citylearn 代码量非常庞大,我都不敢看,看也看不完,不花一定的时间难以搞懂它的原理。 CityLearn(CL)环境是一个类似 OpenAI Gym 的环境,它通过控制不…

UE5 C++ 读取图片插件(一)

原来UE可以使用 static,之前不知道&#xff0c;一用就报错。 static TSharedPtr<IImageWrapper> GetImageWrapperByExtention(const FString InImagePath); //智能指针&#xff0c;方便追寻引用C,加载ImageWrapperstatic UTexture2D* LoadTexture2D(const FString& …

代码随想录 刷题记录-28 图论 (5)最短路径

一、dijkstra&#xff08;朴素版&#xff09;精讲 47. 参加科学大会 思路 本题就是求最短路&#xff0c;最短路是图论中的经典问题即&#xff1a;给出一个有向图&#xff0c;一个起点&#xff0c;一个终点&#xff0c;问起点到终点的最短路径。 接下来讲解最短路算法中的 d…

matter的Commissioning(入网过程)整体流程、加密方式、通信信息结构

在Matter协议中&#xff0c;**控制器负责将新设备加入网络&#xff08;commissioning&#xff09;**的整个流程&#xff0c;这一过程包括设备的发现、验证、授权、加入Fabric&#xff0c;以及最终建立数据通信的步骤。配网完成后的数据通信过程同样遵循严格的加密方式&#xff…

C语言 | Leetcode C语言题解之第385题迷你语法分析器

题目&#xff1a; 题解&#xff1a; struct NestedInteger* helper(const char * s, int * index){if (s[*index] [) {(*index);struct NestedInteger * ni NestedIntegerInit();while (s[*index] ! ]) {NestedIntegerAdd(ni, helper(s, index));if (s[*index] ,) {(*index…

TCP的流量控制深入理解

在理解流量控制之前我们先需要理解TCP的发送缓冲区和接收缓冲区&#xff0c;也称为套接字缓冲区。首先我们先知道缓冲区存在于哪个位置&#xff1f; 其中缓冲区存在于Socket Library层。 而我们的发送窗口和接收窗口就存在于缓冲区当中。在实现滑动窗口时则将两个指针指向缓冲区…

社交媒体的智能变革:Facebook AI优化用户体验

Facebook作为全球领先的社交平台&#xff0c;一直致力于通过人工智能&#xff08;AI&#xff09;技术提升用户体验。AI技术在Facebook的应用涵盖了推荐系统、自然语言处理、广告投放和用户反馈等多个方面&#xff0c;使平台的互动和内容体验更加智能和个性化。 推荐系统的智能化…

火焰传感器详解(STM32)

目录 一、介绍 二、传感器原理 1.原理图 2.引脚描述 三、程序设计 main.c文件 IR.h文件 IR.c文件 四、实验效果 五、资料获取 项目分享 一、介绍 火焰传感器是一种常用于检测火焰或特定波长&#xff08;760nm-1100nm&#xff09;红外光的传感器。探测角度60左右&am…

高压喷雾车的功能与应用_鼎跃安全

在一次森林火灾中&#xff0c;位于山区的一个小型度假村附近突然起火&#xff0c;由于山风强劲&#xff0c;火势迅速蔓延&#xff0c;消防部门立即调派多辆高压喷雾车赶往现场。在扑救过程中&#xff0c;传统消防车难以进入崎岖的山路&#xff0c;但高压喷雾车凭借其高机动性顺…

大模型笔记01--基于ollama和open-webui快速部署chatgpt

大模型笔记01--基于ollama和open-webui快速部署chatgpt 介绍部署&测试安装ollama运行open-webui测试 注意事项说明 介绍 近年来AI大模型得到快速发展&#xff0c;各种大模型如雨后春笋一样涌出&#xff0c;逐步融入各行各业。与之相关的各类开源大模型系统工具也得到了快速…

【神经网络系列(高级)】神经网络Grokking现象的电路效率公式——揭秘学习飞跃的秘密【通俗理解】

【通俗理解】神经网络Grokking现象的电路效率公式 论文地址&#xff1a; https://arxiv.org/abs/2309.02390 参考链接&#xff1a; [1]https://x.com/VikrantVarma_/status/1699823229307699305 [2]https://pair.withgoogle.com/explorables/grokking/ 关键词提炼 #Grokkin…

【办公效率】Axure会议室预订小程序原型图,含PRD需求文档和竞品分析

作品说明 作品页数&#xff1a;共50页 兼容版本&#xff1a;Axure RP 8/9/10 应用领域&#xff1a;中小型企业的会议室在线预订 作品申明&#xff1a;页面内容仅用于功能演示&#xff0c;无实际功能 作品特色 本作品为会议室预订小程序原型图&#xff0c;定位于拥有中大型…

Python 人脸识别实战教程

引言 在本教程中&#xff0c;我们将深入探讨如何使用Python和OpenCV库来实现人脸检测与识别。本文从基础知识入手&#xff0c;逐步构建一个简单的人脸识别系统。本教程假设读者已经熟悉Python编程&#xff0c;并具备一定的OpenCV使用经验。 环境配置 安装必要的库 确保您的…

GitLab 是什么?GitLab使用常见问题解答

GitLab 是什么 GitLab是由GitLab Inc.开发&#xff0c;使用MIT许可证的基于网络的Git仓库管理工具开源项目&#xff0c;且具有wiki和issue跟踪功能&#xff0c;使用Git作为代码管理工具&#xff0c;并在此基础上搭建起来的web服务。 ​GitLab 是由 GitLab Inc.开发&#xff0c…

COD论文笔记 ECCV2024 Just a Hint: Point-Supervised Camouflaged Object Detection

这篇论文的主要动机、现有方法的不足、拟解决的问题、主要贡献和创新点&#xff1a; 1. 动机 伪装物体检测&#xff08;Camouflaged Object Detection, COD&#xff09;旨在检测隐藏在环境中的伪装物体&#xff0c;这是一个具有挑战性的任务。由于伪装物体与背景的细微差别和…