1. OCR文字检测与辨认体系:交融文字检测、文字辨认和方向分类器的综合处理方案

前两章首要介绍了DBNet文字检测算法以及CRNN文字辨认算法。可是关于咱们实践场景中的一张图画,想要独自依据文字检测或许辨认模型,是无法一起获取文字位置与文字内容的,因而,咱们将文字检测算法以及文字辨认算法进行串联,构建了PP-OCR文字检测与辨认体系。在实践运用进程中,检测出的文字方向可能不是咱们期望的方向,终究导致文字辨认过错,因而咱们在PP-OCR体系中也引进了方向分类器。

本章首要介绍PP-OCR文字检测与辨认体系以及该体系中涉及到的优化战略。经过本节课的学习,您能够获得:

  • PaddleOCR战略调优技巧
  • 文本检测、辨认、方向分类器模型的优化技巧和优化办法

PP-OCR体系共经历了2次优化,下面对PP-OCR体系和这2次优化进行简略介绍。

1.1 PP-OCR体系与优化战略简介

PP-OCR中,关于一张图画,假如期望提取其间的文字信息,需求完结以下几个进程:

  • 运用文本检测的办法,获取文本区域多边形信息(PP-OCR中文本检测运用的是DBNet,因而获取的是四点信息)。
  • 对上述文本多边形区域进行裁剪与透视改换校对,将文本区域转化成矩形框,再运用方向分类器对方向进行校对。
  • 依据包括文字区域的矩形框进行文本辨认,得到终究辨认成果。

上面便完结了关于一张图画的文本检测与辨认进程。

PP-OCR的体系框图如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

PP-OCR体系框图

文本检测依据后处理方案比较简略的DBNet,文字区域校对首要运用几许改换以及方向分类器,文本辨认运用了依据交融了卷积特征与序列特征的CRNN模型,运用CTC loss处理猜测成果与标签不共同的问题。

PP-OCR从主干网络、学习率战略、数据增广、模型裁剪量化等方面,共运用了19个战略,对模型进行优化减肥,终究打造了面向服务器端的PP-OCR server体系以及面向移动端的PP-OCR mobile体系。

1.2 PP-OCRv2体系与优化战略简介

比较于PP-OCR, PP-OCRv2 在主干网络、数据增广、丢失函数这三个方面进行进一步优化,处理端侧猜测功率较差、布景杂乱以及类似字符的误识等问题,一起引进了常识蒸馏练习战略,进一步进步模型精度。详细地:

  • 检测模型优化: (1) 选用 CML 协同互学习常识蒸馏战略;(2) CopyPaste 数据增广战略;
  • 辨认模型优化: (1) PP-LCNet 轻量级主干网络;(2) U-DML 改善常识蒸馏战略; (3) Enhanced CTC loss 丢失函数改善。

从作用上看,首要有三个方面进步:

  • 在模型作用上,相关于 PP-OCR mobile 版别进步超7%;
  • 在速度上,相关于 PP-OCR server 版别进步逾越220%;
  • 在模型巨细上,11.6M 的总巨细,服务器端和移动端都能够轻松布置。

PP-OCRv2 模型与之前 PP-OCR 系列模型的精度、猜测耗时、模型巨细比照图如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

PP-OCRv2与PP-OCR的速度、精度、模型巨细比照

PP-OCRv2的体系框图如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

PP-OCRv2体系框图

2. PP-OCR 优化战略

PP-OCR体系包括文本检测器、方向分类器以及文本辨认器。本节针对这三个方向的模型优化战略进行详细介绍。

2.1 文本检测

PP-OCR中的文本检测依据DBNet (Differentiable Binarization)模型,它依据切割方案,后处理简略。DBNet的详细模型结构如下图。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
DBNet框图

DBNet经过主干网络(backbone)提取特征,运用DBFPN的结构(neck)对各阶段的特征进行交融,得到交融后的特征。交融后的特征经过卷积等操作(head)进行解码,生成概率图和阈值图,二者交融后核算得到一个近似的二值图。核算丢失函数时,对这三个特征图均核算丢失函数,这儿把二值化的监督也也参加练习进程,然后让模型学习到更准确的鸿沟。

DBNet中运用了6种优化战略用于进步模型精度与速度,包括主干网络、特征金字塔网络、头部结构、学习率战略、模型裁剪等战略。在验证集上,不同模块的融化试验定论如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
DBNet融化试验

下面进行详细阐明。

2.1.1 轻量级主干网络

主干网络的巨细对文本检测器的模型巨细有重要影响。因而,在构建超轻量检测模型时,应挑选轻量的主干网络。跟着图画分类技能的开展,MobileNetV1、MobileNetV2、MobileNetV3和ShuffleNetV2系列常用作轻量主干网络。每个系列都有不同的模型巨细和功能体现。PaddeClas供给了20多种轻量级主干网络。他们在ARM上的精度-速度曲线如下图所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
PaddleClas中主干网络的”速度-精度”曲线

在猜测时刻相同的情况下,MobileNetV3系列能够完结更高的精度。作者在规划的时分为了覆盖尽可能多的场景,运用scale这个参数来调整特征图通道数,规范为1x,假如是0.5x,则表明该网络中部分特征图通道数为1x对应网络的0.5倍。为了进一步平衡准确率和功率,在V3的规范挑选上,咱们选用了MobileNetV3_large 0.5x的结构。

下面打印出DBNet中MobileNetV3各个阶段的特征图规范。

2.1.2 轻量级特征金字塔网络DBFPN结构

文本检测器的特征交融(neck)部分DBFPN与方针检测使命中的FPN结构类似,交融不同规范的特征图,以进步不同规范的文本区域检测作用。

为了方便兼并不同通道的特征图,这儿运用11的卷积将特征图削减到相同数量的通道。

概率图和阈值图是由卷积交融的特征图生成的,卷积也与inner_channels相关联。因而,inner_channels对模型规范有很大的影响。当inner_channels由256减小到96时,模型规范由7M减小到4.1M,速度进步48%,但精度仅仅略有下降。

下面打印DBFPN的结构以及关于主干网络特征图的交融成果。

2.1.3 主干网络中SE模块分析

SE是squeeze-and-excitation的缩写(Hu, Shen, and Sun 2018)。如图所示

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
SE模块示意图

SE块显式地建模通道之间的相互依靠关系,并自适应地从头校准通道特征呼应。在网络中运用SE块能够显着进步视觉使命的准确性,因而MobileNetV3的查找空间包括了SE模块,终究MobileNetV3中也包括很多个SE模块。可是,当输入分辨率较大时,例如640640,运用SE模块较难估计通道的特征呼应,精度进步有限,但SE模块的时刻本钱十分高。在DBNet中,咱们将SE模块从主干网络中移除,模型巨细从4.1M降到2.6M,但精度没有影响。

PaddleOCR中能够经过设置disable_se=True来移除主干网络中的SE模块,运用办法如下所示。

2.1.4 学习率战略优化

  • Cosine 学习率下降战略

梯度下降算法需求咱们设置一个值,用来控制权重更新幅度,咱们将其称之为学习率。它是控制模型学习速度的超参数。学习率越小,loss的改变越慢。尽管运用较低的学习速率能够确保不会错过任何局部极小值,但这也意味着模型收敛速度较慢。

因而,在练习前期,权重处于随机初始化状态,咱们能够设置一个相对较大的学习速率以加速收敛速度。在练习后期,权重接近最优值,运用相对较小的学习率能够防止模型在收敛的进程中产生震荡。

Cosine学习率战略也就应运而生,Cosine学习率战略指的是学习率在练习的进程中,依照余弦的曲线改变。在整个练习进程中,Cosine学习率衰减战略使得在网络在练习初期坚持了较大的学习速率,在后期学习率会逐渐衰减至0,其收敛速度相对较慢,但终究收敛精度较好。下图比较了两种不同的学习率衰减战略piecewise decaycosine decay

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
Cosine与Piecewise学习率下降战略
  • 学习率预热战略

模型刚开始练习时,模型权重是随机初始化的,此刻若挑选一个较大的学习率,可能造成模型练习不稳定的问题,因而学习率预热的概念被提出,用于处理模型练习初期不收敛的问题。

学习率预热指的是将学习率从一个很小的值开始,逐步增加到初始较大的学习率。它能够确保模型在练习初期的稳定性。运用学习率预热战略有助于进步图画分类使命的准确性。在DBNet中,试验标明该战略也是有用的。学习率预热战略与Cosine学习率结合时,学习率的改变趋势如下代码演示。

2.1.5 模型裁剪战略-FPGM

深度学习模型中一般有比较多的参数冗余,咱们能够运用一些办法,去除模型中比较冗余的当地,然后进步模型推理功率。

模型裁剪指的是经过去除网络中冗余的通道(channel)、滤波器(filter)、神经元(neuron)等,来得到一个更轻量的网络,一起尽可能确保模型精度。

比较于裁剪通道或许特征图的办法,裁剪滤波器的办法能够得到更加规矩的模型,因而削减内存消耗,加速模型推理进程。

之前的裁剪滤波器的办法大多依据范数进行裁剪,即,认为范数较小的滤波器重要程度较小,可是这种办法要求存在的滤波器的最小范数应该趋近于0,不然咱们难以去除。

针对上面的问题,依据几许中心点的裁剪算法(Filter Pruning via Geometric Median, FPGM)被提出。FPGM将卷积层中的每个滤波器都作为欧几里德空间中的一个点,它引进了几许中位数这样一个概念,即与一切采样点间隔之和最小的点。假如一个滤波器的接近这个几许中位数,那咱们能够认为这个滤波器的信息和其他滤波器重合,能够去掉。

FPGM与依据范数的裁剪算法的比照如下图所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
FPGM裁剪示意图

在PP-OCR中,咱们运用FPGM对检测模型进行剪枝,终究DBNet的模型精度只要轻微下降,可是模型巨细减小46%,猜测速度加速19%

关于FPGM模型裁剪完结的更多细节能够参阅PaddleSlim。

留意:

  1. 模型裁剪需求从头练习模型,能够参阅PaddleOCR剪枝教程。
  2. 裁剪代码是依据DBNet进行适配,假如您需求对自己的模型进行剪枝,需求从头分析模型结构、参数的敏感度,咱们通常情况下只主张裁剪相对敏感度低的参数,而跳过敏感度高的参数。
  3. 每个卷积层的剪枝率关于裁剪后模型的功能也很重要,用完全相同的裁剪率去进行模型裁剪通常会导致显着的功能下降。
  4. 模型裁剪不是一蹴即至的,需求进行重复的试验,才干得到符合要求的模型。

2.1.6 文本检测装备阐明

下面给出DBNet的练习装备扼要阐明,完好的装备文件能够参阅:ch_det_mv3_db_v2.0.yml。

Architecture:                       # 模型结构界说
  model_type: det
  algorithm: DB
  Transform:
  Backbone:
    name: MobileNetV3               # 装备主干网络
    scale: 0.5
    model_name: large
    disable_se: True                # 去除SE模块
  Neck:
    name: DBFPN                     # 装备DBFPN
    out_channels: 96                # 装备 inner_channels
  Head:
    name: DBHead
    k: 50
Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine                   # 装备cosine学习率下降战略
    learning_rate: 0.001           # 初始学习率
    warmup_epoch: 2                # 装备学习率预热战略
  regularizer:
    name: 'L2'                     # 装备L2正则
    factor: 0                      # 正则项的权重

2.1.7 PP-OCR 检测优化总结

上面给咱们介绍了PP-OCR中文字检测算法的优化战略,这儿再给咱们回顾一下不同优化战略对应的融化试验与定论。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
DBNet融化试验

经过轻量级主干网络、轻量级neck结构、SE模块的分析和去除、学习率调整及优化、模型裁剪等战略,DBNet的模型巨细从7M削减至1.5M。经过学习率战略优化等练习战略优化,DBNet的模型精度进步逾越1%

PP-OCR中,超轻量DBNet检测作用如下所示:

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

下面展现快速运用文字检测模型的猜测作用。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

2.2 方向分类器

方向分类器的使命是用于分类出文本检测出的文本实例的方向,将文本旋转到0度之后,再送入后续的文本辨认器中。PP-OCR中,咱们考虑了0度和180度2个方向。下面详细介绍针对方向分类器的速度、精度优化战略。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
方向分类器融化试验

2.2.1 轻量级主干网络

与文本检测器相同,咱们仍然选用MobileNetV3作为方向分类器的主干网络。因为方向分类的使命相对简略,咱们运用MobileNetV3 small 0.35x来平衡模型精度与猜测功率。试验标明,即便当运用更大的主干时,精度不会有进一步的进步。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
不同主干网络下的方向分类器精度比照

2.2.2 数据增强

数据增强指的是对图画改换,送入网络进行练习,它能够进步网络的泛化功能。常用的数据增强包括旋转、透视失真改换、运动含糊改换和高斯噪声改换等,PP-OCR中,咱们统称这些数据增强办法为BDA(Base Data Augmentation)。成果标明,BDA能够显着进步方向分类器的精度。

下面展现一些BDA数据增广办法的作用

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
BDA数据增广作用

除了BDA外,咱们还参加了一些更高阶的数据增强操作来进步分类的作用,例如 AutoAugment (Cubuk et al. 2019), RandAugment (Cubuk et al. 2020), CutOut (DeVries and Taylor 2017), RandErasing (Zhong et al. 2020), HideAndSeek (Singh and Lee 2017), GridMask (Chen 2020), Mixup (Zhang et al. 2017) 和 Cutmix (Yun et al. 2019)。

这些数据增广大体分为3个类别:

(1)图画改换类:AutoAugment、RandAugment

(2)图画裁剪类:CutOut、RandErasing、HideAndSeek、GridMask

(3)图画混叠类:Mixup、Cutmix

下面给出不同高阶数据增广的可视化比照成果。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
高阶数据增广可视化作用

可是试验标明,除了RandAugment 和 RandErasing 外,大多数办法都不适用于方向分类器。下图也给出了在不同数据增强战略下,模型精度的改变。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

终究,咱们在练习时结合BDA和RandAugment,作为方向分类器的数据增强战略。

  • RandAugment代码演示

2.2.3 输入分辨率优化

一般来说,当图画的输入分辨率进步时,精度也会进步。因为方向分类器的主干网络参数量很小,即便进步了分辨率也不会导致推理时刻的显着增加。咱们将方向分类器的输入图画规范从3x32x100增加到3x48x192,方向分类器的精度从92.1%进步至94.0%,可是猜测耗时仅仅从3.19ms进步至3.21ms

下面给出两种规范下的图画巨细比照。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
32×100和48×192规范下的图画巨细比照

2.2.4 模型量化战略-PACT

模型量化是一种将浮点核算转成低比特定点核算的技能,能够使神经网络模型具有更低的推迟、更小的体积以及更低的核算功耗。

模型量化首要分为离线量化和在线量化。其间,离线量化是指一种利用KL散度等办法来确定量化参数的定点量化办法,量化后不需求再次练习;在线量化是指在练习进程中确定量化参数,比较离线量化模式,它的精度丢失更小。

PACT(PArameterized Clipping acTivation)是一种新的在线量化办法,能够提前从激活层中去除一些极点值。在去除极点值后,模型能够学习更适宜的量化参数。一般PACT办法的激活值的预处理是依据RELU函数的,公式如下:

y=PACT(x)=0.5(∣x∣−∣x−∣+)={0x∈(−∞,0)xx∈[0,)x∈[,+∞) y=P A C T(x)=0.5(|x|-|x-\alpha|+\alpha)=\left\{\begin{array}{cc} 0 & x \in(-\infty, 0) \\ x & x \in[0, \alpha) \\ \alpha & x \in[\alpha,+\infty) \end{array}\right.

一切大于特定阈值的激活值都会被重置为一个常数。可是,MobileNetV3中的激活函数不仅是ReLU,还包括hardswish。因而运用一般的PACT量化会导致更高的精度丢失。因而,为了削减量化丢失,咱们将激活函数的公式修正为:

y=PACT(x)={−x∈(−∞,−)xx∈[−,)x∈[,+∞) y=P A C T(x)=\left\{\begin{array}{rl} -\alpha & x \in(-\infty,-\alpha) \\ x & x \in[-\alpha, \alpha) \\ \alpha & x \in[\alpha,+\infty) \end{array}\right.

PaddleOCR中供给了适用于PP-OCR套件的量化脚本。详细链接能够参阅PaddleOCR模型量化教程。

2.2.5 方向分类器装备阐明

练习方向分类器时,装备文件中的部分关键字段和阐明如下所示。完好装备文件能够参阅cls_mv3.yml。

Architecture:
  model_type: cls
  algorithm: CLS
  Transform:
  Backbone:
    name: MobileNetV3                                                 # 装备分类模型为MobileNetV3
    scale: 0.35
    model_name: small
  Neck:
  Head:
    name: ClsHead
    class_dim: 2
Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/cls
    label_file_list:
      - ./train_data/cls/train.txt
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - ClsLabelEncode: # Class handling label
      - RecAug:                                                    
          use_tia: False                                             # 装备BDA数据增强,不运用TIA数据增强
      - RandAugment:                                                 # 装备随机增强数据增强办法
      - ClsResizeImg:
          image_shape: [3, 48, 192]                                  # 这儿将[3, 32, 100]修正为[3, 48, 192],进行输入分辨率优化
      - KeepKeys:
          keep_keys: ['image', 'label'] # dataloader will return list in this order
  loader:
    shuffle: True
    batch_size_per_card: 512
    drop_last: True
    num_workers: 8

2.2.5 方向分类器试验总结

在方向分类器模型优化中,咱们运用轻量化主干网络以及模型量化,终究将模型从0.85M降低到了0.46M,运用组合数据增广、高分辨率等特征,终究将模型精度进步了逾越2%。融化试验比照如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
方向分类器融化试验

2.3 文本辨认

PP-OCR中,文本辨认器运用的是CRNN模型。练习的时分运用CTC loss去处理不定长文本的猜测问题。

CRNN模型结构如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
CRNN结构图

PP-OCR针对文本辨认器,从主干网络、头部结构优化、数据增强、正则化战略、特征图下采样战略、量化等多个视点进行模型优化,详细融化试验如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
CRNN辨认模型融化试验

下面详细介绍文本辨认模型的详细优化战略。

2.3.1 轻量级主干网络和头部结构

  • 轻量级主干网络

在文本辨认中,仍然选用了与文本检测相同的MobileNetV3作为backbone。选自MobileNetV3_small_x0.5进一步地平衡精度和功率。假如不要求模型巨细的话,能够挑选MobileNetV3_small_x1,模型巨细仅增加5M,精度显着进步。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
不同主干网络下的辨认模型精度比照
  • 轻量级头部结构

CRNN中,用于解码的轻量级头(head)是一个全衔接层,用于将序列特征解码为一般的猜测字符。序列特征的维数对文本辨认器的模型巨细影响十分大,特别是关于6000多个字符的中文辨认场景(序列特征维度若设置为256,则仅仅是head部分的模型巨细就为6.7M)。在PP-OCR中,咱们针对序列特征的维度打开试验,终究将其设置为48,平衡了精度与功率。部分融化试验定论如下。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
不同序列特征维度的精度比照

2.3.2 数据增强

除了前面提到的常常用于文本辨认的BDA(基本数据增强),TIA(Luo等人,2020)也是一种有用的文本辨认数据增强办法。TIA是一种针对场景文字的数据增强办法,它在图画中设置了多个基准点,然后随机移动点,经过几许改换生成新图画,这样大大进步了数据的多样性以及模型的泛化才能。TIA的基本流程图如图所示:

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

试验证明,运用TIA数据增广,能够协助文本辨认模型的精度在一个极高的baseline上面进一步进步0.9%

下面是TIA中三种涉及到的数据增广的可视化作用图。

2.3.3 学习率战略和正则化

在辨认模型练习中,学习率下降战略与文本检测相同,也运用了Cosine+Warmup的学习率战略。

正则化是一种广泛运用的防止过度拟合的办法,一般包括L1正则化和L2正则化。在大多数运用场景中,咱们都运用L2正则化。它首要的原理就是核算网络中权重的L2范数,增加到丢失函数中。在L2正则化的协助下,网络的权重趋向于挑选一个较小的值,终究整个网络中的参数趋向于0,然后缓解模型的过拟合问题,进步了模型的泛化功能。

咱们试验发现,关于文本辨认,L2正则化对辨认准确率有很大的影响。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
CRNN辨认模型融化试验

2.3.4 特征图降采样战略

咱们在做检测、切割、OCR等下流视觉使命时,主干网络一般都是运用的图画分类使命中的主干网络,它的输入分辨率一般设置为224×224,降采样时,一般宽度和高度会一起降采样。

可是关于文本辨认使命来说,因为输入图画一般是32×100,长宽比十分不平衡,此刻对宽度和高度一起降采样,会导致特征丢失严峻,因而图画分类使命中的主干网络应用到文本辨认使命中需求进行特征图降采样方面的适配(假如咱们自己换主干网络的话,这儿也需求留意一下)。

在PaddleOCR中,CRNN中文文本辨认模型设置的输入图画的高度和宽度设置为32和320。原始MobileNetV3来自分类模型,如前文所述,需求调整降采样的步长,适配文本图画输入分辨率。详细地,为了保存更多的水平信息,咱们将下采样特征图的步长从 (2,2) 修正为 (2,1) ,第一次下采样除外。终究如下图所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
降采样步长战略优化可视化

为了保存更多的垂直信息,咱们进一步将第二次下采样特征图的步长从 (2,1) 修正为 (1,1)。因而,第二个下采样特征图的步长s2会显著影响整个特征图的分辨率和文本辨认器的准确性。在PP-OCR中,s2被设置为(1,1),能够获得更好的功能。一起,因为水平的分辨率增加,CPU的推理时刻从11.84ms 增加到 12.96ms

下面给出了stride优化前后的特征图规范比照。尽管终究输出特征图规范相同,可是stride从(2,1)修正为(1,1)之后,特征信息在编码的进程中被保存得更为完好。

2.3.5 PACT 在线量化战略

咱们选用与方向分类器量化类似的方案来减小文本辨认器的模型巨细。因为LSTM量化的杂乱性,PP-OCR中没有对LSTM进行量化。运用该量化战略之后,模型巨细减小67.4%、猜测速度加速8%、准确率进步1.6%,量化能够削减模型冗余,增强模型的表达才能。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
模型量化融化试验

2.3.6 文字辨认预练习模型

运用适宜的预练习模型能够加速模型的收敛速度。在实在场景中,用于文本辨认的数据通常是有限的。PP-OCR中,咱们组成了千万等级的数据,对模型进行练习,之后再依据该模型,在实在数据上微调,终究辨认准确率从从65.81%进步到69%

2.3.7 文本辨认装备阐明

下面给出CRNN的练习装备扼要阐明,完好的装备文件能够参阅:rec_chinese_lite_train_v2.0.yml。

Optimizer:
  name: Adam
  beta1: 0.9
  beta2: 0.999
  lr:
    name: Cosine                                 # 装备Cosine 学习率下降战略
    learning_rate: 0.001 
    warmup_epoch: 5                              # 装备预热学习率
  regularizer:    
    name: 'L2'                                   # 装备L2正则
    factor: 0.00001
Architecture:
  model_type: rec
  algorithm: CRNN
  Transform:
  Backbone:
    name: MobileNetV3                             # 装备Backbone
    scale: 0.5
    model_name: small
    small_stride: [1, 2, 2, 2]                     # 装备下采样的stride
  Neck:
    name: SequenceEncoder
    encoder_type: rnn
    hidden_size: 48                               # 装备终究一层全衔接层的维度
  Head:
    name: CTCHead
    fc_decay: 0.00001
 Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/
    label_file_list: ["./train_data/train_list.txt"]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - RecAug:                                  # 装备数据增强BDA和TIA,TIA默认运用
      - CTCLabelEncode: # Class handling label
      - RecResizeImg:
          image_shape: [3, 32, 320]
      - KeepKeys:
          keep_keys: ['image', 'label', 'length'] # dataloader will return list in this order
  loader:
    shuffle: True
    batch_size_per_card: 256
    drop_last: True
    num_workers: 8

2.3.8 辨认优化小结

在模型体积方面,PP-OCR运用轻量级主干网络、序列维度裁剪、模型量化的战略,将模型巨细从4.5M减小至1.6M。在精度方面,运用TIA数据增强、Cosine-warmup学习率战略、L2正则、特征图分辨率改善、预练习模型等优化战略,终究在验证集上进步15.4%

PP-OCR中部分辨认作用如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

文本辨认模型的代码演示如下。

3. PP-OCRv2优化战略解读

第2节的内容首要是对PP-OCR以及它的19个优化战略进行了详细介绍。

比较于PP-OCR, PP-OCRv2 在主干网络、数据增广、丢失函数这三个方面进行进一步优化,处理端侧猜测功率较差、布景杂乱以及类似字符的误识等问题,一起引进了常识蒸馏练习战略,进一步进步模型精度。详细地:

  • 检测模型优化: (1) 选用 CML 协同互学习常识蒸馏战略;(2) CopyPaste 数据增广战略;
  • 辨认模型优化: (1) PP-LCNet 轻量级主干网络;(2) U-DML 改善常识蒸馏战略; (3) Enhanced CTC loss 丢失函数改善。

本节首要依据文字检测和辨认模型的优化进程,去解读PP-OCRv2的优化战略。

3.1 文字检测模型优化详解

文字检测模型优化进程中,选用 CML 协同互学习常识蒸馏以及 CopyPaste 数据增广战略;终究将文字检测模型在巨细不变的情况下,Hmean从 0.759 进步至 0.795,详细融化试验如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
PP-OCRv2检测模型融化试验

3.1.1 CML常识蒸馏战略

常识蒸馏的办法在布置中十分常用,经过运用大模型辅导小模型学习的办法,在通常情况下能够使得小模型在猜测耗时不变的情况下,精度得到进一步的进步,然后进一步进步实践布置的体会。

规范的蒸馏办法是经过一个大模型作为 Teacher 模型来辅导 Student 模型进步作用,而后来又开展出 DML 互学习蒸馏办法,即经过两个结构相同的模型相互学习,比较于前者,DML 脱离了对大的 Teacher 模型的依靠,蒸馏练习的流程更加简略,模型产出功率也要更高一些。

PP-OCRv2 文字检测模型中运用的是三个模型之间的 CML (Collaborative Mutual Learning) 协同互蒸馏办法,既包括两个相同结构的 Student 模型之间互学习,一起还引进了较大模型结构的 Teacher 模型。CML与其他蒸馏算法的比照如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
CML与其他常识蒸馏算法的比照

详细地,文本检测使命中,CML的结构框图如下所示。这儿的 response maps 指的就是DBNet终究一层的概率图输出 (Probability map) 。在整个练习进程中,一共包括3个丢失函数。

  • GT loss
  • DML loss
  • Distill loss

这儿的 Teacher 模型的主干网络为 ResNet18_vd,2 个 Student 模型的主干网络为 MobileNetV3。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
CML结构框图
  • GT loss

两个 Student 模型中大部分的参数都是从头初始化的,因而它们在练习的进程中需求受到 groundtruth (GT) 信息 的监督。DBNet 练习使命的 pipeline 如下所示。其输出首要包括 3 种 feature map,详细如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
DBNet头部结构

对这 3 种 feature map 运用不同的 loss function 进行监督,详细如下表所示。

Feature map Loss function weight
Probability map Binary cross-entropy loss 1.0
Binary map Dice loss \alpha
Threshold map L1 loss \beta

终究GT loss能够表明为如下所示。

Lossgt(Tout,gt)=lp(Sout,gt)+lb(Sout,gt)+lt(Sout,gt)Loss_{gt}(T_{out}, gt) = l_{p}(S_{out}, gt) + \alpha l_{b}(S_{out}, gt) + \beta l_{t}(S_{out}, gt)

  • DML loss

关于 2 个完全相同的 Student 模型来说,因为它们的结构完全相同,因而关于相同的输入,应该具有相同的输出,DBNet 终究输出的是概率图 (response maps),因而依据 KL 散度,核算 2 个 Student 模型的 DML loss,详细核算办法如下。

Lossdml=KL(S1pout∣∣S2pout)+KL(S2pout∣∣S1pout)2Loss_{dml} = \frac{KL(S1_{pout} || S2_{pout}) + KL(S2_{pout} || S1_{pout})}{2}

其间 KL(|)是 KL 散度的核算公式,终究这种形式的 DML loss 具有对称性。

  • Distill loss

CML 中,引进了 Teacher 模型,来一起监督 2 个 Student 模型。PP-OCRv2 中只对特征 Probability map 进行蒸馏的监督。详细地,关于其间一个 Student 模型,核算办法如下所示, lp() 和 lb() 别离表明 Binary cross-entropy loss 和 Dice loss。另一个 Student 模型的 loss 核算进程完全相同。

Lossdistill=lp(Sout,fdila(Tout))+lb(Sout,fdila(Tout))Loss_{distill} = \gamma l_{p}(S_{out}, f_{dila}(T_{out})) + l_{b}(S_{out}, f_{dila}(T_{out}))

终究,将上述三个 loss 相加,就得到了用于 CML 练习的丢失函数。

检测装备文件为ch_PP-OCRv2_det_cml.yml,蒸馏结构部分的装备和部分解释如下。

Architecture:
  name: DistillationModel     # 模型称号,这是通用的蒸馏模型表明。
  algorithm: Distillation     # 算法称号,
  Models:                     # 模型,包括子网络的装备信息
    Teacher:                  # Teacher子网络,包括`pretrained`与`freeze_params`信息以及其他用于构建子网络的参数
      freeze_params: true     # 是否固定Teacher网络的参数
      pretrained: ./pretrain_models/ch_ppocr_server_v2.0_det_train/best_accuracy # 预练习模型
      return_all_feats: false # 是否回来一切的特征,为True时,会将backbone、neck、head等模块的输出都回来
      model_type: det         # 模型类别
      algorithm: DB           # Teacher网络的算法称号
      Transform:
      Backbone:
        name: ResNet
        layers: 18
      Neck:
        name: DBFPN
        out_channels: 256
      Head:
        name: DBHead
        k: 50
    Student:                   # Student子网络
      freeze_params: false
      pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
      return_all_feats: false
      model_type: det
      algorithm: DB
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
      Neck:
        name: DBFPN
        out_channels: 96
      Head:
        name: DBHead
        k: 50
    Student2:                  # Student2子网络
      freeze_params: false
      pretrained: ./pretrain_models/MobileNetV3_large_x0_5_pretrained
      return_all_feats: false
      model_type: det
      algorithm: DB
      Transform:
      Backbone:
        name: MobileNetV3
        scale: 0.5
        model_name: large
        disable_se: True
      Neck:
        name: DBFPN
        out_channels: 96
      Head:
        name: DBHead
        k: 50

DistillationModel类的完结在distillation_model.py文件中,DistillationModel类的完结与部分解说如下。

class DistillationModel(nn.Layer):
    def __init__(self, config):
        """
        the module for OCR distillation.
        args:
            config (dict): the super parameters for module.
        """
        super().__init__()
        self.model_list = []
        self.model_name_list = []
        # 依据Models中的每个字段,抽取出子网络的称号以及对应的装备
        for key in config["Models"]:
            model_config = config["Models"][key]
            freeze_params = False
            pretrained = None
            if "freeze_params" in model_config:
                freeze_params = model_config.pop("freeze_params")
            if "pretrained" in model_config:
                pretrained = model_config.pop("pretrained")
            # 依据每个子网络的装备,依据BaseModel生成子网络
            model = BaseModel(model_config)
            # 判别是否加载预练习模型
            if pretrained is not None:
                load_pretrained_params(model, pretrained)
            # 判别是否需求固定该子网络的模型参数
            if freeze_params:
                for param in model.parameters():
                    param.trainable = False
            self.model_list.append(self.add_sublayer(key, model))
            self.model_name_list.append(key)
    def forward(self, x):
        result_dict = dict()
        for idx, model_name in enumerate(self.model_name_list):
            result_dict[model_name] = self.model_list[idx](x)
        return result_dict

运用下面的指令,能够快速完结蒸馏模型的初始化进程。

3.1.2 数据增广

数据增广是进步模型泛化才能重要的手法之一,CopyPaste 是一种新颖的数据增强技巧,已经在方针检测和实例切割使命中验证了有用性。利用 CopyPaste,能够组成文本实例来平衡练习图画中的正负样本之间的份额。比较而言,传统图画旋转、随机翻转和随机裁剪是无法做到的。

CopyPaste 首要进程包括:

  1. 随机挑选两幅练习图画;
  2. 随机规范抖动缩放;
  3. 随机水平翻转;
  4. 随机挑选一幅图画中的方针子集;
  5. 粘贴在另一幅图画中随机的位置。

这样就比较好地进步了样本丰厚度,一起也增加了模型对环境的鲁棒性。如下图所示,经过在左下角的图中裁剪出来的文本,随机旋转缩放之后粘贴到左上角的图画中,进一步丰厚了该文本在不同布景下的多样性。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

假如期望在模型练习中运用CopyPaste,只需在Train.transforms装备字段中增加CopyPaste即可,如下所示。

Train:
  dataset:
    name: SimpleDataSet
    data_dir: ./train_data/icdar2015/text_localization/
    label_file_list:
      - ./train_data/icdar2015/text_localization/train_icdar2015_label.txt
    ratio_list: [1.0]
    transforms:
      - DecodeImage: # load image
          img_mode: BGR
          channel_first: False
      - DetLabelEncode: # Class handling label
      - CopyPaste:  # 增加CopyPaste
      - IaaAugment:
          augmenter_args:
            - { 'type': Fliplr, 'args': { 'p': 0.5 } }
            - { 'type': Affine, 'args': { 'rotate': [-10, 10] } }
            - { 'type': Resize, 'args': { 'size': [0.5, 3] } }
      - EastRandomCropData:
          size: [960, 960]
          max_tries: 50
          keep_ratio: true
      - MakeBorderMap:
          shrink_ratio: 0.4
          thresh_min: 0.3
          thresh_max: 0.7
      - MakeShrinkMap:
          shrink_ratio: 0.4
          min_text_size: 8
      - NormalizeImage:
          scale: 1./255.
          mean: [0.485, 0.456, 0.406]
          std: [0.229, 0.224, 0.225]
          order: 'hwc'
      - ToCHWImage:
      - KeepKeys:
          keep_keys: ['image', 'threshold_map', 'threshold_mask', 'shrink_map', 'shrink_mask'] # the order of the dataloader list
  loader:
    shuffle: True
    drop_last: False
    batch_size_per_card: 8
    num_workers: 4

CopyPaste的详细完结能够参阅copy_paste.py。

下面依据icdar2015检测数据集,演示CopyPaste的实践运行进程。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

3.1.3 文字检测优化小结

PP-OCRv2中,对文字检测模型选用运用常识蒸馏方案以及数据增广战略,增加模型的泛化功能。终究文字检测模型在巨细不变的情况下,Hmean从 0.759 进步至 0.795,详细融化试验如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
PP-OCRv2检测模型融化试验

PP-OCRv2中检测作用如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

3.2 文本辨认模型优化详解

PP-OCRv2文字辨认模型优化进程中,选用主干网络优化、UDML常识蒸馏战略、CTC loss改善等技巧,终究将辨认精度从 66.7% 进步至 74.8%,详细融化试验如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
PP-OCRv2辨认模型融化试验

3.2.1 PP-LCNet轻量级主干网络

百度提出了一种依据 MKLDNN 加速战略的轻量级 CPU 网络,即 PP-LCNet,大幅进步了轻量级模型在图画分类使命上的功能,关于核算机视觉的下流使命,如文本辨认、方针检测、语义切割等,有很好的体现。这儿需求留意的是,PP-LCNet是针对CPU+MKLDNN这个场景进行定制优化,在分类使命上的速度和精度都远远优于其他模型,因而咱们假如有这个运用场景的模型需求的话,也推荐咱们去运用。

PP-LCNet 论文地址:PP-LCNet: A Lightweight CPU Convolutional Neural Network

PP-LCNet依据MobileNetV1改善得到,其结构图如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

比较于MobileNetV1,PP-LCNet中交融了MobileNetV3结构中激活函数、头部结构、SE模块等战略优化技巧,一起分析了终究阶段卷积层的卷积核巨细,终究该模型在确保速度优势的基础上,精度大幅逾越MobileNet、GhostNet等轻量级模型。

详细地,PP-LCNet中共涉及到下面4个优化点。

  • 除了 SE 模块,网络中一切的 relu 激活函数替换为 h-swish,精度进步1%-2%
  • PP-LCNet 第五阶段,DW 的 kernel size 变为5×5,精度进步0.5%-1%
  • PP-LCNet 第五阶段的终究两个 DepthSepConv block 增加 SE 模块, 精度进步0.5%-1%
  • GAP 后增加 1280 维的 FC 层,增加特征表达才能,精度进步2%-3%

在ImageNet1k数据集上,PP-LCNet比较于其他目前比较常用的轻量级分类模型,Top1-Acc 与猜测耗时如下图所示。能够看出,猜测耗时和精度都是要更优的。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

经过下面这种办法,便能够快速完结PP-LCNet辨认模型的界说。

3.2.2 U-DML 常识蒸馏战略

关于规范的 DML 战略,蒸馏的丢失函数仅包括终究输出层监督,可是关于 2 个结构完全相同的模型来说,关于完全相同的输入,它们的中心特征输出期望也完全相同,因而在终究输出层监督的监督上,能够进一步增加中心输出的特征图的监督信号,作为丢失函数,即 PP-OCRv2 中的 U-DML (Unified-Deep Mutual Learning) 常识蒸馏办法。

U-DML 常识蒸馏的算法流程图如下所示。 Teacher 模型与 Student 模型的网络结构完全相同,初始化参数不同,此外,在新增在规范的 DML 常识蒸馏的基础上,新增引进了关于 Feature Map 的监督机制,新增 Feature Loss。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

在练习的进程中,一共包括 3 种 loss: GT loss,DML loss,Feature loss。

  • GT loss

文本辨认使命运用的模型结构是 CRNN,因而运用 CTC loss 作为 GT loss, GT loss 核算办法如下所示。

Lossctc=CTC(Shout,gt)+CTC(Thout,gt)Loss_{ctc} = CTC(S_{hout}, gt) + CTC(T_{hout}, gt)

  • DML loss

DML loss 核算办法如下,这儿 Teacher 模型与 Student 模型相互核算 KL 散度,终究 DML loss具有对称性。

Lossdml=KL(Spout∣∣Tpout)+KL(Tpout∣∣Spout)2Loss_{dml} = \frac{KL(S_{pout} || T_{pout}) + KL(T_{pout} || S_{pout})}{2}

  • Feature loss

Feature loss 运用的是 L2 loss,详细核算办法如下所示。

Lossfeat=L2(Sbout,Tbout)Loss_{feat} = L2(S_{bout}, T_{bout})

终究,练习进程中的 loss function 核算办法如下所示。

Losstotal=Lossctc+Lossdml+LossfeatLoss_{total} = Loss_{ctc} + Loss_{dml} + Loss_{feat}

此外,在练习进程中经过增加迭代次数,在 Head 部分增加 FC 层等 trick,平衡模型的特征编码与解码的才能,进一步进步了模型作用。

装备文件在ch_PP-OCRv2_rec_distillation.yml。

Architecture:
  model_type: &model_type "rec"    # 模型类别,rec、det等,每个子网络的的模型类别都与
  name: DistillationModel          # 结构称号,蒸馏使命中,为DistillationModel,用于构建对应的结构
  algorithm: Distillation          # 算法称号
  Models:                          # 模型,包括子网络的装备信息
    Teacher:                       # 子网络称号,至少需求包括`pretrained`与`freeze_params`信息,其他的参数为子网络的构造参数
      pretrained:                  # 该子网络是否需求加载预练习模型
      freeze_params: false         # 是否需求固定参数
      return_all_feats: true       # 子网络的参数,表明是否需求回来一切的features,假如为False,则只回来终究的输出
      model_type: *model_type      # 模型类别
      algorithm: CRNN              # 子网络的算法称号,该子网络剩下参与均为构造参数,与一般的模型练习装备共同
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
      Neck:
        name: SequenceEncoder
        encoder_type: rnn
        hidden_size: 64
      Head:
        name: CTCHead
        mid_channels: 96           # Head解码进程中穿插一层
        fc_decay: 0.00002
    Student:                       # 别的一个子网络,这儿给的是DML的蒸馏示例,两个子网络结构相同,均需求学习参数
      pretrained:                  # 下面的组网参数同上
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: CRNN
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
      Neck:
        name: SequenceEncoder
        encoder_type: rnn
        hidden_size: 64
      Head:
        name: CTCHead
        mid_channels: 96
        fc_decay: 0.00002

当然,这儿假如期望增加更多的子网络进行练习,也能够依照StudentTeacher的增加办法,在装备文件中增加相应的字段。比如说假如期望有3个模型相互监督,共同练习,那么Architecture能够写为如下格局。

Architecture:
  model_type: &model_type "rec"
  name: DistillationModel
  algorithm: Distillation
  Models:
    Teacher:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: CRNN
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
      Neck:
        name: SequenceEncoder
        encoder_type: rnn
        hidden_size: 64
      Head:
        name: CTCHead
        mid_channels: 96
        fc_decay: 0.00002
    Student:
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: CRNN
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
      Neck:
        name: SequenceEncoder
        encoder_type: rnn
        hidden_size: 64
      Head:
        name: CTCHead
        mid_channels: 96
        fc_decay: 0.00002
    Student2:                       # 常识蒸馏使命中引进的新的子网络,其他部分与上述装备相同
      pretrained:
      freeze_params: false
      return_all_feats: true
      model_type: *model_type
      algorithm: CRNN
      Transform:
      Backbone:
        name: MobileNetV1Enhance
        scale: 0.5
      Neck:
        name: SequenceEncoder
        encoder_type: rnn
        hidden_size: 64
      Head:
        name: CTCHead
        mid_channels: 96
        fc_decay: 0.00002

终究该模型练习时,包括3个子网络:Teacher, Student, Student2

蒸馏模型DistillationModel类的详细完结代码能够参阅distillation_model.py。

终究模型forward输出为一个字典,key为一切的子网络称号,例如这儿为StudentTeacher,value为对应子网络的输出,能够为Tensor(只回来该网络的终究一层)和dict(也回来了中心的特征信息)。

在辨认使命中,为了增加更多丢失函数,确保蒸馏办法的可扩展性,将每个子网络的输出保存为dict,其间包括子模块输出。以该辨认模型为例,每个子网络的输出成果均为dict,key包括backbone_out,neck_out, head_outvalue为对应模块的tensor,终究关于上述装备文件,DistillationModel的输出格局如下。

{
  "Teacher": {
    "backbone_out": tensor,
    "neck_out": tensor,
    "head_out": tensor,
  },
  "Student": {
    "backbone_out": tensor,
    "neck_out": tensor,
    "head_out": tensor,
  }
}

常识蒸馏使命中,丢失函数装备如下所示。

Loss:
  name: CombinedLoss                           # 丢失函数称号,依据改称号,构建用于丢失函数的类
  loss_config_list:                            # 丢失函数装备文件列表,为CombinedLoss的必备函数
  - DistillationCTCLoss:                       # 依据蒸馏的CTC丢失函数,继承自规范的CTC loss
      weight: 1.0                              # 丢失函数的权重,loss_config_list中,每个丢失函数的装备都必须包括该字段
      model_name_list: ["Student", "Teacher"]  # 关于蒸馏模型的猜测成果,提取这两个子网络的输出,与gt核算CTC loss
      key: head_out                            # 取子网络输出dict中,该key对应的tensor
  - DistillationDMLLoss:                       # 蒸馏的DML丢失函数,继承自规范的DMLLoss
      weight: 1.0                              # 权重
      act: "softmax"                           # 激活函数,对输入运用激活函数处理,能够为softmax, sigmoid或许为None,默认为None
      model_name_pairs:                        # 用于核算DML loss的子网络称号对,假如期望核算其他子网络的DML loss,能够在列表下面持续填充
      - ["Student", "Teacher"]
      key: head_out                            # 取子网络输出dict中,该key对应的tensor
  - DistillationDistanceLoss:                  # 蒸馏的间隔丢失函数
      weight: 1.0                              # 权重
      mode: "l2"                               # 间隔核算办法,目前支持l1, l2, smooth_l1
      model_name_pairs:                        # 用于核算distance loss的子网络称号对
      - ["Student", "Teacher"]
      key: backbone_out                        # 取子网络输出dict中,该key对应的tensor

上述丢失函数中,一切的蒸馏丢失函数均继承自规范的丢失函数类,首要功能为: 对蒸馏模型的输出进行解析,找到用于核算丢失的中心节点(tensor),再运用规范的丢失函数类去核算。

以上述装备为例,终究蒸馏练习的丢失函数包括下面3个部分。

  • StudentTeacher的终究输出(head_out)与gt的CTC loss,权重为1。在这儿因为2个子网络都需求更新参数,因而2者都需求核算与gt的loss。
  • StudentTeacher的终究输出(head_out)之间的DML loss,权重为1。
  • StudentTeacher的主干网络输出(backbone_out)之间的l2 loss,权重为1。

CombinedLoss类完结如下。

class CombinedLoss(nn.Layer):
    """
    CombinedLoss:
        a combionation of loss function
    """
    def __init__(self, loss_config_list=None):
        super().__init__()
        self.loss_func = []
        self.loss_weight = []
        assert isinstance(loss_config_list, list), (
            'operator config should be a list')
        for config in loss_config_list:
            assert isinstance(config,
                              dict) and len(config) == 1, "yaml format error"
            name = list(config)[0]
            param = config[name]
            assert "weight" in param, "weight must be in param, but param just contains {}".format(
                param.keys())
            self.loss_weight.append(param.pop("weight"))
            self.loss_func.append(eval(name)(**param))
    def forward(self, input, batch, **kargs):
        loss_dict = {}
        loss_all = 0.
        for idx, loss_func in enumerate(self.loss_func):
            loss = loss_func(input, batch, **kargs)
            if isinstance(loss, paddle.Tensor):
                loss = {"loss_{}_{}".format(str(loss), idx): loss}
            weight = self.loss_weight[idx]
            loss = {key: loss[key] * weight for key in loss}
            if "loss" in loss:
                loss_all += loss["loss"]
            else:
                loss_all += paddle.add_n(list(loss.values()))
            loss_dict.update(loss)
        loss_dict["loss"] = loss_all
        return loss_dict

关于CombinedLoss更加详细的完结能够参阅: combined_loss.py。关于DistillationCTCLoss等蒸馏丢失函数更加详细的完结能够参阅distillation_loss.py。

关于上面3个模型的蒸馏,Loss字段也需求相应修正,一起考虑3个子网络之间的丢失,如下所示。

Loss:
  name: CombinedLoss                           # 丢失函数称号,依据改称号,构建用于丢失函数的类
  loss_config_list:                            # 丢失函数装备文件列表,为CombinedLoss的必备函数
  - DistillationCTCLoss:                       # 依据蒸馏的CTC丢失函数,继承自规范的CTC loss
      weight: 1.0                              # 丢失函数的权重,loss_config_list中,每个丢失函数的装备都必须包括该字段
      model_name_list: ["Student", "Student2", "Teacher"]  # 关于蒸馏模型的猜测成果,提取这三个子网络的输出,与gt核算CTC loss
      key: head_out                            # 取子网络输出dict中,该key对应的tensor
  - DistillationDMLLoss:                       # 蒸馏的DML丢失函数,继承自规范的DMLLoss
      weight: 1.0                              # 权重
      act: "softmax"                           # 激活函数,对输入运用激活函数处理,能够为softmax, sigmoid或许为None,默认为None
      model_name_pairs:                        # 用于核算DML loss的子网络称号对,假如期望核算其他子网络的DML loss,能够在列表下面持续填充
      - ["Student", "Teacher"]
      - ["Student2", "Teacher"]
      - ["Student", "Student2"]
      key: head_out                            # 取子网络输出dict中,该key对应的tensor
  - DistillationDistanceLoss:                  # 蒸馏的间隔丢失函数
      weight: 1.0                              # 权重
      mode: "l2"                               # 间隔核算办法,目前支持l1, l2, smooth_l1
      model_name_pairs:                        # 用于核算distance loss的子网络称号对
      - ["Student", "Teacher"]
      - ["Student2", "Teacher"]
      - ["Student", "Student2"]
      key: backbone_out                        # 取子网络输出dict中,该key对应的tensor

3.2.3 Enhanced CTC loss 改善

中文 OCR 使命常常遇到的辨认难点是类似字符数太多,容易误识。学习 Metric Learning 中的想法,引进 Center loss,进一步增大类间间隔,核心公式如下所示。

L=Lctc+∗LcenterL = L_{ctc} + \lambda * L_{center} Lcenter=∑t=1T∣∣xt−cyt∣∣22L_{center} =\sum_{t=1}^T||x_{t} – c_{y_{t}}||_{2}^{2}

这儿 xtx_t 表明时刻步长 tt 处的标签,cytc_{y_{t}} 表明标签 yty_t 对应的 center。

Enhance CTC 中,center 的初始化对成果也有较大影响,在 PP-OCRv2 中,center 初始化的详细进程如下所示。

  1. 依据规范的 CTC loss,练习一个网络;
  2. 提取出练习调集中辨认正确的图画调集,记为 G ;
  3. 将 G 中的图片依次输入网络, 提取head输出时序特征的 xtx_tyty_t 的对应关系,其间 yty_t 核算办法如下:

yt=argmax(W∗xt)y_{t} = argmax(W * x_{t})

  1. 将相同 yty_t 对应的 xtx_t 聚合在一起,取其平均值,作为初始 center。

首先需求依据configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec.yml练习一个基础网络

更多关于Center loss的练习进程能够参阅:Enhanced CTC Loss运用文档

终究,运用configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml进行练习,指令如下所示。

python tools/train.py -c configs/rec/ch_PP-OCRv2/ch_PP-OCRv2_rec_enhanced_ctc_loss.yml

首要改善点为Loss字段,比较于规范的CTCLoss,增加了CenterLoss。装备类别数、特征维度、center途径即可。

Loss:
  name: CombinedLoss
  loss_config_list:
  - CTCLoss:
      use_focal_loss: false
      weight: 1.0
  - CenterLoss:
      weight: 0.05
      num_classes: 6625
      feat_dim: 96
      center_file_path: "./train_center.pkl"

3.2.4 文本辨认优化小结

PP-OCRv2文字辨认模型优化进程中,对模型从主干网络、丢失函数等视点进行改善,并引进常识蒸馏的练习办法,终究将辨认精度从 66.7% 进步至 74.8%,详细融化试验如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案
PP-OCRv2辨认模型融化试验

在PP-OCRv2文字检测的基础上,辨认模型的试验作用如下所示。

OCR文字检测与识别系统:融合文字检测、文字识别和方向分类器的综合解决方案

4. 总结

本章首要介绍PP-OCR以及PP-OCRv2的优化战略。

PP-OCR从主干网络、学习率战略、数据增广、模型裁剪量化等方面,共运用了19个对战略,对模型进行优化减肥,终究打造了面向服务器端的PP-OCR server体系以及面向移动端的PP-OCR mobile体系。

比较于PP-OCR, PP-OCRv2 在主干网络、数据增广、丢失函数这三个方面进行进一步优化,处理端侧猜测功率较差、布景杂乱以及类似字符的误识等问题,一起引进了常识蒸馏练习战略,进一步进步模型精度,终究打造了精度、速度远超PP-OCR的文字检测与辨认体系。

参阅链接

aistudio.baidu.com/education/g…

github.com/PaddlePaddl…

更多优质内容请重视公号:汀丶人工智能;会供给一些相关的资源和优质文章,免费获取阅览。