Skip to content

darts_segmentation.training.DartsDataset

Bases: torch.utils.data.Dataset

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(self, data_dir: Path | str, augment: bool, indices: list[int] | None = None):
    if isinstance(data_dir, str):
        data_dir = Path(data_dir)

    self.x_files = sorted((data_dir / "x").glob("*.pt"))
    self.y_files = sorted((data_dir / "y").glob("*.pt"))
    assert len(self.x_files) == len(self.y_files), (
        f"Dataset corrupted! Got {len(self.x_files)=} and {len(self.y_files)=}!"
    )
    if indices is not None:
        self.x_files = [self.x_files[i] for i in indices]
        self.y_files = [self.y_files[i] for i in indices]

    self.transform = (
        A.Compose(
            [
                A.HorizontalFlip(),
                A.VerticalFlip(),
                A.RandomRotate90(),
                # A.Blur(),
                A.RandomBrightnessContrast(),
                A.MultiplicativeNoise(per_channel=True, elementwise=True),
                # ToTensorV2(),
            ]
        )
        if augment
        else None
    )

transform instance-attribute

transform = (
    albumentations.Compose(
        [
            albumentations.HorizontalFlip(),
            albumentations.VerticalFlip(),
            albumentations.RandomRotate90(),
            albumentations.RandomBrightnessContrast(),
            albumentations.MultiplicativeNoise(
                per_channel=True, elementwise=True
            ),
        ]
    )
    if darts_segmentation.training.data.DartsDataset(
        augment
    )
    else None
)

x_files instance-attribute

x_files = sorted(
    darts_segmentation.training.data.DartsDataset(data_dir)
    / "x".glob("*.pt")
)

y_files instance-attribute

y_files = sorted(
    darts_segmentation.training.data.DartsDataset(data_dir)
    / "y".glob("*.pt")
)

__getitem__

__getitem__(idx)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __getitem__(self, idx):
    xfile = self.x_files[idx]
    yfile = self.y_files[idx]
    assert xfile.stem == yfile.stem, f"Dataset corrupted! Files must have the same name, but got {xfile=} {yfile=}!"

    x = torch.load(xfile).numpy()
    y = torch.load(yfile).int().numpy()

    # Apply augmentations
    if self.transform is not None:
        augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
        x = augmented["image"].transpose(2, 0, 1)
        y = augmented["mask"]

    return x, y

__len__

__len__()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __len__(self):
    return len(self.x_files)