# Copyright (C) 2021-2024, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
# Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py
import json
import logging
import os
import subprocess
import textwrap
from pathlib import Path
from typing import Any
from huggingface_hub import (
HfApi,
Repository,
get_token,
get_token_permission,
hf_hub_download,
login,
)
from doctr import models
from doctr.file_utils import is_tf_available, is_torch_available
if is_torch_available():
import torch
__all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"]
AVAILABLE_ARCHS = {
"classification": models.classification.zoo.ARCHS + models.classification.zoo.ORIENTATION_ARCHS,
"detection": models.detection.zoo.ARCHS,
"recognition": models.recognition.zoo.ARCHS,
}
[docs]
def login_to_hub() -> None: # pragma: no cover
"""Login to huggingface hub"""
access_token = get_token()
if access_token is not None and get_token_permission(access_token):
logging.info("Huggingface Hub token found and valid")
login(token=access_token, write_permission=True)
else:
login()
# check if git lfs is installed
try:
subprocess.call(["git", "lfs", "version"])
except FileNotFoundError:
raise OSError(
"Looks like you do not have git-lfs installed, please install. \
You can install from https://git-lfs.github.com/. \
Then run `git lfs install` (you only have to do this once)."
)
def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None:
"""Save model and config to disk for pushing to huggingface hub
Args:
----
model: TF or PyTorch model to be saved
save_dir: directory to save model and config
arch: architecture name
task: task name
"""
save_directory = Path(save_dir)
if is_torch_available():
weights_path = save_directory / "pytorch_model.bin"
torch.save(model.state_dict(), weights_path)
elif is_tf_available():
weights_path = save_directory / "tf_model.weights.h5"
model.save_weights(str(weights_path))
config_path = save_directory / "config.json"
# add model configuration
model_config = model.cfg
model_config["arch"] = arch
model_config["task"] = task
with config_path.open("w") as f:
json.dump(model_config, f, indent=2, ensure_ascii=False)
[docs]
def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # pragma: no cover
"""Save model and its configuration on HF hub
>>> from doctr.models import login_to_hub, push_to_hf_hub
>>> from doctr.models.recognition import crnn_mobilenet_v3_small
>>> login_to_hub()
>>> model = crnn_mobilenet_v3_small(pretrained=True)
>>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small')
Args:
----
model: TF or PyTorch model to be saved
model_name: name of the model which is also the repository name
task: task name
**kwargs: keyword arguments for push_to_hf_hub
"""
run_config = kwargs.get("run_config", None)
arch = kwargs.get("arch", None)
if run_config is None and arch is None:
raise ValueError("run_config or arch must be specified")
if task not in ["classification", "detection", "recognition"]:
raise ValueError("task must be one of classification, detection, recognition")
# default readme
readme = textwrap.dedent(
f"""
---
language: en
---
<p align="center">
<img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%">
</p>
**Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch**
## Task: {task}
https://github.com/mindee/doctr
### Example usage:
```python
>>> from doctr.io import DocumentFile
>>> from doctr.models import ocr_predictor, from_hub
>>> img = DocumentFile.from_images(['<image_path>'])
>>> # Load your model from the hub
>>> model = from_hub('mindee/my-model')
>>> # Pass it to the predictor
>>> # If your model is a recognition model:
>>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large',
>>> reco_arch=model,
>>> pretrained=True)
>>> # If your model is a detection model:
>>> predictor = ocr_predictor(det_arch=model,
>>> reco_arch='crnn_mobilenet_v3_small',
>>> pretrained=True)
>>> # Get your predictions
>>> res = predictor(img)
```
"""
)
# add run configuration to readme if available
if run_config is not None:
arch = run_config.arch
readme += textwrap.dedent(
f"""### Run Configuration
\n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}"""
)
if arch not in AVAILABLE_ARCHS[task]:
raise ValueError(
f"Architecture: {arch} for task: {task} not found.\
\nAvailable architectures: {AVAILABLE_ARCHS}"
)
commit_message = f"Add {model_name} model"
local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name)
repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False)
repo = Repository(local_dir=local_cache_dir, clone_from=repo_url)
with repo.commit(commit_message):
_save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task)
readme_path = Path(repo.local_dir) / "README.md"
readme_path.write_text(readme)
repo.git_push()
[docs]
def from_hub(repo_id: str, **kwargs: Any):
"""Instantiate & load a pretrained model from HF hub.
>>> from doctr.models import from_hub
>>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn")
Args:
----
repo_id: HuggingFace model hub repo
kwargs: kwargs of `hf_hub_download` or `snapshot_download`
Returns:
-------
Model loaded with the checkpoint
"""
# Get the config
with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f:
cfg = json.load(f)
arch = cfg["arch"]
task = cfg["task"]
cfg.pop("arch")
cfg.pop("task")
if task == "classification":
model = models.classification.__dict__[arch](
pretrained=False, classes=cfg["classes"], num_classes=cfg["num_classes"]
)
elif task == "detection":
model = models.detection.__dict__[arch](pretrained=False)
elif task == "recognition":
model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"])
# update model cfg
model.cfg = cfg
# Load checkpoint
if is_torch_available():
state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu")
model.load_state_dict(state_dict)
else: # tf
weights = hf_hub_download(repo_id, filename="tf_model.weights.h5", **kwargs)
model.load_weights(weights)
return model