# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import os
from typing import Any
import defusedxml.ElementTree as ET
import numpy as np
from tqdm import tqdm
from .datasets import VisionDataset
from .utils import convert_target_to_relative, crop_bboxes_from_image
__all__ = ["SVT"]
[docs]
class SVT(VisionDataset):
    """SVT dataset from `"The Street View Text Dataset - UCSD Computer Vision"
    <http://vision.ucsd.edu/~kai/svt/>`_.
    .. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svt-grid.png&src=0
        :align: center
    >>> from doctr.datasets import SVT
    >>> train_set = SVT(train=True, download=True)
    >>> img, target = train_set[0]
    Args:
        train: whether the subset should be the training one
        use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
        recognition_task: whether the dataset should be used for recognition task
        detection_task: whether the dataset should be used for detection task
        **kwargs: keyword arguments from `VisionDataset`.
    """
    URL = "http://www.iapr-tc11.org/dataset/SVT/svt.zip"
    SHA256 = "63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf"
    def __init__(
        self,
        train: bool = True,
        use_polygons: bool = False,
        recognition_task: bool = False,
        detection_task: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__(
            self.URL,
            None,
            self.SHA256,
            True,
            pre_transforms=convert_target_to_relative if not recognition_task else None,
            **kwargs,
        )
        if recognition_task and detection_task:
            raise ValueError(
                "`recognition_task` and `detection_task` cannot be set to True simultaneously. "
                + "To get the whole dataset with boxes and labels leave both parameters to False."
            )
        self.train = train
        self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
        np_dtype = np.float32
        # Load xml data
        tmp_root = os.path.join(self.root, "svt1") if self.SHA256 else self.root
        xml_tree = (
            ET.parse(os.path.join(tmp_root, "train.xml"))
            if self.train
            else ET.parse(os.path.join(tmp_root, "test.xml"))
        )
        xml_root = xml_tree.getroot()
        for image in tqdm(iterable=xml_root, desc="Preparing and Loading SVT", total=len(xml_root)):
            name, _, _, _resolution, rectangles = image
            # File existence check
            if not os.path.exists(os.path.join(tmp_root, name.text)):
                raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}")
            if use_polygons:
                # (x, y) coordinates of top left, top right, bottom right, bottom left corners
                _boxes = [
                    [
                        [float(rect.attrib["x"]), float(rect.attrib["y"])],
                        [float(rect.attrib["x"]) + float(rect.attrib["width"]), float(rect.attrib["y"])],
                        [
                            float(rect.attrib["x"]) + float(rect.attrib["width"]),
                            float(rect.attrib["y"]) + float(rect.attrib["height"]),
                        ],
                        [float(rect.attrib["x"]), float(rect.attrib["y"]) + float(rect.attrib["height"])],
                    ]
                    for rect in rectangles
                ]
            else:
                # x_min, y_min, x_max, y_max
                _boxes = [
                    [
                        float(rect.attrib["x"]),  # type: ignore[list-item]
                        float(rect.attrib["y"]),  # type: ignore[list-item]
                        float(rect.attrib["x"]) + float(rect.attrib["width"]),  # type: ignore[list-item]
                        float(rect.attrib["y"]) + float(rect.attrib["height"]),  # type: ignore[list-item]
                    ]
                    for rect in rectangles
                ]
            boxes: np.ndarray = np.asarray(_boxes, dtype=np_dtype)
            # Get the labels
            labels = [lab.text for rect in rectangles for lab in rect]
            if recognition_task:
                crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, name.text), geoms=boxes)
                for crop, label in zip(crops, labels):
                    if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0 and " " not in label:
                        self.data.append((crop, label))
            elif detection_task:
                self.data.append((name.text, boxes))
            else:
                self.data.append((name.text, dict(boxes=boxes, labels=labels)))
        self.root = tmp_root
    def extra_repr(self) -> str:
        return f"train={self.train}"