Source code for doctr.models.recognition.viptr.pytorch
# Copyright (C) 2021-2025, Mindee.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.fromcollections.abcimportCallablefromcopyimportdeepcopyfromitertoolsimportgroupbyfromtypingimportAnyimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorchvision.models._utilsimportIntermediateLayerGetterfromdoctr.datasetsimportVOCABS,decode_sequencefrom...classificationimportvip_base,vip_tinyfrom...utils.pytorchimport_bf16_to_float32,load_pretrained_paramsfrom..coreimportRecognitionModel,RecognitionPostProcessor__all__=["VIPTR","viptr_base","viptr_tiny"]default_cfgs:dict[str,dict[str,Any]]={"viptr_tiny":{"mean":(0.694,0.695,0.693),"std":(0.299,0.296,0.301),"input_shape":(3,32,128),"vocab":VOCABS["french"],"url":None,},"viptr_base":{"mean":(0.694,0.695,0.693),"std":(0.299,0.296,0.301),"input_shape":(3,32,128),"vocab":VOCABS["french"],"url":None,},}classVIPTRPostProcessor(RecognitionPostProcessor):"""Postprocess raw prediction of the model (logits) to a list of words using CTC decoding Args: vocab: string containing the ordered sequence of supported characters """@staticmethoddefctc_best_path(logits:torch.Tensor,vocab:str=VOCABS["french"],blank:int=0,)->list[tuple[str,float]]:"""Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from <https://github.com/githubharald/CTCDecoder>`_. Args: logits: model output, shape: N x T x C vocab: vocabulary to use blank: index of blank label Returns: A list of tuples: (word, confidence) """# Gather the most confident characters, and assign the smallest conf among those to the sequence probprobs=F.softmax(logits,dim=-1).max(dim=-1).values.min(dim=1).values# collapse best path (using itertools.groupby), map to chars, join char list to stringwords=[decode_sequence([kfork,_ingroupby(seq.tolist())ifk!=blank],vocab)forseqintorch.argmax(logits,dim=-1)]returnlist(zip(words,probs.tolist()))def__call__(self,logits:torch.Tensor)->list[tuple[str,float]]:"""Performs decoding of raw output with CTC and decoding of CTC predictions with label_to_idx mapping dictionnary Args: logits: raw output of the model, shape (N, C + 1, seq_len) Returns: A tuple of 2 lists: a list of str (words) and a list of float (probs) """# Decode CTCreturnself.ctc_best_path(logits=logits,vocab=self.vocab,blank=len(self.vocab))classVIPTR(RecognitionModel,nn.Module):"""Implements a VIPTR architecture as described in `"A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_. Args: feature_extractor: the backbone serving as feature extractor vocab: vocabulary used for encoding input_shape: input shape of the image exportable: onnx exportable returns only logits cfg: configuration dictionary """def__init__(self,feature_extractor:nn.Module,vocab:str,input_shape:tuple[int,int,int]=(3,32,128),exportable:bool=False,cfg:dict[str,Any]|None=None,):super().__init__()self.vocab=vocabself.exportable=exportableself.cfg=cfgself.max_length=32self.vocab_size=len(vocab)self.feat_extractor=feature_extractorwithtorch.inference_mode():embedding_units=self.feat_extractor(torch.zeros((1,*input_shape)))["features"].shape[-1]self.postprocessor=VIPTRPostProcessor(vocab=self.vocab)self.head=nn.Linear(embedding_units,len(self.vocab)+1)# +1 for PADforn,minself.named_modules():# Don't override the initialization of the backboneifn.startswith("feat_extractor."):continueifisinstance(m,nn.Linear):nn.init.trunc_normal_(m.weight,std=0.02)ifm.biasisnotNone:nn.init.zeros_(m.bias)defforward(self,x:torch.Tensor,target:list[str]|None=None,return_model_output:bool=False,return_preds:bool=False,)->dict[str,Any]:iftargetisnotNone:_gt,_seq_len=self.build_target(target)gt,seq_len=torch.from_numpy(_gt).to(dtype=torch.long),torch.tensor(_seq_len)gt,seq_len=gt.to(x.device),seq_len.to(x.device)ifself.trainingandtargetisNone:raiseValueError("Need to provide labels during training")features=self.feat_extractor(x)["features"]# (B, max_len, embed_dim)B,N,E=features.size()logits=self.head(features).view(B,N,len(self.vocab)+1)decoded_features=_bf16_to_float32(logits)out:dict[str,Any]={}ifself.exportable:out["logits"]=decoded_featuresreturnoutifreturn_model_output:out["out_map"]=decoded_featuresiftargetisNoneorreturn_preds:# Disable for torch.compile compatibility@torch.compiler.disable# type: ignore[attr-defined]def_postprocess(decoded_features:torch.Tensor)->list[tuple[str,float]]:returnself.postprocessor(decoded_features)# Post-process boxesout["preds"]=_postprocess(decoded_features)iftargetisnotNone:out["loss"]=self.compute_loss(decoded_features,gt,seq_len,len(self.vocab))returnout@staticmethoddefcompute_loss(model_output:torch.Tensor,gt:torch.Tensor,seq_len:torch.Tensor,blank_idx:int=0,)->torch.Tensor:"""Compute CTC loss for the model. Args: model_output: predicted logits of the model gt: ground truth tensor seq_len: sequence lengths of the ground truth blank_idx: index of the blank label Returns: The loss of the model on the batch """batch_len=model_output.shape[0]input_length=model_output.shape[1]*torch.ones(size=(batch_len,),dtype=torch.int32)# N x T x C -> T x N x Clogits=model_output.permute(1,0,2)probs=F.log_softmax(logits,dim=-1)ctc_loss=F.ctc_loss(probs,gt,input_length,seq_len,blank_idx,zero_infinity=True,)returnctc_lossdef_viptr(arch:str,pretrained:bool,backbone_fn:Callable[[bool],nn.Module],layer:str,pretrained_backbone:bool=True,ignore_keys:list[str]|None=None,**kwargs:Any,)->VIPTR:pretrained_backbone=pretrained_backboneandnotpretrained# Patch the config_cfg=deepcopy(default_cfgs[arch])_cfg["vocab"]=kwargs.get("vocab",_cfg["vocab"])_cfg["input_shape"]=kwargs.get("input_shape",_cfg["input_shape"])# Feature extractorfeat_extractor=IntermediateLayerGetter(backbone_fn(pretrained_backbone,input_shape=_cfg["input_shape"]),# type: ignore[call-arg]{layer:"features"},)kwargs["vocab"]=_cfg["vocab"]kwargs["input_shape"]=_cfg["input_shape"]model=VIPTR(feat_extractor,cfg=_cfg,**kwargs)# Load pretrained parametersifpretrained:# The number of classes is not the same as the number of classes in the pretrained model =># remove the last layer weights_ignore_keys=ignore_keysif_cfg["vocab"]!=default_cfgs[arch]["vocab"]elseNoneload_pretrained_params(model,default_cfgs[arch]["url"],ignore_keys=_ignore_keys)returnmodel
[docs]defviptr_base(pretrained:bool=False,**kwargs:Any)->VIPTR:"""VIPTR-Base as described in `"A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_. >>> import torch >>> from doctr.models import viptr_base >>> model = viptr_base(pretrained=False) >>> input_tensor = torch.rand((1, 3, 32, 128)) >>> out = model(input_tensor) Args: pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the VIPTR architecture Returns: VIPTR: a VIPTR model instance """return_viptr("viptr_base",pretrained,vip_base,"5",ignore_keys=["head.weight","head.bias"],**kwargs,)
[docs]defviptr_tiny(pretrained:bool=False,**kwargs:Any)->VIPTR:"""VIPTR-Tiny as described in `"A Vision Permutable Extractor for Fast and Efficient Scene Text Recognition" <https://arxiv.org/abs/2401.10110>`_. >>> import torch >>> from doctr.models import viptr_tiny >>> model = viptr_tiny(pretrained=False) >>> input_tensor = torch.rand((1, 3, 32, 128)) >>> out = model(input_tensor) Args: pretrained (bool): If True, returns a model pre-trained on our text recognition dataset **kwargs: keyword arguments of the VIPTR architecture Returns: VIPTR: a VIPTR model instance """return_viptr("viptr_tiny",pretrained,vip_tiny,"5",ignore_keys=["head.weight","head.bias"],**kwargs,)