# Copyright (C) 2021-2024, Mindee.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.fromtypingimportAny,Listfromdoctr.file_utilsimportis_tf_availablefromdoctr.models.preprocessorimportPreProcessorfrom..importrecognitionfrom.predictorimportRecognitionPredictor__all__=["recognition_predictor"]ARCHS:List[str]=["crnn_vgg16_bn","crnn_mobilenet_v3_small","crnn_mobilenet_v3_large","sar_resnet31","master","vitstr_small","vitstr_base","parseq",]def_predictor(arch:Any,pretrained:bool,**kwargs:Any)->RecognitionPredictor:ifisinstance(arch,str):ifarchnotinARCHS:raiseValueError(f"unknown architecture '{arch}'")_model=recognition.__dict__[arch](pretrained=pretrained,pretrained_backbone=kwargs.get("pretrained_backbone",True))else:ifnotisinstance(arch,(recognition.CRNN,recognition.SAR,recognition.MASTER,recognition.ViTSTR,recognition.PARSeq)):raiseValueError(f"unknown architecture: {type(arch)}")_model=archkwargs.pop("pretrained_backbone",None)kwargs["mean"]=kwargs.get("mean",_model.cfg["mean"])kwargs["std"]=kwargs.get("std",_model.cfg["std"])kwargs["batch_size"]=kwargs.get("batch_size",128)input_shape=_model.cfg["input_shape"][:2]ifis_tf_available()else_model.cfg["input_shape"][-2:]predictor=RecognitionPredictor(PreProcessor(input_shape,preserve_aspect_ratio=True,**kwargs),_model)returnpredictor
[docs]defrecognition_predictor(arch:Any="crnn_vgg16_bn",pretrained:bool=False,**kwargs:Any)->RecognitionPredictor:"""Text recognition architecture. Example:: >>> import numpy as np >>> from doctr.models import recognition_predictor >>> model = recognition_predictor(pretrained=True) >>> input_page = (255 * np.random.rand(32, 128, 3)).astype(np.uint8) >>> out = model([input_page]) Args: ---- arch: name of the architecture or model itself to use (e.g. 'crnn_vgg16_bn') pretrained: If True, returns a model pre-trained on our text recognition dataset **kwargs: optional parameters to be passed to the architecture Returns: ------- Recognition predictor """return_predictor(arch,pretrained,**kwargs)