# Copyright (C) 2021-2024, Mindee.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.fromtypingimportAny,Listfromdoctr.file_utilsimportis_tf_availablefrom..importclassificationfrom..preprocessorimportPreProcessorfrom.predictorimportOrientationPredictor__all__=["crop_orientation_predictor","page_orientation_predictor"]ARCHS:List[str]=["magc_resnet31","mobilenet_v3_small","mobilenet_v3_small_r","mobilenet_v3_large","mobilenet_v3_large_r","resnet18","resnet31","resnet34","resnet50","resnet34_wide","textnet_tiny","textnet_small","textnet_base","vgg16_bn_r","vit_s","vit_b",]ORIENTATION_ARCHS:List[str]=["mobilenet_v3_small_crop_orientation","mobilenet_v3_small_page_orientation"]def_orientation_predictor(arch:str,pretrained:bool,**kwargs:Any)->OrientationPredictor:ifarchnotinORIENTATION_ARCHS:raiseValueError(f"unknown architecture '{arch}'")# Load directly classifier from backbone_model=classification.__dict__[arch](pretrained=pretrained)kwargs["mean"]=kwargs.get("mean",_model.cfg["mean"])kwargs["std"]=kwargs.get("std",_model.cfg["std"])kwargs["batch_size"]=kwargs.get("batch_size",128if"crop"inarchelse4)input_shape=_model.cfg["input_shape"][:-1]ifis_tf_available()else_model.cfg["input_shape"][1:]predictor=OrientationPredictor(PreProcessor(input_shape,preserve_aspect_ratio=True,symmetric_pad=True,**kwargs),_model)returnpredictor
[docs]defcrop_orientation_predictor(arch:str="mobilenet_v3_small_crop_orientation",pretrained:bool=False,**kwargs:Any)->OrientationPredictor:"""Crop orientation classification architecture. >>> import numpy as np >>> from doctr.models import crop_orientation_predictor >>> model = crop_orientation_predictor(arch='mobilenet_v3_small_crop_orientation', pretrained=True) >>> input_crop = (255 * np.random.rand(256, 256, 3)).astype(np.uint8) >>> out = model([input_crop]) Args: ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_crop_orientation') pretrained: If True, returns a model pre-trained on our recognition crops dataset **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: ------- OrientationPredictor """return_orientation_predictor(arch,pretrained,**kwargs)
[docs]defpage_orientation_predictor(arch:str="mobilenet_v3_small_page_orientation",pretrained:bool=False,**kwargs:Any)->OrientationPredictor:"""Page orientation classification architecture. >>> import numpy as np >>> from doctr.models import page_orientation_predictor >>> model = page_orientation_predictor(arch='mobilenet_v3_small_page_orientation', pretrained=True) >>> input_page = (255 * np.random.rand(512, 512, 3)).astype(np.uint8) >>> out = model([input_page]) Args: ---- arch: name of the architecture to use (e.g. 'mobilenet_v3_small_page_orientation') pretrained: If True, returns a model pre-trained on our recognition crops dataset **kwargs: keyword arguments to be passed to the OrientationPredictor Returns: ------- OrientationPredictor """return_orientation_predictor(arch,pretrained,**kwargs)