# Copyright (C) 2021-2025, Mindee.# This program is licensed under the Apache License 2.0.# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.importjsonimportosfromtypingimportAnyimportnumpyasnpfromdoctr.file_utilsimportCLASS_NAMEfrom.datasetsimportAbstractDatasetfrom.utilsimportpre_transform_multiclass__all__=["DetectionDataset"]
[docs]classDetectionDataset(AbstractDataset):"""Implements a text detection dataset >>> from doctr.datasets import DetectionDataset >>> train_set = DetectionDataset(img_folder="/path/to/images", >>> label_path="/path/to/labels.json") >>> img, target = train_set[0] Args: img_folder: folder with all the images of the dataset label_path: path to the annotations of each image use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) **kwargs: keyword arguments from `AbstractDataset`. """def__init__(self,img_folder:str,label_path:str,use_polygons:bool=False,**kwargs:Any,)->None:super().__init__(img_folder,pre_transforms=pre_transform_multiclass,**kwargs,)# File existence checkself._class_names:list=[]ifnotos.path.exists(label_path):raiseFileNotFoundError(f"unable to locate {label_path}")withopen(label_path,"rb")asf:labels=json.load(f)self.data:list[tuple[str,tuple[np.ndarray,list[str]]]]=[]np_dtype=np.float32forimg_name,labelinlabels.items():# File existence checkifnotos.path.exists(os.path.join(self.root,img_name)):raiseFileNotFoundError(f"unable to locate {os.path.join(self.root,img_name)}")geoms,polygons_classes=self.format_polygons(label["polygons"],use_polygons,np_dtype)self.data.append((img_name,(np.asarray(geoms,dtype=np_dtype),polygons_classes)))defformat_polygons(self,polygons:list|dict,use_polygons:bool,np_dtype:type)->tuple[np.ndarray,list[str]]:"""Format polygons into an array Args: polygons: the bounding boxes use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) np_dtype: dtype of array Returns: geoms: bounding boxes as np array polygons_classes: list of classes for each bounding box """ifisinstance(polygons,list):self._class_names+=[CLASS_NAME]polygons_classes=[CLASS_NAMEfor_inpolygons]_polygons:np.ndarray=np.asarray(polygons,dtype=np_dtype)elifisinstance(polygons,dict):self._class_names+=list(polygons.keys())polygons_classes=[kfork,vinpolygons.items()for_inv]_polygons=np.concatenate([np.asarray(poly,dtype=np_dtype)forpolyinpolygons.values()ifpoly],axis=0)else:raiseTypeError(f"polygons should be a dictionary or list, it was {type(polygons)}")geoms=_polygonsifuse_polygonselsenp.concatenate((_polygons.min(axis=1),_polygons.max(axis=1)),axis=1)returngeoms,polygons_classes@propertydefclass_names(self):returnsorted(set(self._class_names))