Source code for doctr.models.classification.vit.pytorch
# Copyright (C) 2021-2025, Mindee.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.fromcopyimportdeepcopyfromtypingimportAnyimporttorchfromtorchimportnnfromdoctr.datasetsimportVOCABSfromdoctr.models.modules.transformerimportEncoderBlockfromdoctr.models.modules.vision_transformer.pytorchimportPatchEmbeddingfrom...utils.pytorchimportload_pretrained_params__all__=["vit_s","vit_b"]default_cfgs:dict[str,dict[str,Any]]={"vit_s":{"mean":(0.694,0.695,0.693),"std":(0.299,0.296,0.301),"input_shape":(3,32,32),"classes":list(VOCABS["french"]),"url":"https://doctr-static.mindee.com/models?id=v0.6.0/vit_s-5d05442d.pt&src=0",},"vit_b":{"mean":(0.694,0.695,0.693),"std":(0.299,0.296,0.301),"input_shape":(3,32,32),"classes":list(VOCABS["french"]),"url":"https://doctr-static.mindee.com/models?id=v0.6.0/vit_b-0fbef167.pt&src=0",},}classClassifierHead(nn.Module):"""Classifier head for Vision Transformer Args: in_channels: number of input channels num_classes: number of output classes """def__init__(self,in_channels:int,num_classes:int,)->None:super().__init__()self.head=nn.Linear(in_channels,num_classes)defforward(self,x:torch.Tensor)->torch.Tensor:# (batch_size, num_classes) cls tokenreturnself.head(x[:,0])classVisionTransformer(nn.Sequential):"""VisionTransformer architecture as described in `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", <https://arxiv.org/pdf/2010.11929.pdf>`_. Args: d_model: dimension of the transformer layers num_layers: number of transformer layers num_heads: number of attention heads ffd_ratio: multiplier for the hidden dimension of the feedforward layer patch_size: size of the patches input_shape: size of the input image dropout: dropout rate num_classes: number of output classes include_top: whether the classifier head should be instantiated """def__init__(self,d_model:int,num_layers:int,num_heads:int,ffd_ratio:int,patch_size:tuple[int,int]=(4,4),input_shape:tuple[int,int,int]=(3,32,32),dropout:float=0.0,num_classes:int=1000,include_top:bool=True,cfg:dict[str,Any]|None=None,)->None:_layers:list[nn.Module]=[PatchEmbedding(input_shape,d_model,patch_size),EncoderBlock(num_layers,num_heads,d_model,d_model*ffd_ratio,dropout,nn.GELU()),]ifinclude_top:_layers.append(ClassifierHead(d_model,num_classes))super().__init__(*_layers)self.cfg=cfgdeffrom_pretrained(self,path_or_url:str,**kwargs:Any)->None:"""Load pretrained parameters onto the model Args: path_or_url: the path or URL to the model parameters (checkpoint) **kwargs: additional arguments to be passed to `doctr.models.utils.load_pretrained_params` """load_pretrained_params(self,path_or_url,**kwargs)def_vit(arch:str,pretrained:bool,ignore_keys:list[str]|None=None,**kwargs:Any,)->VisionTransformer:kwargs["num_classes"]=kwargs.get("num_classes",len(default_cfgs[arch]["classes"]))kwargs["input_shape"]=kwargs.get("input_shape",default_cfgs[arch]["input_shape"])kwargs["classes"]=kwargs.get("classes",default_cfgs[arch]["classes"])_cfg=deepcopy(default_cfgs[arch])_cfg["num_classes"]=kwargs["num_classes"]_cfg["input_shape"]=kwargs["input_shape"]_cfg["classes"]=kwargs["classes"]kwargs.pop("classes")# Build the modelmodel=VisionTransformer(cfg=_cfg,**kwargs)# Load pretrained parametersifpretrained:# The number of classes is not the same as the number of classes in the pretrained model =># remove the last layer weights_ignore_keys=ignore_keysifkwargs["num_classes"]!=len(default_cfgs[arch]["classes"])elseNonemodel.from_pretrained(default_cfgs[arch]["url"],ignore_keys=_ignore_keys)returnmodel
[docs]defvit_s(pretrained:bool=False,**kwargs:Any)->VisionTransformer:"""VisionTransformer-S architecture `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", <https://arxiv.org/pdf/2010.11929.pdf>`_. Patches: (H, W) -> (H/8, W/8) NOTE: unofficial config used in ViTSTR and ParSeq >>> import torch >>> from doctr.models import vit_s >>> model = vit_s(pretrained=False) >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32) >>> out = model(input_tensor) Args: pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the VisionTransformer architecture Returns: A feature extractor model """return_vit("vit_s",pretrained,d_model=384,num_layers=12,num_heads=6,ffd_ratio=4,ignore_keys=["2.head.weight","2.head.bias"],**kwargs,)
[docs]defvit_b(pretrained:bool=False,**kwargs:Any)->VisionTransformer:"""VisionTransformer-B architecture as described in `"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", <https://arxiv.org/pdf/2010.11929.pdf>`_. Patches: (H, W) -> (H/8, W/8) >>> import torch >>> from doctr.models import vit_b >>> model = vit_b(pretrained=False) >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=tf.float32) >>> out = model(input_tensor) Args: pretrained: boolean, True if model is pretrained **kwargs: keyword arguments of the VisionTransformer architecture Returns: A feature extractor model """return_vit("vit_b",pretrained,d_model=768,num_layers=12,num_heads=12,ffd_ratio=4,ignore_keys=["2.head.weight","2.head.bias"],**kwargs,)