# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py
from copy import deepcopy
from typing import Any
from torchvision.models import mobilenetv3
from torchvision.models.mobilenetv3 import MobileNetV3
from doctr.datasets import VOCABS
from ...utils import load_pretrained_params
__all__ = [
"MobileNetV3",
"mobilenet_v3_small",
"mobilenet_v3_small_r",
"mobilenet_v3_large",
"mobilenet_v3_large_r",
"mobilenet_v3_small_crop_orientation",
"mobilenet_v3_small_page_orientation",
]
default_cfgs: dict[str, dict[str, Any]] = {
"mobilenet_v3_large": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large-11fc8cb9.pt&src=0",
},
"mobilenet_v3_large_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_large_r-74a22066.pt&src=0",
},
"mobilenet_v3_small": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small-6a4bfa6b.pt&src=0",
},
"mobilenet_v3_small_r": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 32, 32),
"classes": list(VOCABS["french"]),
"url": "https://doctr-static.mindee.com/models?id=v0.4.1/mobilenet_v3_small_r-1a8a3530.pt&src=0",
},
"mobilenet_v3_small_crop_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 256, 256),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_crop_orientation-f0847a18.pt&src=0",
},
"mobilenet_v3_small_page_orientation": {
"mean": (0.694, 0.695, 0.693),
"std": (0.299, 0.296, 0.301),
"input_shape": (3, 512, 512),
"classes": [0, -90, 180, 90],
"url": "https://doctr-static.mindee.com/models?id=v0.8.1/mobilenet_v3_small_page_orientation-8e60325c.pt&src=0",
},
}
def _mobilenet_v3(
arch: str,
pretrained: bool,
rect_strides: list[str] | None = None,
ignore_keys: list[str] | None = None,
**kwargs: Any,
) -> mobilenetv3.MobileNetV3:
kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])
_cfg = deepcopy(default_cfgs[arch])
_cfg["num_classes"] = kwargs["num_classes"]
_cfg["classes"] = kwargs["classes"]
kwargs.pop("classes")
if arch.startswith("mobilenet_v3_small"):
model = mobilenetv3.mobilenet_v3_small(**kwargs, weights=None)
else:
model = mobilenetv3.mobilenet_v3_large(**kwargs, weights=None)
# Rectangular strides
if isinstance(rect_strides, list):
for layer_name in rect_strides:
m = model
for child in layer_name.split("."):
m = getattr(m, child)
m.stride = (2, 1)
# Load pretrained parameters
if pretrained:
# The number of classes is not the same as the number of classes in the pretrained model =>
# remove the last layer weights
_ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)
model.cfg = _cfg
return model
[docs]
def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.
>>> import torch
>>> from doctr.models import mobilenet_v3_small
>>> model = mobilenetv3_small(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
a torch.nn.Module
"""
return _mobilenet_v3(
"mobilenet_v3_small", pretrained, ignore_keys=["classifier.3.weight", "classifier.3.bias"], **kwargs
)
[docs]
def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_, with rectangular pooling.
>>> import torch
>>> from doctr.models import mobilenet_v3_small_r
>>> model = mobilenet_v3_small_r(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
a torch.nn.Module
"""
return _mobilenet_v3(
"mobilenet_v3_small_r",
pretrained,
["features.2.block.1.0", "features.4.block.1.0", "features.9.block.1.0"],
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
**kwargs,
)
[docs]
def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Large architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.
>>> import torch
>>> from doctr.models import mobilenet_v3_large
>>> model = mobilenet_v3_large(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
a torch.nn.Module
"""
return _mobilenet_v3(
"mobilenet_v3_large",
pretrained,
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
**kwargs,
)
[docs]
def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Large architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_, with rectangular pooling.
>>> import torch
>>> from doctr.models import mobilenet_v3_large_r
>>> model = mobilenet_v3_large_r(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
a torch.nn.Module
"""
return _mobilenet_v3(
"mobilenet_v3_large_r",
pretrained,
["features.4.block.1.0", "features.7.block.1.0", "features.13.block.1.0"],
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
**kwargs,
)
[docs]
def mobilenet_v3_small_crop_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.
>>> import torch
>>> from doctr.models import mobilenet_v3_small_crop_orientation
>>> model = mobilenet_v3_small_crop_orientation(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
a torch.nn.Module
"""
return _mobilenet_v3(
"mobilenet_v3_small_crop_orientation",
pretrained,
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
**kwargs,
)
[docs]
def mobilenet_v3_small_page_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3:
"""MobileNetV3-Small architecture as described in
`"Searching for MobileNetV3",
<https://arxiv.org/pdf/1905.02244.pdf>`_.
>>> import torch
>>> from doctr.models import mobilenet_v3_small_page_orientation
>>> model = mobilenet_v3_small_page_orientation(pretrained=False)
>>> input_tensor = torch.rand((1, 3, 512, 512), dtype=torch.float32)
>>> out = model(input_tensor)
Args:
pretrained: boolean, True if model is pretrained
**kwargs: keyword arguments of the MobileNetV3 architecture
Returns:
a torch.nn.Module
"""
return _mobilenet_v3(
"mobilenet_v3_small_page_orientation",
pretrained,
ignore_keys=["classifier.3.weight", "classifier.3.bias"],
**kwargs,
)