# Copyright (C) 2021-2024, Mindee.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.importjsonimportosfrompathlibimportPathfromtypingimportAny,Dict,List,Tupleimportnumpyasnpfrom.datasetsimportAbstractDataset__all__=["OCRDataset"]
[docs]classOCRDataset(AbstractDataset):"""Implements an OCR dataset >>> from doctr.datasets import OCRDataset >>> train_set = OCRDataset(img_folder="/path/to/images", >>> label_file="/path/to/labels.json") >>> img, target = train_set[0] Args: ---- img_folder: local path to image folder (all jpg at the root) label_file: local path to the label file use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) **kwargs: keyword arguments from `AbstractDataset`. """def__init__(self,img_folder:str,label_file:str,use_polygons:bool=False,**kwargs:Any,)->None:super().__init__(img_folder,**kwargs)# List imagesself.data:List[Tuple[str,Dict[str,Any]]]=[]np_dtype=np.float32withopen(label_file,"rb")asf:data=json.load(f)forimg_name,annotationsindata.items():# Get image pathimg_name=Path(img_name)# File existence checkifnotos.path.exists(os.path.join(self.root,img_name)):raiseFileNotFoundError(f"unable to locate {os.path.join(self.root,img_name)}")# handle empty imagesiflen(annotations["typed_words"])==0:self.data.append((img_name,dict(boxes=np.zeros((0,4),dtype=np_dtype),labels=[])))continue# Unpack the straight boxes (xmin, ymin, xmax, ymax)geoms=[list(map(float,obj["geometry"][:4]))forobjinannotations["typed_words"]]ifuse_polygons:# (x, y) coordinates of top left, top right, bottom right, bottom left cornersgeoms=[[geom[:2],[geom[2],geom[1]],geom[2:],[geom[0],geom[3]]]# type: ignore[list-item]forgeomingeoms]text_targets=[obj["value"]forobjinannotations["typed_words"]]self.data.append((img_name,dict(boxes=np.asarray(geoms,dtype=np_dtype),labels=text_targets)))