# Copyright (C) 2021-2025, Mindee.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.# Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.pyimportjsonimportloggingimportosimportsubprocessimporttextwrapfrompathlibimportPathfromtypingimportAnyfromhuggingface_hubimport(HfApi,Repository,get_token,get_token_permission,hf_hub_download,login,)fromdoctrimportmodelsfromdoctr.file_utilsimportis_tf_available,is_torch_availableifis_torch_available():importtorch__all__=["login_to_hub","push_to_hf_hub","from_hub","_save_model_and_config_for_hf_hub"]AVAILABLE_ARCHS={"classification":models.classification.zoo.ARCHS+models.classification.zoo.ORIENTATION_ARCHS,"detection":models.detection.zoo.ARCHS,"recognition":models.recognition.zoo.ARCHS,}
[docs]deflogin_to_hub()->None:# pragma: no cover"""Login to huggingface hub"""access_token=get_token()ifaccess_tokenisnotNoneandget_token_permission(access_token):logging.info("Huggingface Hub token found and valid")login(token=access_token,write_permission=True)else:login()# check if git lfs is installedtry:subprocess.call(["git","lfs","version"])exceptFileNotFoundError:raiseOSError("Looks like you do not have git-lfs installed, please install. \ You can install from https://git-lfs.github.com/. \ Then run `git lfs install` (you only have to do this once).")
def_save_model_and_config_for_hf_hub(model:Any,save_dir:str,arch:str,task:str)->None:"""Save model and config to disk for pushing to huggingface hub Args: model: TF or PyTorch model to be saved save_dir: directory to save model and config arch: architecture name task: task name """save_directory=Path(save_dir)ifis_torch_available():weights_path=save_directory/"pytorch_model.bin"torch.save(model.state_dict(),weights_path)elifis_tf_available():weights_path=save_directory/"tf_model.weights.h5"model.save_weights(str(weights_path))config_path=save_directory/"config.json"# add model configurationmodel_config=model.cfgmodel_config["arch"]=archmodel_config["task"]=taskwithconfig_path.open("w")asf:json.dump(model_config,f,indent=2,ensure_ascii=False)
[docs]defpush_to_hf_hub(model:Any,model_name:str,task:str,**kwargs)->None:# pragma: no cover"""Save model and its configuration on HF hub >>> from doctr.models import login_to_hub, push_to_hf_hub >>> from doctr.models.recognition import crnn_mobilenet_v3_small >>> login_to_hub() >>> model = crnn_mobilenet_v3_small(pretrained=True) >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small') Args: model: TF or PyTorch model to be saved model_name: name of the model which is also the repository name task: task name **kwargs: keyword arguments for push_to_hf_hub """run_config=kwargs.get("run_config",None)arch=kwargs.get("arch",None)ifrun_configisNoneandarchisNone:raiseValueError("run_config or arch must be specified")iftasknotin["classification","detection","recognition"]:raiseValueError("task must be one of classification, detection, recognition")# default readmereadme=textwrap.dedent(f""" language: en <p align="center"> <img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%"> </p> **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch** ## Task: {task} https://github.com/mindee/doctr ### Example usage: ```python >>> from doctr.io import DocumentFile >>> from doctr.models import ocr_predictor, from_hub >>> img = DocumentFile.from_images(['<image_path>']) >>> # Load your model from the hub >>> model = from_hub('mindee/my-model') >>> # Pass it to the predictor >>> # If your model is a recognition model: >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', >>> reco_arch=model, >>> pretrained=True) >>> # If your model is a detection model: >>> predictor = ocr_predictor(det_arch=model, >>> reco_arch='crnn_mobilenet_v3_small', >>> pretrained=True) >>> # Get your predictions >>> res = predictor(img) ``` """)# add run configuration to readme if availableifrun_configisnotNone:arch=run_config.archreadme+=textwrap.dedent(f"""### Run Configuration\n{json.dumps(vars(run_config),indent=2,ensure_ascii=False)}""")ifarchnotinAVAILABLE_ARCHS[task]:raiseValueError(f"Architecture: {arch} for task: {task} not found.\\nAvailable architectures: {AVAILABLE_ARCHS}")commit_message=f"Add {model_name} model"local_cache_dir=os.path.join(os.path.expanduser("~"),".cache","huggingface","hub",model_name)repo_url=HfApi().create_repo(model_name,token=get_token(),exist_ok=False)repo=Repository(local_dir=local_cache_dir,clone_from=repo_url)withrepo.commit(commit_message):_save_model_and_config_for_hf_hub(model,repo.local_dir,arch=arch,task=task)readme_path=Path(repo.local_dir)/"README.md"readme_path.write_text(readme)repo.git_push()
[docs]deffrom_hub(repo_id:str,**kwargs:Any):"""Instantiate & load a pretrained model from HF hub. >>> from doctr.models import from_hub >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn") Args: repo_id: HuggingFace model hub repo kwargs: kwargs of `hf_hub_download` or `snapshot_download` Returns: Model loaded with the checkpoint """# Get the configwithopen(hf_hub_download(repo_id,filename="config.json",**kwargs),"rb")asf:cfg=json.load(f)arch=cfg["arch"]task=cfg["task"]cfg.pop("arch")cfg.pop("task")iftask=="classification":model=models.classification.__dict__[arch](pretrained=False,classes=cfg["classes"],num_classes=cfg["num_classes"])eliftask=="detection":model=models.detection.__dict__[arch](pretrained=False)eliftask=="recognition":model=models.recognition.__dict__[arch](pretrained=False,input_shape=cfg["input_shape"],vocab=cfg["vocab"])# update model cfgmodel.cfg=cfg# Load checkpointifis_torch_available():state_dict=torch.load(hf_hub_download(repo_id,filename="pytorch_model.bin",**kwargs),map_location="cpu")model.load_state_dict(state_dict)else:# tfweights=hf_hub_download(repo_id,filename="tf_model.weights.h5",**kwargs)model.load_weights(weights)returnmodel