Source code for doctr.models.factory.hub

# Copyright (C) 2021-2024, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

# Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py

import json
import logging
import os
import subprocess
import textwrap
from pathlib import Path
from typing import Any

from huggingface_hub import (
    HfApi,
    Repository,
    get_token,
    get_token_permission,
    hf_hub_download,
    login,
)

from doctr import models
from doctr.file_utils import is_tf_available, is_torch_available

if is_torch_available():
    import torch

__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]


AVAILABLE_ARCHS = {
    "classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
    "detection": models.detection.zoo.ARCHS,
    "recognition": models.recognition.zoo.ARCHS,
}


[docs] def login_to_hub() -> None: # pragma: no cover """Login to huggingface hub""" access_token = get_token() if access_token is not None and get_token_permission(access_token): logging.info("Huggingface Hub token found and valid") login(token=access_token, write_permission=True) else: login() # check if git lfs is installed try: subprocess.call(["git", "lfs", "version"]) except FileNotFoundError: raise OSError( "Looks like you do not have git-lfs installed, please install. \ You can install from https://git-lfs.github.com/. \ Then run `git lfs install` (you only have to do this once)." )
def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None: """Save model and config to disk for pushing to huggingface hub Args: model: TF or PyTorch model to be saved save_dir: directory to save model and config arch: architecture name task: task name """ save_directory = Path(save_dir) if is_torch_available(): weights_path = save_directory / "pytorch_model.bin" torch.save(model.state_dict(), weights_path) elif is_tf_available(): weights_path = save_directory / "tf_model.weights.h5" model.save_weights(str(weights_path)) config_path = save_directory / "config.json" # add model configuration model_config = model.cfg model_config["arch"] = arch model_config["task"] = task with config_path.open("w") as f: json.dump(model_config, f, indent=2, ensure_ascii=False)
[docs] def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # pragma: no cover """Save model and its configuration on HF hub >>> from doctr.models import login_to_hub, push_to_hf_hub >>> from doctr.models.recognition import crnn_mobilenet_v3_small >>> login_to_hub() >>> model = crnn_mobilenet_v3_small(pretrained=True) >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small') Args: model: TF or PyTorch model to be saved model_name: name of the model which is also the repository name task: task name **kwargs: keyword arguments for push_to_hf_hub """ run_config = kwargs.get("run_config", None) arch = kwargs.get("arch", None) if run_config is None and arch is None: raise ValueError("run_config or arch must be specified") if task not in ["classification", "detection", "recognition"]: raise ValueError("task must be one of classification, detection, recognition") # default readme readme = textwrap.dedent( f""" language: en <p align="center"> <img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%"> </p> **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch** ## Task: {task} https://github.com/mindee/doctr ### Example usage: ```python >>> from doctr.io import DocumentFile >>> from doctr.models import ocr_predictor, from_hub >>> img = DocumentFile.from_images(['<image_path>']) >>> # Load your model from the hub >>> model = from_hub('mindee/my-model') >>> # Pass it to the predictor >>> # If your model is a recognition model: >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', >>> reco_arch=model, >>> pretrained=True) >>> # If your model is a detection model: >>> predictor = ocr_predictor(det_arch=model, >>> reco_arch='crnn_mobilenet_v3_small', >>> pretrained=True) >>> # Get your predictions >>> res = predictor(img) ``` """ ) # add run configuration to readme if available if run_config is not None: arch = run_config.arch readme += textwrap.dedent( f"""### Run Configuration \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}""" ) if arch not in AVAILABLE_ARCHS[task]: raise ValueError( f"Architecture: {arch} for task: {task} not found.\ \nAvailable architectures: {AVAILABLE_ARCHS}" ) commit_message = f"Add {model_name} model" local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name) repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False) repo = Repository(local_dir=local_cache_dir, clone_from=repo_url) with repo.commit(commit_message): _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task) readme_path = Path(repo.local_dir) / "README.md" readme_path.write_text(readme) repo.git_push()
[docs] def from_hub(repo_id: str, **kwargs: Any): """Instantiate & load a pretrained model from HF hub. >>> from doctr.models import from_hub >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn") Args: repo_id: HuggingFace model hub repo kwargs: kwargs of `hf_hub_download` or `snapshot_download` Returns: Model loaded with the checkpoint """ # Get the config with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f: cfg = json.load(f) arch = cfg["arch"] task = cfg["task"] cfg.pop("arch") cfg.pop("task") if task == "classification": model = models.classification.__dict__[arch]( pretrained=False, classes=cfg["classes"], num_classes=cfg["num_classes"] ) elif task == "detection": model = models.detection.__dict__[arch](pretrained=False) elif task == "recognition": model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"]) # update model cfg model.cfg = cfg # Load checkpoint if is_torch_available(): state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu") model.load_state_dict(state_dict) else: # tf weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs) model.load_weights(weights) return model