# Copyright (C) 2021-2025, Mindee.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.importcsvimportosfrompathlibimportPathfromtypingimportAnyimportnumpyasnpfromtqdmimporttqdmfrom.datasetsimportVisionDatasetfrom.utilsimportconvert_target_to_relative,crop_bboxes_from_image__all__=["SROIE"]
[docs]classSROIE(VisionDataset):"""SROIE dataset from `"ICDAR2019 Competition on Scanned Receipt OCR and Information Extraction" <https://arxiv.org/pdf/2103.10213.pdf>`_. .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/sroie-grid.png&src=0 :align: center >>> from doctr.datasets import SROIE >>> train_set = SROIE(train=True, download=True) >>> img, target = train_set[0] Args: train: whether the subset should be the training one use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) recognition_task: whether the dataset should be used for recognition task detection_task: whether the dataset should be used for detection task **kwargs: keyword arguments from `VisionDataset`. """TRAIN=("https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_train_task1.zip&src=0","d4fa9e60abb03500d83299c845b9c87fd9c9430d1aeac96b83c5d0bb0ab27f6f","sroie2019_train_task1.zip",)TEST=("https://doctr-static.mindee.com/models?id=v0.1.1/sroie2019_test.zip&src=0","41b3c746a20226fddc80d86d4b2a903d43b5be4f521dd1bbe759dbf8844745e2","sroie2019_test.zip",)def__init__(self,train:bool=True,use_polygons:bool=False,recognition_task:bool=False,detection_task:bool=False,**kwargs:Any,)->None:url,sha256,name=self.TRAINiftrainelseself.TESTsuper().__init__(url,name,sha256,True,pre_transforms=convert_target_to_relativeifnotrecognition_taskelseNone,**kwargs,)ifrecognition_taskanddetection_task:raiseValueError("`recognition_task` and `detection_task` cannot be set to True simultaneously. "+"To get the whole dataset with boxes and labels leave both parameters to False.")self.train=traintmp_root=os.path.join(self.root,"images")self.data:list[tuple[str|np.ndarray,str|dict[str,Any]|np.ndarray]]=[]np_dtype=np.float32forimg_pathintqdm(iterable=os.listdir(tmp_root),desc="Preparing and Loading SROIE",total=len(os.listdir(tmp_root))):# File existence checkifnotos.path.exists(os.path.join(tmp_root,img_path)):raiseFileNotFoundError(f"unable to locate {os.path.join(tmp_root,img_path)}")stem=Path(img_path).stemwithopen(os.path.join(self.root,"annotations",f"{stem}.txt"),encoding="latin")asf:_rows=[rowforrowinlist(csv.reader(f,delimiter=","))iflen(row)>0]labels=[",".join(row[8:])forrowin_rows]# reorder coordinates (8 -> (4,2) -># (x, y) coordinates of top left, top right, bottom right, bottom left corners) and filter empty linescoords:np.ndarray=np.stack([np.array(list(map(int,row[:8])),dtype=np_dtype).reshape((4,2))forrowin_rows],axis=0)ifnotuse_polygons:# xmin, ymin, xmax, ymaxcoords=np.concatenate((coords.min(axis=1),coords.max(axis=1)),axis=1)ifrecognition_task:crops=crop_bboxes_from_image(img_path=os.path.join(tmp_root,img_path),geoms=coords)forcrop,labelinzip(crops,labels):ifcrop.shape[0]>0andcrop.shape[1]>0andlen(label)>0:self.data.append((crop,label))elifdetection_task:self.data.append((img_path,coords))else:self.data.append((img_path,dict(boxes=coords,labels=labels)))self.root=tmp_rootdefextra_repr(self)->str:returnf"train={self.train}"