Skip to content

darts_segmentation.training.DartsDataModule

Bases: lightning.LightningDataModule

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(
    self,
    data_dir: Path,
    batch_size: int,
    fold: int = 0,  # Not used for test
    augment: bool = True,  # Not used for test
    num_workers: int = 0,
    in_memory: bool = False,
):
    super().__init__()
    self.save_hyperparameters()
    self.data_dir = data_dir
    self.batch_size = batch_size
    self.fold = fold
    self.augment = augment
    self.num_workers = num_workers
    self.in_memory = in_memory

    data_dir = Path(data_dir)

    store = zarr.storage.DirectoryStore(data_dir)
    zroot = zarr.group(store=store)
    self.nsamples = len(zroot["x"])

augment instance-attribute

augment = darts_segmentation.training.data.DartsDataModule(
    augment
)

batch_size instance-attribute

batch_size = (
    darts_segmentation.training.data.DartsDataModule(
        batch_size
    )
)

data_dir instance-attribute

data_dir = darts_segmentation.training.data.DartsDataModule(
    data_dir
)

fold instance-attribute

in_memory instance-attribute

in_memory = (
    darts_segmentation.training.data.DartsDataModule(
        in_memory
    )
)

nsamples instance-attribute

nsamples = len(zroot['x'])

num_workers instance-attribute

num_workers = (
    darts_segmentation.training.data.DartsDataModule(
        num_workers
    )
)

setup

setup(
    stage: typing.Literal[
        "fit", "validate", "test", "predict"
    ]
    | None = None,
)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None):
    if stage in ["fit", "validate"]:
        kf = KFold(n_splits=5)
        train_idx, val_idx = list(kf.split(range(self.nsamples)))[self.fold]

        dsclass = DartsDatasetInMemory if self.in_memory else DartsDatasetZarr
        self.train = dsclass(self.data_dir, self.augment, train_idx)
        self.val = dsclass(self.data_dir, False, val_idx)
    if stage == "test":
        dsclass = DartsDatasetInMemory if self.in_memory else DartsDatasetZarr
        self.test = dsclass(self.data_dir, False)

test_dataloader

test_dataloader()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def test_dataloader(self):
    return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers)

train_dataloader

train_dataloader()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def train_dataloader(self):
    return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

val_dataloader

val_dataloader()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def val_dataloader(self):
    return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers)