# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
from typing import Any, List
from doctr.file_utils import is_tf_available, is_torch_available
from .. import detection
from ..detection.fast import reparameterize
from ..preprocessor import PreProcessor
from .predictor import DetectionPredictor
__all__ = ["detection_predictor"]
ARCHS: List[str]
if is_tf_available():
ARCHS = [
"db_resnet50",
"db_mobilenet_v3_large",
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
"fast_tiny",
"fast_small",
"fast_base",
]
elif is_torch_available():
ARCHS = [
"db_resnet34",
"db_resnet50",
"db_mobilenet_v3_large",
"linknet_resnet18",
"linknet_resnet34",
"linknet_resnet50",
"fast_tiny",
"fast_small",
"fast_base",
]
def _predictor(arch: Any, pretrained: bool, assume_straight_pages: bool = True, **kwargs: Any) -> DetectionPredictor:
if isinstance(arch, str):
if arch not in ARCHS:
raise ValueError(f"unknown architecture '{arch}'")
_model = detection.__dict__[arch](
pretrained=pretrained,
pretrained_backbone=kwargs.get("pretrained_backbone", True),
assume_straight_pages=assume_straight_pages,
)
# Reparameterize FAST models by default to lower inference latency and memory usage
if isinstance(_model, detection.FAST):
_model = reparameterize(_model)
else:
if not isinstance(arch, (detection.DBNet, detection.LinkNet, detection.FAST)):
raise ValueError(f"unknown architecture: {type(arch)}")
_model = arch
_model.assume_straight_pages = assume_straight_pages
_model.postprocessor.assume_straight_pages = assume_straight_pages
kwargs.pop("pretrained_backbone", None)
kwargs["mean"] = kwargs.get("mean", _model.cfg["mean"])
kwargs["std"] = kwargs.get("std", _model.cfg["std"])
kwargs["batch_size"] = kwargs.get("batch_size", 2)
predictor = DetectionPredictor(
PreProcessor(_model.cfg["input_shape"][:-1] if is_tf_available() else _model.cfg["input_shape"][1:], **kwargs),
_model,
)
return predictor
[docs]
def detection_predictor(
arch: Any = "fast_base",
pretrained: bool = False,
assume_straight_pages: bool = True,
**kwargs: Any,
) -> DetectionPredictor:
"""Text detection architecture.
>>> import numpy as np
>>> from doctr.models import detection_predictor
>>> model = detection_predictor(arch='db_resnet50', pretrained=True)
>>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8)
>>> out = model([input_page])
Args:
----
arch: name of the architecture or model itself to use (e.g. 'db_resnet50')
pretrained: If True, returns a model pre-trained on our text detection dataset
assume_straight_pages: If True, fit straight boxes to the page
**kwargs: optional keyword arguments passed to the architecture
Returns:
-------
Detection predictor
"""
return _predictor(arch, pretrained, assume_straight_pages, **kwargs)