# Copyright (C) 2021-2025, Mindee.
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.
import os
from typing import Any
import h5py
import numpy as np
from tqdm import tqdm
from .datasets import VisionDataset
from .utils import convert_target_to_relative, crop_bboxes_from_image
__all__ = ["SVHN"]
[docs]
class SVHN(VisionDataset):
"""SVHN dataset from `"The Street View House Numbers (SVHN) Dataset"
<http://ufldl.stanford.edu/housenumbers/>`_.
.. image:: https://doctr-static.mindee.com/models?id=v0.5.0/svhn-grid.png&src=0
:align: center
>>> from doctr.datasets import SVHN
>>> train_set = SVHN(train=True, download=True)
>>> img, target = train_set[0]
Args:
train: whether the subset should be the training one
use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones)
recognition_task: whether the dataset should be used for recognition task
detection_task: whether the dataset should be used for detection task
**kwargs: keyword arguments from `VisionDataset`.
"""
TRAIN = (
"http://ufldl.stanford.edu/housenumbers/train.tar.gz",
"4b17bb33b6cd8f963493168f80143da956f28ec406cc12f8e5745a9f91a51898",
"svhn_train.tar",
)
TEST = (
"http://ufldl.stanford.edu/housenumbers/test.tar.gz",
"57ac9ceb530e4aa85b55d991be8fc49c695b3d71c6f6a88afea86549efde7fb5",
"svhn_test.tar",
)
def __init__(
self,
train: bool = True,
use_polygons: bool = False,
recognition_task: bool = False,
detection_task: bool = False,
**kwargs: Any,
) -> None:
url, sha256, name = self.TRAIN if train else self.TEST
super().__init__(
url,
file_name=name,
file_hash=sha256,
extract_archive=True,
pre_transforms=convert_target_to_relative if not recognition_task else None,
**kwargs,
)
if recognition_task and detection_task:
raise ValueError(
"`recognition_task` and `detection_task` cannot be set to True simultaneously. "
+ "To get the whole dataset with boxes and labels leave both parameters to False."
)
self.train = train
self.data: list[tuple[str | np.ndarray, str | dict[str, Any] | np.ndarray]] = []
np_dtype = np.float32
tmp_root = os.path.join(self.root, "train" if train else "test")
# Load mat data (matlab v7.3 - can not be loaded with scipy)
with h5py.File(os.path.join(tmp_root, "digitStruct.mat"), "r") as f:
img_refs = f["digitStruct/name"]
box_refs = f["digitStruct/bbox"]
for img_ref, box_ref in tqdm(
iterable=zip(img_refs, box_refs), desc="Preparing and Loading SVHN", total=len(img_refs)
):
# convert ascii matrix to string
img_name = "".join(map(chr, f[img_ref[0]][()].flatten()))
# File existence check
if not os.path.exists(os.path.join(tmp_root, img_name)):
raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}")
# Unpack the information
box = f[box_ref[0]]
if box["left"].shape[0] == 1:
box_dict = {k: [int(vals[0][0])] for k, vals in box.items()}
else:
box_dict = {k: [int(f[v[0]][()].item()) for v in vals] for k, vals in box.items()}
# Convert it to the right format
coords: np.ndarray = np.array(
[box_dict["left"], box_dict["top"], box_dict["width"], box_dict["height"]], dtype=np_dtype
).transpose()
label_targets = list(map(str, box_dict["label"]))
if use_polygons:
# (x, y) coordinates of top left, top right, bottom right, bottom left corners
box_targets: np.ndarray = np.stack(
[
np.stack([coords[:, 0], coords[:, 1]], axis=-1),
np.stack([coords[:, 0] + coords[:, 2], coords[:, 1]], axis=-1),
np.stack([coords[:, 0] + coords[:, 2], coords[:, 1] + coords[:, 3]], axis=-1),
np.stack([coords[:, 0], coords[:, 1] + coords[:, 3]], axis=-1),
],
axis=1,
)
else:
# x, y, width, height -> xmin, ymin, xmax, ymax
box_targets = np.stack(
[
coords[:, 0],
coords[:, 1],
coords[:, 0] + coords[:, 2],
coords[:, 1] + coords[:, 3],
],
axis=-1,
)
if recognition_task:
crops = crop_bboxes_from_image(img_path=os.path.join(tmp_root, img_name), geoms=box_targets)
for crop, label in zip(crops, label_targets):
if crop.shape[0] > 0 and crop.shape[1] > 0 and len(label) > 0 and " " not in label:
self.data.append((crop, label))
elif detection_task:
self.data.append((img_name, box_targets))
else:
self.data.append((img_name, dict(boxes=box_targets, labels=label_targets)))
self.root = tmp_root
def extra_repr(self) -> str:
return f"train={self.train}"