Source code for doctr.models.classification.vit.pytorch

# Copyright (C) 2021-2025, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from copy import deepcopy
from typing import Any

import torch
from torch import nn

from doctr.datasets import VOCABS
from doctr.models.modules.transformer import EncoderBlock
from doctr.models.modules.vision_transformer.pytorch import PatchEmbedding

from ...utils.pytorch import load_pretrained_params

__all__ = ["vit_s", "vit_b"]


default_cfgs: dict[str, dict[str, Any]] = {
    "vit_s": {
        "mean": (0.694, 0.695, 0.693),
        "std": (0.299, 0.296, 0.301),
        "input_shape": (3, 32, 32),
        "classes": list(VOCABS["french"]),
        "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-5d05442d.pt&src=0",
    },
    "vit_b": {
        "mean": (0.694, 0.695, 0.693),
        "std": (0.299, 0.296, 0.301),
        "input_shape": (3, 32, 32),
        "classes": list(VOCABS["french"]),
        "url": "https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-0fbef167.pt&src=0",
    },
}


class ClassifierHead(nn.Module):
    """Classifier head for Vision Transformer

    Args:
        in_channels: number of input channels
        num_classes: number of output classes
    """

    def __init__(
        self,
        in_channels: int,
        num_classes: int,
    ) -> None:
        super().__init__()

        self.head = nn.Linear(in_channels, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # (batch_size, num_classes) cls token
        return self.head(x[:, 0])


class VisionTransformer(nn.Sequential):
    """VisionTransformer architecture as described in
    `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale",
    <https://arxiv.org/pdf/2010.11929.pdf>`_.

    Args:
        d_model: dimension of the transformer layers
        num_layers: number of transformer layers
        num_heads: number of attention heads
        ffd_ratio: multiplier for the hidden dimension of the feedforward layer
        patch_size: size of the patches
        input_shape: size of the input image
        dropout: dropout rate
        num_classes: number of output classes
        include_top: whether the classifier head should be instantiated
    """

    def __init__(
        self,
        d_model: int,
        num_layers: int,
        num_heads: int,
        ffd_ratio: int,
        patch_size: tuple[int, int] = (4, 4),
        input_shape: tuple[int, int, int] = (3, 32, 32),
        dropout: float = 0.0,
        num_classes: int = 1000,
        include_top: bool = True,
        cfg: dict[str, Any] | None = None,
    ) -> None:
        _layers: list[nn.Module] = [
            PatchEmbedding(input_shape, d_model, patch_size),
            EncoderBlock(num_layers, num_heads, d_model, d_model * ffd_ratio, dropout, nn.GELU()),
        ]
        if include_top:
            _layers.append(ClassifierHead(d_model, num_classes))

        super().__init__(*_layers)
        self.cfg = cfg


def _vit(
    arch: str,
    pretrained: bool,
    ignore_keys: list[str] | None = None,
    **kwargs: Any,
) -> VisionTransformer:
    kwargs["num_classes"] = kwargs.get("num_classes", len(default_cfgs[arch]["classes"]))
    kwargs["input_shape"] = kwargs.get("input_shape", default_cfgs[arch]["input_shape"])
    kwargs["classes"] = kwargs.get("classes", default_cfgs[arch]["classes"])

    _cfg = deepcopy(default_cfgs[arch])
    _cfg["num_classes"] = kwargs["num_classes"]
    _cfg["input_shape"] = kwargs["input_shape"]
    _cfg["classes"] = kwargs["classes"]
    kwargs.pop("classes")

    # Build the model
    model = VisionTransformer(cfg=_cfg, **kwargs)
    # Load pretrained parameters
    if pretrained:
        # The number of classes is not the same as the number of classes in the pretrained model =>
        # remove the last layer weights
        _ignore_keys = ignore_keys if kwargs["num_classes"] != len(default_cfgs[arch]["classes"]) else None
        load_pretrained_params(model, default_cfgs[arch]["url"], ignore_keys=_ignore_keys)

    return model


[docs] def vit_s(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: """VisionTransformer-S architecture `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", <https://arxiv.org/pdf/2010.11929.pdf>`_. Patches: (H, W) -> (H/8, W/8) NOTE: unofficial config used in ViTSTR and ParSeq >>> import torch >>> from doctr.models import vit_s >>> model = vit_s(pretrained=False) >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32) >>> out = model(input_tensor) Args: pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the VisionTransformer architecture Returns: A feature extractor model """ return _vit( "vit_s", pretrained, d_model=384, num_layers=12, num_heads=6, ffd_ratio=4, ignore_keys=["2.head.weight", "2.head.bias"], **kwargs, )
[docs] def vit_b(pretrained: bool = False, **kwargs: Any) -> VisionTransformer: """VisionTransformer-B architecture as described in `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", <https://arxiv.org/pdf/2010.11929.pdf>`_. Patches: (H, W) -> (H/8, W/8) >>> import torch >>> from doctr.models import vit_b >>> model = vit_b(pretrained=False) >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32) >>> out = model(input_tensor) Args: pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the VisionTransformer architecture Returns: A feature extractor model """ return _vit( "vit_b", pretrained, d_model=768, num_layers=12, num_heads=12, ffd_ratio=4, ignore_keys=["2.head.weight", "2.head.bias"], **kwargs, )