Skip to content

darts_segmentation.training.DartsDatasetZarr

Bases: torch.utils.data.Dataset

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(self, data_dir: Path | str, augment: bool, indices: list[int] | None = None):
    if isinstance(data_dir, str):
        data_dir = Path(data_dir)

    store = zarr.storage.LocalStore(data_dir)
    self.zroot = zarr.group(store=store)

    assert "x" in self.zroot and "y" in self.zroot, (
        f"Dataset corrupted! {self.zroot.info=} must contain 'x' or 'y' arrays!"
    )

    self.indices = indices if indices is not None else list(range(self.zroot["x"].shape[0]))

    self.transform = (
        A.Compose(
            [
                A.HorizontalFlip(),
                A.VerticalFlip(),
                A.RandomRotate90(),
                # A.Blur(),
                A.RandomBrightnessContrast(),
                A.MultiplicativeNoise(per_channel=True, elementwise=True),
                # ToTensorV2(),
            ]
        )
        if augment
        else None
    )

indices instance-attribute

indices = (
    darts_segmentation.training.data.DartsDatasetZarr(
        indices
    )
    if darts_segmentation.training.data.DartsDatasetZarr(
        indices
    )
    is not None
    else list(
        range(
            darts_segmentation.training.data.DartsDatasetZarr(
                self
            )
            .zroot["x"]
            .shape[0]
        )
    )
)

transform instance-attribute

transform = (
    albumentations.Compose(
        [
            albumentations.HorizontalFlip(),
            albumentations.VerticalFlip(),
            albumentations.RandomRotate90(),
            albumentations.RandomBrightnessContrast(),
            albumentations.MultiplicativeNoise(
                per_channel=True, elementwise=True
            ),
        ]
    )
    if darts_segmentation.training.data.DartsDatasetZarr(
        augment
    )
    else None
)

zroot instance-attribute

zroot = zarr.group(store=store)

__getitem__

__getitem__(idx)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __getitem__(self, idx):
    i = self.indices[idx]

    x = self.zroot["x"][i]
    y = self.zroot["y"][i]

    # Apply augmentations
    if self.transform is not None:
        augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
        x = augmented["image"].transpose(2, 0, 1)
        y = augmented["mask"]

    return x, y

__len__

__len__()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __len__(self):
    return len(self.indices)