darts_segmentation.training.BinarySegmentationMetrics¶
Bases: lightning.pytorch.callbacks.Callback
Callback for validation metrics and visualizations.
Initialize the ValidationCallback.
Parameters:
-
input_combination
(list[str]
) –List of input names to combine for the visualization.
-
val_set
(str
, default:'val'
) –Name of the validation set. Only used for naming the validation metrics. Defaults to "val".
-
test_set
(str
, default:'test'
) –Name of the test set. Only used for naming the test metrics. Defaults to "test".
-
plot_every_n_val_epochs
(int
, default:5
) –Plot validation samples every n epochs. Defaults to 5.
-
is_crossval
(bool
, default:False
) –Whether the training is done with cross-validation. This will change the logging behavior of scalar metrics from logging to {val_set} to just "val". The logging behaviour of the samples is not affected. Defaults to False.
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
input_combination
instance-attribute
¶
input_combination = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
input_combination
)
is_crossval
instance-attribute
¶
is_crossval = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
is_crossval
)
plot_every_n_val_epochs
instance-attribute
¶
plot_every_n_val_epochs = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
plot_every_n_val_epochs
)
test_instance_cmx
instance-attribute
¶
test_instance_cmx: (
darts_segmentation.metrics.BinaryInstanceConfusionMatrix
)
test_instance_prc
instance-attribute
¶
test_instance_prc: darts_segmentation.metrics.BinaryInstancePrecisionRecallCurve
test_set
instance-attribute
¶
test_set = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
test_set
)
val_set
instance-attribute
¶
val_set = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
val_set
)
is_val_plot_epoch
¶
Check if the current epoch is an epoch where validation samples should be plotted.
Parameters:
-
current_epoch
(int
) –The current epoch.
-
check_val_every_n_epoch
(int | None
) –The number of epochs to check for plotting. If None, no plotting is done.
Returns:
-
bool
(bool
) –True if the current epoch is a plot epoch, False otherwise.
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_test_batch_end
¶
on_test_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_test_epoch_end
¶
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_train_batch_end
¶
on_train_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_train_epoch_end
¶
on_validation_batch_end
¶
on_validation_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_validation_epoch_end
¶
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
setup
¶
setup(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
stage: darts_segmentation.training.callbacks.Stage,
)
Setups the callback.
Creates metrics required for the specific stage:
- For the "fit" stage, creates training and validation metrics and visualizations.
- For the "validate" stage, only creates validation metrics and visualizations.
- For the "test" stage, only creates test metrics and visualizations.
- For the "predict" stage, no metrics or visualizations are created.
Always maps the trainer and pl_module to the callback.
Training and validation metrics are "simple" metrics from torchmetrics. The validation visualizations are more complex metrics from torchmetrics. The test metrics and vsiualizations are the same as the validation ones, and also include custom "Instance" metrics.
Parameters:
-
trainer
(lightning.Trainer
) –The lightning trainer.
-
pl_module
(lightning.LightningModule
) –The lightning module.
-
stage
(typing.Literal['fit', 'validate', 'test', 'predict']
) –The current stage. One of: "fit", "validate", "test", "predict".
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 |
|
teardown
¶
teardown(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
stage: darts_segmentation.training.callbacks.Stage,
)