darts_segmentation.training.data
¶
Training script for DARTS segmentation.
Augmentation
module-attribute
¶
Augmentation = typing.Literal[
"HorizontalFlip",
"VerticalFlip",
"RandomRotate90",
"D4",
"Blur",
"RandomBrightnessContrast",
"MultiplicativeNoise",
"Posterize",
]
DartsDataModule
¶
DartsDataModule(
data_dir: pathlib.Path,
batch_size: int,
data_split_method: typing.Literal[
"random", "region", "sample"
]
| None = None,
data_split_by: list[str | float] | None = None,
fold_method: typing.Literal[
"kfold",
"shuffle",
"stratified",
"region",
"region-stratified",
]
| None = "kfold",
total_folds: int = 5,
fold: int = 0,
subsample: int | None = None,
bands: list[str] | None = None,
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None = None,
num_workers: int = 0,
in_memory: bool = False,
)
Bases: lightning.LightningDataModule
Initialize the data module.
Supports spliting the data into train and test set while also defining cv-folds. Folding only applies to the non-test set and splits this into a train and validation set.
Example
-
Normal train-validate. (Can also be used for testing on the complete dataset)
-
Specifying a test split by random (20% of the data will be used for testing)
-
Specific fold for cross-validation (On the complete dataset, because data_split_method is "none"). This will be take the third of a total of7 folds to determine the validation set.
In general this should be used in combination with a cross-validation loop.
for fold in range(total_folds):
dm = DartsDataModule(
data_dir,
batch_size,
fold_method="region-stratified",
fold=fold,
total_folds=total_folds)
...
- Don't split anything -> only train
Parameters:
-
data_dir(pathlib.Path) –The path to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script.
-
batch_size(int) –Batch size for training and validation.
-
data_split_method(typing.Literal['random', 'region', 'sample'] | None, default:None) –The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the test size can be specified by providing a list with a single a float between 0 and 1 to data_split_by This will be the fraction of the data to be used for testing. E.g. [0.2] will use 20% of the data for testing. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.
-
data_split_by(list[str | float] | None, default:None) –Select by which regions/samples to split or the size of test set. Defaults to None.
-
fold_method(typing.Literal['kfold', 'shuffle', 'stratified', 'region', 'region-stratified'] | None, default:'kfold') –Method for cross-validation split. Defaults to "kfold".
-
total_folds(int, default:5) –Total number of folds in cross-validation. Defaults to 5.
-
fold(int, default:0) –Index of the current fold. Defaults to 0.
-
subsample(int | None, default:None) –If set, will subsample the dataset to this number of samples. This is useful for debugging and testing. Defaults to None.
-
bands(Bands | list[str] | None, default:None) –List of bands to use. Expects the data_dir to contain a config.toml with a "darts.bands" key, with which the indices of the bands will be mapped to. Defaults to None.
-
augment(bool, default:None) –Whether to augment the data. Does nothing for testing. Defaults to True.
-
num_workers(int, default:0) –Number of workers for data loading. See torch.utils.data.DataLoader. Defaults to 0.
-
in_memory(bool, default:False) –Whether to load the data into memory. Defaults to False.
Source code in darts-segmentation/src/darts_segmentation/training/data.py
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 | |
bands
instance-attribute
¶
bands = (
[
(data_bands.index(b))
for b in darts_segmentation.training.data.DartsDataModule(
bands
)
]
if darts_segmentation.training.data.DartsDataModule(
bands
)
else None
)
batch_size
instance-attribute
¶
batch_size = (
darts_segmentation.training.data.DartsDataModule(
batch_size
)
)
data_dir
instance-attribute
¶
data_dir = darts_segmentation.training.data.DartsDataModule(
data_dir
)
data_split_by
instance-attribute
¶
data_split_by = (
darts_segmentation.training.data.DartsDataModule(
data_split_by
)
)
data_split_method
instance-attribute
¶
data_split_method = (
darts_segmentation.training.data.DartsDataModule(
data_split_method
)
)
fold_method
instance-attribute
¶
fold_method = (
darts_segmentation.training.data.DartsDataModule(
fold_method
)
)
in_memory
instance-attribute
¶
in_memory = (
darts_segmentation.training.data.DartsDataModule(
in_memory
)
)
num_workers
instance-attribute
¶
num_workers = (
darts_segmentation.training.data.DartsDataModule(
num_workers
)
)
subsample
instance-attribute
¶
subsample = (
darts_segmentation.training.data.DartsDataModule(
subsample
)
)
total_folds
instance-attribute
¶
total_folds = (
darts_segmentation.training.data.DartsDataModule(
total_folds
)
)
plot
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
setup
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
test_dataloader
¶
train_dataloader
¶
DartsDatasetInMemory
¶
DartsDatasetInMemory(
data_dir: pathlib.Path | str,
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None = None,
indices: list[int] | None = None,
bands: list[int] | None = None,
)
Bases: torch.utils.data.Dataset
Source code in darts-segmentation/src/darts_segmentation/training/data.py
transform
instance-attribute
¶
transform = darts_segmentation.training.augmentations.get_augmentation(
darts_segmentation.training.data.DartsDatasetInMemory(
augment
)
)
__getitem__
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
DartsDatasetZarr
¶
DartsDatasetZarr(
data_dir: pathlib.Path | str,
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None = None,
indices: list[int] | None = None,
bands: list[int] | None = None,
)
Bases: torch.utils.data.Dataset
Source code in darts-segmentation/src/darts_segmentation/training/data.py
indices
instance-attribute
¶
indices = (
darts_segmentation.training.data.DartsDatasetZarr(
indices
)
if darts_segmentation.training.data.DartsDatasetZarr(
indices
)
is not None
else list(
range(
darts_segmentation.training.data.DartsDatasetZarr(
self
)
.zroot["x"]
.shape[0]
)
)
)
transform
instance-attribute
¶
transform = darts_segmentation.training.augmentations.get_augmentation(
darts_segmentation.training.data.DartsDatasetZarr(
augment
)
)
__getitem__
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
_get_fold
¶
_get_fold(
metadata: geopandas.GeoDataFrame,
fold_method: typing.Literal[
"kfold",
"shuffle",
"stratified",
"region",
"region-stratified",
"none",
]
| None,
n_folds: int,
fold: int,
) -> tuple[list[int], list[int]]
Source code in darts-segmentation/src/darts_segmentation/training/data.py
_log_stats
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
_split_metadata
¶
_split_metadata(
metadata: geopandas.GeoDataFrame,
data_split_method: typing.Literal[
"random", "region", "sample", "none"
]
| None,
data_split_by: list[str | float] | None,
)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
get_augmentation
¶
get_augmentation(
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None,
always_apply: bool = False,
) -> albumentations.Compose | None
Get augmentations for segmentation tasks.
Parameters:
-
augment(list[darts_segmentation.training.augmentations.Augmentation] | None) –List of augmentations to apply. If None or emtpy, no augmentations are applied. If not empty, augmentations are applied in the order they are listed. Available augmentations: - D4 (Combination of HorizontalFlip, VerticalFlip, and RandomRotate90) - Blur - RandomBrightnessContrast - MultiplicativeNoise - Posterize (quantization to reduce number of bits per channel)
-
always_apply(bool, default:False) –If True, augmentations are always applied. This is useful for visualization/testing augmentations. Default is False.
Raises:
-
ValueError–If an unknown augmentation is provided.
Returns:
-
albumentations.Compose | None–A.Compose | None: A Compose object containing the augmentations. If no augmentations are provided, returns None.
Source code in darts-segmentation/src/darts_segmentation/training/augmentations.py
plot_training_data_distribution
¶
plot_training_data_distribution(
train_metadata: geopandas.GeoDataFrame,
val_metadata: geopandas.GeoDataFrame | None,
test_metadata: geopandas.GeoDataFrame | None,
name: str,
) -> tuple[
matplotlib.pyplot.Figure,
dict[str, matplotlib.pyplot.Axes],
]
Plot the distribution of training data by region on a polar projection.
Parameters:
-
train_metadata(geopandas.GeoDataFrame) –GeoDataFrame containing training metadata.
-
val_metadata(geopandas.GeoDataFrame | None) –GeoDataFrame containing validation metadata.
-
test_metadata(geopandas.GeoDataFrame | None) –GeoDataFrame containing test metadata.
-
name(str) –Name of the dataset or experiment for the plot title.
Returns:
-
tuple[matplotlib.pyplot.Figure, dict[str, matplotlib.pyplot.Axes]]–tuple[plt.Figure, plt.Axes]: The figure and axes of the plot.
Source code in darts-segmentation/src/darts_segmentation/training/viz.py
246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 | |