darts_segmentation.training¶
Training related functions and classes for Image Segmentation.
Classes:
-
BinarySegmentationMetrics
–Callback for validation metrics and visualizations.
-
DartsDataModule
– -
DartsDataset
– -
DartsDatasetInMemory
– -
DartsDatasetZarr
– -
SMPSegmenter
–Lightning module for training a segmentation model using the segmentation_models_pytorch library.
Functions:
-
create_training_patches
–Create training patches from a tile and labels.
BinarySegmentationMetrics
¶
BinarySegmentationMetrics(
*,
input_combination: list[str],
val_set: str = "val",
test_set: str = "test",
plot_every_n_val_epochs: int = 5,
is_crossval: bool = False,
)
Bases: lightning.pytorch.callbacks.Callback
Callback for validation metrics and visualizations.
Initialize the ValidationCallback.
Parameters:
-
input_combination
(list[str]
) –List of input names to combine for the visualization.
-
val_set
(str
, default:'val'
) –Name of the validation set. Only used for naming the validation metrics. Defaults to "val".
-
test_set
(str
, default:'test'
) –Name of the test set. Only used for naming the test metrics. Defaults to "test".
-
plot_every_n_val_epochs
(int
, default:5
) –Plot validation samples every n epochs. Defaults to 5.
-
is_crossval
(bool
, default:False
) –Whether the training is done with cross-validation. This will change the logging behavior of scalar metrics from logging to {val_set} to just "val". The logging behaviour of the samples is not affected. Defaults to False.
Methods:
-
is_val_plot_epoch
–Check if the current epoch is an epoch where validation samples should be plotted.
-
on_test_batch_end
– -
on_test_epoch_end
– -
on_train_batch_end
– -
on_train_epoch_end
– -
on_validation_batch_end
– -
on_validation_epoch_end
– -
setup
–Setups the callback.
-
teardown
–
Attributes:
-
input_combination
– -
is_crossval
– -
pl_module
(lightning.LightningModule
) – -
plot_every_n_val_epochs
– -
stage
(darts_segmentation.training.callbacks.Stage
) – -
test_cmx
(torchmetrics.ConfusionMatrix
) – -
test_instance_cmx
(darts_segmentation.metrics.BinaryInstanceConfusionMatrix
) – -
test_instance_prc
(darts_segmentation.metrics.BinaryInstancePrecisionRecallCurve
) – -
test_metrics
(torchmetrics.MetricCollection
) – -
test_prc
(torchmetrics.PrecisionRecallCurve
) – -
test_roc
(torchmetrics.ROC
) – -
test_set
– -
train_metrics
(torchmetrics.MetricCollection
) – -
trainer
(lightning.Trainer
) – -
val_cmx
(torchmetrics.ConfusionMatrix
) – -
val_metrics
(torchmetrics.MetricCollection
) – -
val_prc
(torchmetrics.PrecisionRecallCurve
) – -
val_roc
(torchmetrics.ROC
) – -
val_set
–
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
input_combination
instance-attribute
¶
input_combination = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
input_combination
)
is_crossval
instance-attribute
¶
is_crossval = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
is_crossval
)
plot_every_n_val_epochs
instance-attribute
¶
plot_every_n_val_epochs = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
plot_every_n_val_epochs
)
test_instance_cmx
instance-attribute
¶
test_instance_cmx: (
darts_segmentation.metrics.BinaryInstanceConfusionMatrix
)
test_instance_prc
instance-attribute
¶
test_instance_prc: darts_segmentation.metrics.BinaryInstancePrecisionRecallCurve
test_set
instance-attribute
¶
test_set = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
test_set
)
val_set
instance-attribute
¶
val_set = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
val_set
)
is_val_plot_epoch
¶
Check if the current epoch is an epoch where validation samples should be plotted.
Parameters:
-
current_epoch
(int
) –The current epoch.
-
check_val_every_n_epoch
(int | None
) –The number of epochs to check for plotting. If None, no plotting is done.
Returns:
-
bool
(bool
) –True if the current epoch is a plot epoch, False otherwise.
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_test_batch_end
¶
on_test_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_test_epoch_end
¶
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_train_batch_end
¶
on_train_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_train_epoch_end
¶
on_validation_batch_end
¶
on_validation_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_validation_epoch_end
¶
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
setup
¶
setup(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
stage: darts_segmentation.training.callbacks.Stage,
)
Setups the callback.
Creates metrics required for the specific stage:
- For the "fit" stage, creates training and validation metrics and visualizations.
- For the "validate" stage, only creates validation metrics and visualizations.
- For the "test" stage, only creates test metrics and visualizations.
- For the "predict" stage, no metrics or visualizations are created.
Always maps the trainer and pl_module to the callback.
Training and validation metrics are "simple" metrics from torchmetrics. The validation visualizations are more complex metrics from torchmetrics. The test metrics and vsiualizations are the same as the validation ones, and also include custom "Instance" metrics.
Parameters:
-
trainer
(lightning.Trainer
) –The lightning trainer.
-
pl_module
(lightning.LightningModule
) –The lightning module.
-
stage
(typing.Literal['fit', 'validate', 'test', 'predict']
) –The current stage. One of: "fit", "validate", "test", "predict".
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
|
teardown
¶
teardown(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
stage: darts_segmentation.training.callbacks.Stage,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
DartsDataModule
¶
DartsDataModule(
data_dir: pathlib.Path,
batch_size: int,
fold: int = 0,
augment: bool = True,
num_workers: int = 0,
in_memory: bool = False,
)
Bases: lightning.LightningDataModule
Methods:
Attributes:
-
augment
– -
batch_size
– -
data_dir
– -
fold
– -
in_memory
– -
nsamples
– -
num_workers
–
Source code in darts-segmentation/src/darts_segmentation/training/data.py
batch_size
instance-attribute
¶
batch_size = (
darts_segmentation.training.data.DartsDataModule(
batch_size
)
)
data_dir
instance-attribute
¶
data_dir = darts_segmentation.training.data.DartsDataModule(
data_dir
)
in_memory
instance-attribute
¶
in_memory = (
darts_segmentation.training.data.DartsDataModule(
in_memory
)
)
num_workers
instance-attribute
¶
num_workers = (
darts_segmentation.training.data.DartsDataModule(
num_workers
)
)
setup
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
test_dataloader
¶
train_dataloader
¶
DartsDataset
¶
Bases: torch.utils.data.Dataset
Methods:
-
__getitem__
– -
__len__
–
Attributes:
Source code in darts-segmentation/src/darts_segmentation/training/data.py
transform
instance-attribute
¶
transform = (
albumentations.Compose(
[
albumentations.HorizontalFlip(),
albumentations.VerticalFlip(),
albumentations.RandomRotate90(),
albumentations.RandomBrightnessContrast(),
albumentations.MultiplicativeNoise(
per_channel=True, elementwise=True
),
]
)
if darts_segmentation.training.data.DartsDataset(
augment
)
else None
)
x_files
instance-attribute
¶
x_files = sorted(
darts_segmentation.training.data.DartsDataset(data_dir)
/ "x".glob("*.pt")
)
y_files
instance-attribute
¶
y_files = sorted(
darts_segmentation.training.data.DartsDataset(data_dir)
/ "y".glob("*.pt")
)
__getitem__
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
DartsDatasetInMemory
¶
DartsDatasetInMemory(
data_dir: pathlib.Path | str,
augment: bool,
indices: list[int] | None = None,
)
Bases: torch.utils.data.Dataset
Methods:
-
__getitem__
– -
__len__
–
Attributes:
Source code in darts-segmentation/src/darts_segmentation/training/data.py
transform
instance-attribute
¶
transform = (
albumentations.Compose(
[
albumentations.HorizontalFlip(),
albumentations.VerticalFlip(),
albumentations.RandomRotate90(),
albumentations.RandomBrightnessContrast(),
albumentations.MultiplicativeNoise(
per_channel=True, elementwise=True
),
]
)
if darts_segmentation.training.data.DartsDatasetInMemory(
augment
)
else None
)
__getitem__
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
DartsDatasetZarr
¶
Bases: torch.utils.data.Dataset
Methods:
-
__getitem__
– -
__len__
–
Attributes:
Source code in darts-segmentation/src/darts_segmentation/training/data.py
indices
instance-attribute
¶
indices = (
darts_segmentation.training.data.DartsDatasetZarr(
indices
)
if darts_segmentation.training.data.DartsDatasetZarr(
indices
)
is not None
else list(
range(
darts_segmentation.training.data.DartsDatasetZarr(
self
)
.zroot["x"]
.shape[0]
)
)
)
transform
instance-attribute
¶
transform = (
albumentations.Compose(
[
albumentations.HorizontalFlip(),
albumentations.VerticalFlip(),
albumentations.RandomRotate90(),
albumentations.RandomBrightnessContrast(),
albumentations.MultiplicativeNoise(
per_channel=True, elementwise=True
),
]
)
if darts_segmentation.training.data.DartsDatasetZarr(
augment
)
else None
)
__getitem__
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
SMPSegmenter
¶
SMPSegmenter(
config: darts_segmentation.segment.SMPSegmenterConfig,
learning_rate: float = 1e-05,
gamma: float = 0.9,
focal_loss_alpha: float | None = None,
focal_loss_gamma: float = 2.0,
**kwargs: dict[str, typing.Any],
)
Bases: lightning.LightningModule
Lightning module for training a segmentation model using the segmentation_models_pytorch library.
Initialize the SMPSegmenter.
Parameters:
-
config
(darts_segmentation.segment.SMPSegmenterConfig
) –Configuration for the segmentation model.
-
learning_rate
(float
, default:1e-05
) –Initial learning rate. Defaults to 1e-5.
-
gamma
(float
, default:0.9
) –Multiplicative factor of learning rate decay. Defaults to 0.9.
-
focal_loss_alpha
(float
, default:None
) –Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.
-
focal_loss_gamma
(float
, default:2.0
) –Focal loss power factor. Defaults to 2.0.
-
kwargs
(dict[str, typing.Any]
, default:{}
) –Additional keyword arguments which should be saved to the hyperparameter file.
Methods:
-
__repr__
– -
configure_optimizers
– -
on_train_epoch_end
– -
test_step
– -
training_step
– -
validation_step
–
Attributes:
Source code in darts-segmentation/src/darts_segmentation/training/module.py
loss_fn
instance-attribute
¶
loss_fn = segmentation_models_pytorch.losses.FocalLoss(
mode="binary",
alpha=darts_segmentation.training.module.SMPSegmenter(
focal_loss_alpha
),
gamma=darts_segmentation.training.module.SMPSegmenter(
focal_loss_gamma
),
ignore_index=2,
)
model
instance-attribute
¶
model = segmentation_models_pytorch.create_model(
**darts_segmentation.training.module.SMPSegmenter(
config
)["model"],
activation="sigmoid",
)
__repr__
¶
configure_optimizers
¶
Source code in darts-segmentation/src/darts_segmentation/training/module.py
on_train_epoch_end
¶
test_step
¶
training_step
¶
validation_step
¶
create_training_patches
¶
create_training_patches(
tile: xarray.Dataset,
labels: geopandas.GeoDataFrame,
bands: list[str],
norm_factors: dict[str, float],
patch_size: int,
overlap: int,
exclude_nopositive: bool,
exclude_nan: bool,
device: typing.Literal["cuda", "cpu"] | int,
mask_erosion_size: int,
) -> collections.abc.Generator[
tuple[torch.tensor, torch.tensor]
]
Create training patches from a tile and labels.
Parameters:
-
tile
(xarray.Dataset
) –The input tile, containing preprocessed, harmonized data.
-
labels
(geopandas.GeoDataFrame
) –The labels to be used for training.
-
bands
(list[str]
) –The bands to be used for training. Must be present in the tile.
-
norm_factors
(dict[str, float]
) –The normalization factors for the bands.
-
patch_size
(int
) –The size of the patches.
-
overlap
(int
) –The size of the overlap.
-
exclude_nopositive
(bool
) –Whether to exclude patches where the labels do not contain positives.
-
exclude_nan
(bool
) –Whether to exclude patches where the input data has nan values.
-
device
(typing.Literal['cuda', 'cpu'] | int
) –The device to use for the erosion.
-
mask_erosion_size
(int
) –The size of the disk to use for erosion.
Yields:
-
collections.abc.Generator[tuple[torch.tensor, torch.tensor]]
–Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors. The input has the format (C, H, W), the labels (H, W).
Raises:
-
ValueError
–If a band is not found in the preprocessed data.
Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
|