Source code for doctr.models.detection.zoo

# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Any, List

from doctr.file_utils import is_tf_available, is_torch_available

from .. import detection
from ..detection.fast import reparameterize
from ..preprocessor import PreProcessor
from .predictor import DetectionPredictor

__all__ = ["detection_predictor"]

ARCHS: List[str]


if is_tf_available():
    ARCHS = [
        "db_resnet50",
        "db_mobilenet_v3_large",
        "linknet_resnet18",
        "linknet_resnet34",
        "linknet_resnet50",
        "fast_tiny",
        "fast_small",
        "fast_base",
    ]
elif is_torch_available():
    ARCHS = [
        "db_resnet34",
        "db_resnet50",
        "db_mobilenet_v3_large",
        "linknet_resnet18",
        "linknet_resnet34",
        "linknet_resnet50",
        "fast_tiny",
        "fast_small",
        "fast_base",
    ]


def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
    if isinstance(arch, str):
        if arch not in ARCHS:
            raise ValueError(f"unknown architecture '{arch}'")

        _model = detection.__dict__[arch](
            pretrained=pretrained,
            pretrained_backbone=kwargs.get("pretrained_backbone", True),
            assume_straight_pages=assume_straight_pages,
        )
        # Reparameterize FAST models by default to lower inference latency and memory usage
        if isinstance(_model, detection.FAST):
            _model = reparameterize(_model)
    else:
        if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
            raise ValueError(f"unknown architecture: {type(arch)}")

        _model = arch
        _model.assume_straight_pages = assume_straight_pages
        _model.postprocessor.assume_straight_pages = assume_straight_pages

    kwargs.pop("pretrained_backbone", None)

    kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
    kwargs["std"] = kwargs.get("std", _model.cfg["std"])
    kwargs["batch_size"] = kwargs.get("batch_size", 2)
    predictor = DetectionPredictor(
        PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs),
        _model,
    )
    return predictor


[docs] def detection_predictor( arch: Any = "fast_base", pretrained: bool = False, assume_straight_pages: bool = True, **kwargs: Any, ) -> DetectionPredictor: """Text detection architecture. >>> import numpy as np >>> from doctr.models import detection_predictor >>> model = detection_predictor(arch='db_resnet50', pretrained=True) >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) >>> out = model([input_page]) Args: ---- arch: name of the architecture or model itself to use (e.g. 'db_resnet50') pretrained: If True, returns a model pre-trained on our text detection dataset assume_straight_pages: If True, fit straight boxes to the page **kwargs: optional keyword arguments passed to the architecture Returns: ------- Detection predictor """ return _predictor(arch, pretrained, assume_straight_pages, **kwargs)