Bases: lightning.LightningDataModule
Source code in darts-segmentation/src/darts_segmentation/training/data.py
| def __init__(
self,
data_dir: Path,
batch_size: int,
fold: int = 0, # Not used for test
augment: bool = True, # Not used for test
num_workers: int = 0,
in_memory: bool = False,
):
super().__init__()
self.save_hyperparameters()
self.data_dir = data_dir
self.batch_size = batch_size
self.fold = fold
self.augment = augment
self.num_workers = num_workers
self.in_memory = in_memory
data_dir = Path(data_dir)
store = zarr.storage.DirectoryStore(data_dir)
zroot = zarr.group(store=store)
self.nsamples = len(zroot["x"])
|
augment
instance-attribute
batch_size
instance-attribute
data_dir
instance-attribute
in_memory
instance-attribute
nsamples
instance-attribute
nsamples = len(zroot['x'])
num_workers
instance-attribute
setup
setup(
stage: typing.Literal[
"fit", "validate", "test", "predict"
]
| None = None,
)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
| def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None):
if stage in ["fit", "validate"]:
kf = KFold(n_splits=5)
train_idx, val_idx = list(kf.split(range(self.nsamples)))[self.fold]
dsclass = DartsDatasetInMemory if self.in_memory else DartsDatasetZarr
self.train = dsclass(self.data_dir, self.augment, train_idx)
self.val = dsclass(self.data_dir, False, val_idx)
if stage == "test":
dsclass = DartsDatasetInMemory if self.in_memory else DartsDatasetZarr
self.test = dsclass(self.data_dir, False)
|
test_dataloader
Source code in darts-segmentation/src/darts_segmentation/training/data.py
| def test_dataloader(self):
return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers)
|
train_dataloader
Source code in darts-segmentation/src/darts_segmentation/training/data.py
| def train_dataloader(self):
return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)
|
val_dataloader
Source code in darts-segmentation/src/darts_segmentation/training/data.py
| def val_dataloader(self):
return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers)
|