Skip to content

darts_segmentation.training.DartsDatasetInMemory

Bases: torch.utils.data.Dataset

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(self, data_dir: Path | str, augment: bool, indices: list[int] | None = None):
    if isinstance(data_dir, str):
        data_dir = Path(data_dir)

    x_files = sorted((data_dir / "x").glob("*.pt"))
    y_files = sorted((data_dir / "y").glob("*.pt"))
    assert len(x_files) == len(y_files), f"Dataset corrupted! Got {len(x_files)=} and {len(y_files)=}!"
    if indices is not None:
        x_files = [x_files[i] for i in indices]
        y_files = [y_files[i] for i in indices]

    self.x = []
    self.y = []
    for xfile, yfile in zip(x_files, y_files):
        assert xfile.stem == yfile.stem, (
            f"Dataset corrupted! Files must have the same name, but got {xfile=} {yfile=}!"
        )
        x = torch.load(xfile).numpy()
        y = torch.load(yfile).int().numpy()
        self.x.append(x)
        self.y.append(y)

    self.transform = (
        A.Compose(
            [
                A.HorizontalFlip(),
                A.VerticalFlip(),
                A.RandomRotate90(),
                # A.Blur(),
                A.RandomBrightnessContrast(),
                A.MultiplicativeNoise(per_channel=True, elementwise=True),
                # ToTensorV2(),
            ]
        )
        if augment
        else None
    )

transform instance-attribute

transform = (
    albumentations.Compose(
        [
            albumentations.HorizontalFlip(),
            albumentations.VerticalFlip(),
            albumentations.RandomRotate90(),
            albumentations.RandomBrightnessContrast(),
            albumentations.MultiplicativeNoise(
                per_channel=True, elementwise=True
            ),
        ]
    )
    if darts_segmentation.training.data.DartsDatasetInMemory(
        augment
    )
    else None
)

x instance-attribute

x = []

y instance-attribute

y = []

__getitem__

__getitem__(idx)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __getitem__(self, idx):
    x = self.x[idx]
    y = self.y[idx]

    # Apply augmentations
    if self.transform is not None:
        augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
        x = augmented["image"].transpose(2, 0, 1)
        y = augmented["mask"]

    return x, y

__len__

__len__()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __len__(self):
    return len(self.x)