Source code for doctr.models.detection.fast.tensorflow

# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization

from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import tensorflow as tf
from tensorflow.keras import Model, Sequential, layers

from doctr.file_utils import CLASS_NAME
from doctr.models.utils import IntermediateLayerGetter, _bf16_to_float32, _build_model, load_pretrained_params
from doctr.utils.repr import NestedObject

from ...classification import textnet_base, textnet_small, textnet_tiny
from ...modules.layers import FASTConvLayer
from .base import _FAST, FASTPostProcessor

__all__ = ["FAST", "fast_tiny", "fast_small", "fast_base", "reparameterize"]


default_cfgs: Dict[str, Dict[str, Any]] = {
    "fast_tiny": {
        "input_shape": (1024, 1024, 3),
        "mean": (0.798, 0.785, 0.772),
        "std": (0.264, 0.2749, 0.287),
        "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_tiny-d7379d7b.weights.h5&src=0",
    },
    "fast_small": {
        "input_shape": (1024, 1024, 3),
        "mean": (0.798, 0.785, 0.772),
        "std": (0.264, 0.2749, 0.287),
        "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_small-44b27eb6.weights.h5&src=0",
    },
    "fast_base": {
        "input_shape": (1024, 1024, 3),
        "mean": (0.798, 0.785, 0.772),
        "std": (0.264, 0.2749, 0.287),
        "url": "https://doctr-static.mindee.com/models?id=v0.9.0/fast_base-f2c6c736.weights.h5&src=0",
    },
}


class FastNeck(layers.Layer, NestedObject):
    """Neck of the FAST architecture, composed of a series of 3x3 convolutions and upsampling layer.

    Args:
    ----
        in_channels: number of input channels
        out_channels: number of output channels
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int = 128,
    ) -> None:
        super().__init__()
        self.reduction = [FASTConvLayer(in_channels * scale, out_channels, kernel_size=3) for scale in [1, 2, 4, 8]]

    def _upsample(self, x: tf.Tensor, y: tf.Tensor) -> tf.Tensor:
        return tf.image.resize(x, size=y.shape[1:3], method="bilinear")

    def call(self, x: tf.Tensor, **kwargs: Any) -> tf.Tensor:
        f1, f2, f3, f4 = x
        f1, f2, f3, f4 = [reduction(f, **kwargs) for reduction, f in zip(self.reduction, (f1, f2, f3, f4))]
        f2, f3, f4 = [self._upsample(f, f1) for f in (f2, f3, f4)]
        f = tf.concat((f1, f2, f3, f4), axis=-1)
        return f


class FastHead(Sequential):
    """Head of the FAST architecture

    Args:
    ----
        in_channels: number of input channels
        num_classes: number of output classes
        out_channels: number of output channels
        dropout: dropout probability
    """

    def __init__(
        self,
        in_channels: int,
        num_classes: int,
        out_channels: int = 128,
        dropout: float = 0.1,
    ) -> None:
        _layers = [
            FASTConvLayer(in_channels, out_channels, kernel_size=3),
            layers.Dropout(dropout),
            layers.Conv2D(num_classes, kernel_size=1, use_bias=False),
        ]
        super().__init__(_layers)


class FAST(_FAST, Model, NestedObject):
    """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation"
    <https://arxiv.org/pdf/2111.02394.pdf>`_.

    Args:
    ----
        feature extractor: the backbone serving as feature extractor
        bin_thresh: threshold for binarization
        box_thresh: minimal objectness score to consider a box
        dropout_prob: dropout probability
        pooling_size: size of the pooling layer
        assume_straight_pages: if True, fit straight bounding boxes only
        exportable: onnx exportable returns only logits
        cfg: the configuration dict of the model
        class_names: list of class names
    """

    _children_names: List[str] = ["feat_extractor", "neck", "head", "postprocessor"]

    def __init__(
        self,
        feature_extractor: IntermediateLayerGetter,
        bin_thresh: float = 0.1,
        box_thresh: float = 0.1,
        dropout_prob: float = 0.1,
        pooling_size: int = 4,  # different from paper performs better on close text-rich images
        assume_straight_pages: bool = True,
        exportable: bool = False,
        cfg: Optional[Dict[str, Any]] = {},
        class_names: List[str] = [CLASS_NAME],
    ) -> None:
        super().__init__()
        self.class_names = class_names
        num_classes: int = len(self.class_names)
        self.cfg = cfg

        self.feat_extractor = feature_extractor
        self.exportable = exportable
        self.assume_straight_pages = assume_straight_pages

        # Identify the number of channels for the neck & head initialization
        feat_out_channels = [
            layers.Input(shape=in_shape[1:]).shape[-1] for in_shape in self.feat_extractor.output_shape
        ]
        # Initialize neck & head
        self.neck = FastNeck(feat_out_channels[0], feat_out_channels[1])
        self.head = FastHead(feat_out_channels[-1], num_classes, feat_out_channels[1], dropout_prob)

        # NOTE: The post processing from the paper works not well for text-rich images
        # so we use a modified version from DBNet
        self.postprocessor = FASTPostProcessor(
            assume_straight_pages=assume_straight_pages, bin_thresh=bin_thresh, box_thresh=box_thresh
        )

        # Pooling layer as erosion reversal as described in the paper
        self.pooling = layers.MaxPooling2D(pool_size=pooling_size // 2 + 1, strides=1, padding="same")

    def compute_loss(
        self,
        out_map: tf.Tensor,
        target: List[Dict[str, np.ndarray]],
        eps: float = 1e-6,
    ) -> tf.Tensor:
        """Compute fast loss, 2 x Dice loss where the text kernel loss is scaled by 0.5.

        Args:
        ----
            out_map: output feature map of the model of shape (N, num_classes, H, W)
            target: list of dictionary where each dict has a `boxes` and a `flags` entry
            eps: epsilon factor in dice loss

        Returns:
        -------
            A loss tensor
        """
        targets = self.build_target(target, out_map.shape[1:], True)

        seg_target = tf.convert_to_tensor(targets[0], dtype=out_map.dtype)
        seg_mask = tf.convert_to_tensor(targets[1], dtype=out_map.dtype)
        shrunken_kernel = tf.convert_to_tensor(targets[2], dtype=out_map.dtype)

        def ohem(score: tf.Tensor, gt: tf.Tensor, mask: tf.Tensor) -> tf.Tensor:
            pos_num = tf.reduce_sum(tf.cast(gt > 0.5, dtype=tf.int32)) - tf.reduce_sum(
                tf.cast((gt > 0.5) & (mask <= 0.5), dtype=tf.int32)
            )
            neg_num = tf.reduce_sum(tf.cast(gt <= 0.5, dtype=tf.int32))
            neg_num = tf.minimum(pos_num * 3, neg_num)

            if neg_num == 0 or pos_num == 0:
                return mask

            neg_score_sorted, _ = tf.nn.top_k(-tf.boolean_mask(score, gt <= 0.5), k=neg_num)
            threshold = -neg_score_sorted[-1]

            selected_mask = tf.math.logical_and((score >= threshold) | (gt > 0.5), (mask > 0.5))
            return tf.cast(selected_mask, dtype=tf.float32)

        if len(self.class_names) > 1:
            kernels = tf.nn.softmax(out_map, axis=-1)
            prob_map = tf.nn.softmax(self.pooling(out_map), axis=-1)
        else:
            kernels = tf.sigmoid(out_map)
            prob_map = tf.sigmoid(self.pooling(out_map))

        # As described in the paper, we use the Dice loss for the text segmentation map and the Dice loss scaled by 0.5.
        selected_masks = tf.stack(
            [ohem(score, gt, mask) for score, gt, mask in zip(prob_map, seg_target, seg_mask)], axis=0
        )
        inter = tf.reduce_sum(selected_masks * prob_map * seg_target, axis=(0, 1, 2))
        cardinality = tf.reduce_sum(selected_masks * (prob_map + seg_target), axis=(0, 1, 2))
        text_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps))) * 0.5

        # As described in the paper, we use the Dice loss for the text kernel map.
        selected_masks = seg_target * seg_mask
        inter = tf.reduce_sum(selected_masks * kernels * shrunken_kernel, axis=(0, 1, 2))
        cardinality = tf.reduce_sum(selected_masks * (kernels + shrunken_kernel), axis=(0, 1, 2))
        kernel_loss = tf.reduce_mean((1 - 2 * inter / (cardinality + eps)))

        return text_loss + kernel_loss

    def call(
        self,
        x: tf.Tensor,
        target: Optional[List[Dict[str, np.ndarray]]] = None,
        return_model_output: bool = False,
        return_preds: bool = False,
        **kwargs: Any,
    ) -> Dict[str, Any]:
        feat_maps = self.feat_extractor(x, **kwargs)
        # Pass through the Neck & Head & Upsample
        feat_concat = self.neck(feat_maps, **kwargs)
        logits: tf.Tensor = self.head(feat_concat, **kwargs)
        logits = layers.UpSampling2D(size=x.shape[-2] // logits.shape[-2], interpolation="bilinear")(logits, **kwargs)

        out: Dict[str, tf.Tensor] = {}
        if self.exportable:
            out["logits"] = logits
            return out

        if return_model_output or target is None or return_preds:
            prob_map = _bf16_to_float32(tf.math.sigmoid(self.pooling(logits, **kwargs)))

        if return_model_output:
            out["out_map"] = prob_map

        if target is None or return_preds:
            # Post-process boxes (keep only text predictions)
            out["preds"] = [dict(zip(self.class_names, preds)) for preds in self.postprocessor(prob_map.numpy())]

        if target is not None:
            loss = self.compute_loss(logits, target)
            out["loss"] = loss

        return out


def reparameterize(model: Union[FAST, layers.Layer]) -> FAST:
    """Fuse batchnorm and conv layers and reparameterize the model

    args:
    ----
        model: the FAST model to reparameterize

    Returns:
    -------
        the reparameterized model
    """
    last_conv = None
    last_conv_idx = None

    for idx, layer in enumerate(model.layers):
        if hasattr(layer, "layers") or isinstance(
            layer, (FASTConvLayer, FastNeck, FastHead, layers.BatchNormalization, layers.Conv2D)
        ):
            if isinstance(layer, layers.BatchNormalization):
                # fuse batchnorm only if it is followed by a conv layer
                if last_conv is None:
                    continue
                conv_w = last_conv.kernel
                conv_b = last_conv.bias if last_conv.use_bias else tf.zeros_like(layer.moving_mean)

                factor = layer.gamma / tf.sqrt(layer.moving_variance + layer.epsilon)
                last_conv.kernel = conv_w * factor.numpy().reshape([1, 1, 1, -1])
                if last_conv.use_bias:
                    last_conv.bias.assign((conv_b - layer.moving_mean) * factor + layer.beta)
                model.layers[last_conv_idx] = last_conv  # Replace the last conv layer with the fused version
                model.layers[idx] = layers.Lambda(lambda x: x)
                last_conv = None
            elif isinstance(layer, layers.Conv2D):
                last_conv = layer
                last_conv_idx = idx
            elif isinstance(layer, FASTConvLayer):
                layer.reparameterize_layer()
            elif isinstance(layer, FastNeck):
                for reduction in layer.reduction:
                    reduction.reparameterize_layer()
            elif isinstance(layer, FastHead):
                reparameterize(layer)
            else:
                reparameterize(layer)
    return model


def _fast(
    arch: str,
    pretrained: bool,
    backbone_fn,
    feat_layers: List[str],
    pretrained_backbone: bool = True,
    input_shape: Optional[Tuple[int, int, int]] = None,
    **kwargs: Any,
) -> FAST:
    pretrained_backbone = pretrained_backbone and not pretrained

    # Patch the config
    _cfg = deepcopy(default_cfgs[arch])
    _cfg["input_shape"] = input_shape or _cfg["input_shape"]
    if not kwargs.get("class_names", None):
        kwargs["class_names"] = _cfg.get("class_names", [CLASS_NAME])
    else:
        kwargs["class_names"] = sorted(kwargs["class_names"])

    # Feature extractor
    feat_extractor = IntermediateLayerGetter(
        backbone_fn(
            input_shape=_cfg["input_shape"],
            include_top=False,
            pretrained=pretrained_backbone,
        ),
        feat_layers,
    )

    # Build the model
    model = FAST(feat_extractor, cfg=_cfg, **kwargs)
    _build_model(model)

    # Load pretrained parameters
    if pretrained:
        # The given class_names differs from the pretrained model => skip the mismatching layers for fine tuning
        load_pretrained_params(
            model,
            _cfg["url"],
            skip_mismatch=kwargs["class_names"] != default_cfgs[arch].get("class_names", [CLASS_NAME]),
        )

    return model


[docs] def fast_tiny(pretrained: bool = False, **kwargs: Any) -> FAST: """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" <https://arxiv.org/pdf/2111.02394.pdf>`_, using a tiny TextNet backbone. >>> import tensorflow as tf >>> from doctr.models import fast_tiny >>> model = fast_tiny(pretrained=True) >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) >>> out = model(input_tensor) Args: ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: ------- text detection architecture """ return _fast( "fast_tiny", pretrained, textnet_tiny, ["stage_0", "stage_1", "stage_2", "stage_3"], **kwargs, )
[docs] def fast_small(pretrained: bool = False, **kwargs: Any) -> FAST: """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" <https://arxiv.org/pdf/2111.02394.pdf>`_, using a small TextNet backbone. >>> import tensorflow as tf >>> from doctr.models import fast_small >>> model = fast_small(pretrained=True) >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) >>> out = model(input_tensor) Args: ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: ------- text detection architecture """ return _fast( "fast_small", pretrained, textnet_small, ["stage_0", "stage_1", "stage_2", "stage_3"], **kwargs, )
[docs] def fast_base(pretrained: bool = False, **kwargs: Any) -> FAST: """FAST as described in `"FAST: Faster Arbitrarily-Shaped Text Detector with Minimalist Kernel Representation" <https://arxiv.org/pdf/2111.02394.pdf>`_, using a base TextNet backbone. >>> import tensorflow as tf >>> from doctr.models import fast_base >>> model = fast_base(pretrained=True) >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) >>> out = model(input_tensor) Args: ---- pretrained (bool): If True, returns a model pre-trained on our text detection dataset **kwargs: keyword arguments of the DBNet architecture Returns: ------- text detection architecture """ return _fast( "fast_base", pretrained, textnet_base, ["stage_0", "stage_1", "stage_2", "stage_3"], **kwargs, )