Source code for doctr.datasets.detection

# Copyright (C) 2021-2023, Mindee.

# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <> for full license details.

import json
import os
from typing import Any, Dict, List, Tuple, Type, Union

import numpy as np

from doctr.file_utils import CLASS_NAME

from .datasets import AbstractDataset
from .utils import pre_transform_multiclass

__all__ = ["DetectionDataset"]

[docs]class DetectionDataset(AbstractDataset): """Implements a text detection dataset >>> from doctr.datasets import DetectionDataset >>> train_set = DetectionDataset(img_folder="/path/to/images", >>> label_path="/path/to/labels.json") >>> img, target = train_set[0] Args: img_folder: folder with all the images of the dataset label_path: path to the annotations of each image use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) **kwargs: keyword arguments from `AbstractDataset`. """ def __init__( self, img_folder: str, label_path: str, use_polygons: bool = False, **kwargs: Any, ) -> None: super().__init__( img_folder, pre_transforms=pre_transform_multiclass, **kwargs, ) # File existence check self._class_names: List = [] if not os.path.exists(label_path): raise FileNotFoundError(f"unable to locate {label_path}") with open(label_path, "rb") as f: labels = json.load(f) List[Tuple[str, Tuple[np.ndarray, List[str]]]] = [] np_dtype = np.float32 for img_name, label in labels.items(): # File existence check if not os.path.exists(os.path.join(self.root, img_name)): raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") geoms, polygons_classes = self.format_polygons(label["polygons"], use_polygons, np_dtype), (np.asarray(geoms, dtype=np_dtype), polygons_classes))) def format_polygons( self, polygons: Union[List, Dict], use_polygons: bool, np_dtype: Type ) -> Tuple[np.ndarray, List[str]]: """format polygons into an array Args: polygons: the bounding boxes use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) np_dtype: dtype of array Returns: geoms: bounding boxes as np array polygons_classes: list of classes for each bounding box """ if isinstance(polygons, list): self._class_names += [CLASS_NAME] polygons_classes = [CLASS_NAME for _ in polygons] _polygons: np.ndarray = np.asarray(polygons, dtype=np_dtype) elif isinstance(polygons, dict): self._class_names += list(polygons.keys()) polygons_classes = [k for k, v in polygons.items() for _ in v] _polygons = np.concatenate([np.asarray(poly, dtype=np_dtype) for poly in polygons.values() if poly], axis=0) else: raise TypeError(f"polygons should be a dictionary or list, it was {type(polygons)}") geoms = _polygons if use_polygons else np.concatenate((_polygons.min(axis=1), _polygons.max(axis=1)), axis=1) return geoms, polygons_classes @property def class_names(self): return sorted(list(set(self._class_names)))