Bases: torch.utils.data.Dataset
Source code in darts-segmentation/src/darts_segmentation/training/data.py
| def __init__(self, data_dir: Path | str, augment: bool, indices: list[int] | None = None):
if isinstance(data_dir, str):
data_dir = Path(data_dir)
x_files = sorted((data_dir / "x").glob("*.pt"))
y_files = sorted((data_dir / "y").glob("*.pt"))
assert len(x_files) == len(y_files), f"Dataset corrupted! Got {len(x_files)=} and {len(y_files)=}!"
if indices is not None:
x_files = [x_files[i] for i in indices]
y_files = [y_files[i] for i in indices]
self.x = []
self.y = []
for xfile, yfile in zip(x_files, y_files):
assert xfile.stem == yfile.stem, (
f"Dataset corrupted! Files must have the same name, but got {xfile=} {yfile=}!"
)
x = torch.load(xfile).numpy()
y = torch.load(yfile).int().numpy()
self.x.append(x)
self.y.append(y)
self.transform = (
A.Compose(
[
A.HorizontalFlip(),
A.VerticalFlip(),
A.RandomRotate90(),
# A.Blur(),
A.RandomBrightnessContrast(),
A.MultiplicativeNoise(per_channel=True, elementwise=True),
# ToTensorV2(),
]
)
if augment
else None
)
|
transform = (
albumentations.Compose(
[
albumentations.HorizontalFlip(),
albumentations.VerticalFlip(),
albumentations.RandomRotate90(),
albumentations.RandomBrightnessContrast(),
albumentations.MultiplicativeNoise(
per_channel=True, elementwise=True
),
]
)
if darts_segmentation.training.data.DartsDatasetInMemory(
augment
)
else None
)
__getitem__
Source code in darts-segmentation/src/darts_segmentation/training/data.py
| def __getitem__(self, idx):
x = self.x[idx]
y = self.y[idx]
# Apply augmentations
if self.transform is not None:
augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
x = augmented["image"].transpose(2, 0, 1)
y = augmented["mask"]
return x, y
|
__len__
Source code in darts-segmentation/src/darts_segmentation/training/data.py
| def __len__(self):
return len(self.x)
|