Cross-Validation¶
Fold strategies¶
While cross-validating, the data can further be split into a training and validation set.
One can specify the fraction of the validation set by providing an integer to total_folds
.
Higher values will result in smaller, validation sets and therefore more fold-combinations.
To reduce the number of folds actually run, one can provide the n_folds
parameter to limit the number of folds actually run.
Thus, some folds will be skipped.
The "folding" is based on scikit-learn
and currently supports the following folding methods, which can be specified by the fold_method
parameter:
"kfold"
: Split the data intototal_folds
folds, where each fold can be used as a validation set. Uses sklearn.model_selection.KFold."stratified"
: Will use the"empty"
column of the metadata to createtotal_folds
shuffled folds where each fold contains the same amount of empty and non-empty samples. Uses sklearn.model_selection.StratifiedKFold."shuffle"
: Similar to"stratified"
, but the order of the data is shuffled before splitting. Uses sklearn.model_selection.StratifiedShuffleSplit."region"
: Will use the"region"
column of the metadata to createtotal_folds
folds where each fold splits the data by one or multiple regions. Uses sklearn.model_selection.GroupShuffleSplit."region-stratified"
: Merge of the"region"
and"stratified"
methods. Uses sklearn.model_selection.StratifiedGroupKFold.
Even in normal training a single KFold split is used to split between training and validation.
This can be disabled by setting fold_method
to None
.
In such cases, the validation set becomes equal to the training set, meaning longer validation time and the metrics are always calculated on seen data.
This is useful for e.g. the final training of a model before deployment.
Using DartsDataModule
The data splitting is implemented by the darts_segmentation.training.data.DartsDataModule and can therefore be used in other settings as well.
darts_segmentation.training.data.DartsDataModule
¶
DartsDataModule(
data_dir: pathlib.Path,
batch_size: int,
data_split_method: typing.Literal[
"random", "region", "sample"
]
| None = None,
data_split_by: list[str | float] | None = None,
fold_method: typing.Literal[
"kfold",
"shuffle",
"stratified",
"region",
"region-stratified",
]
| None = "kfold",
total_folds: int = 5,
fold: int = 0,
subsample: int | None = None,
bands: darts_segmentation.utils.Bands
| list[str]
| None = None,
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None = None,
num_workers: int = 0,
in_memory: bool = False,
)
Bases: lightning.LightningDataModule
Initialize the data module.
Supports spliting the data into train and test set while also defining cv-folds. Folding only applies to the non-test set and splits this into a train and validation set.
Example
-
Normal train-validate. (Can also be used for testing on the complete dataset)
-
Specifying a test split by random (20% of the data will be used for testing)
-
Specific fold for cross-validation (On the complete dataset, because data_split_method is "none"). This will be take the third of a total of7 folds to determine the validation set.
In general this should be used in combination with a cross-validation loop.
for fold in range(total_folds):
dm = DartsDataModule(
data_dir,
batch_size,
fold_method="region-stratified",
fold=fold,
total_folds=total_folds)
...
- Don't split anything -> only train
Parameters:
-
data_dir
(pathlib.Path
) –The path to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script.
-
batch_size
(int
) –Batch size for training and validation.
-
data_split_method
(typing.Literal['random', 'region', 'sample'] | None
, default:None
) –The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the test size can be specified by providing a list with a single a float between 0 and 1 to data_split_by This will be the fraction of the data to be used for testing. E.g. [0.2] will use 20% of the data for testing. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.
-
data_split_by
(list[str | float] | None
, default:None
) –Select by which regions/samples to split or the size of test set. Defaults to None.
-
fold_method
(typing.Literal['kfold', 'shuffle', 'stratified', 'region', 'region-stratified'] | None
, default:'kfold'
) –Method for cross-validation split. Defaults to "kfold".
-
total_folds
(int
, default:5
) –Total number of folds in cross-validation. Defaults to 5.
-
fold
(int
, default:0
) –Index of the current fold. Defaults to 0.
-
subsample
(int | None
, default:None
) –If set, will subsample the dataset to this number of samples. This is useful for debugging and testing. Defaults to None.
-
bands
(darts_segmentation.utils.Bands | list[str] | None
, default:None
) –List of bands to use. Expects the data_dir to contain a config.toml with a "darts.bands" key, with which the indices of the bands will be mapped to. Defaults to None.
-
augment
(bool
, default:None
) –Whether to augment the data. Does nothing for testing. Defaults to True.
-
num_workers
(int
, default:0
) –Number of workers for data loading. See torch.utils.data.DataLoader. Defaults to 0.
-
in_memory
(bool
, default:False
) –Whether to load the data into memory. Defaults to False.
Source code in darts-segmentation/src/darts_segmentation/training/data.py
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
|
Scoring strategies¶
To turn the information (metrics) gathered of a single cross-validation into a useful score, we need to somehow aggregate the metrics.
In cases we are only interested in a single metric, this is easy: we can easily compute the mean.
This metric can be specified by the scoring_metric
parameter of the cross validation.
It is also possible to use multiple metrics by specifying a list of metrics in the scoring_metric
parameter.
This, however, makes it a little more complicated.
Multi-metric scoring is implemented as combine-then-reduce, meaning that first for each fold the metrics are combined using the specified strategy, and then the results are reduced via mean.
The combining strategy can be specified by the multi_score_strategy
parameter.
As of now, there are four strategies implemented: "arithmetic"
, "geometric"
, "harmonic"
and "min"
.
The following visualization should help visualize how the different strategies work. Note that the loss is interpreted as "lower is better" and has also a broader range of possible values, exceeding 1. For the multi-metric scoring with IoU and Loss the arithmetic and geometric strategies are very instable. The scores for very low loss values where so high that the scores needed to be clipped to the range [0, 1] for the visualization to be able to show the behaviour of these strategies. However, especially the geometric mean shows a smoother curve than the harmonic mean for the multi-metric scoring with IoU and Recall. This should show that the strategy should be chosen carefully and in respect to the metrics used.
IoU & Loss | ![]() |
IoU & Recall | ![]() |
Code to reproduce the visualization
If you are unsure which strategy to use, you can use this code snippet to make a visualization based on your metrics:
import numpy as np
import xarray as xr
a = np.arange(0, 1, 0.01)
a = xr.DataArray(a, dims=["a"], coords={"a": a})
# 1 / ... indicates "lower is better" - replace it if needed
b = np.arange(0, 2, 0.01)
b = 1 / xr.DataArray(b, dims=["b"], coords={"b": b})
def viz_strategies(a, b):
harmonic = 2 / (1 / a + 1 / b)
geometric = np.sqrt(a * b)
arithmetic = (a + b) / 2
minimum = np.minimum(a, b)
harmonic = harmonic.rename("harmonic mean")
geometric = geometric.rename("geometric mean")
arithmetic = arithmetic.rename("arithmetic mean")
minimum = minimum.rename("minimum")
fig, axs = plt.subplots(1, 4, figsize=(25, 5))
axs = axs.flatten()
harmonic.plot(ax=axs[0])
axs[0].set_title("Harmonic")
geometric.plot(ax=axs[1], vmax=min(geometric.max(), 1))
axs[1].set_title("Geometric")
arithmetic.plot(ax=axs[2], vmax=min(arithmetic.max(), 1))
axs[2].set_title("Arithmetic")
minimum.plot(ax=axs[3])
axs[3].set_title("Minimum")
return fig
viz_strategies(a, b).show()
Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics. This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower". If no direction is provided, it is assumed to be ":higher". Has no real effect on the single score calculation, since only the mean is calculated there.
Available metrics
The following metrics are visible to the scoring function:
'train/time'
'train/device/batches_per_second'
'train/device/samples_per_second'
'train/device/flops_per_second'
'train/device/mfu'
'train/loss'
'train/Accuracy'
'train/CohenKappa'
'train/F1Score'
'train/HammingDistance'
'train/JaccardIndex'
'train/Precision'
'train/Recall'
'train/Specificity'
'val/loss'
'val/Accuracy'
'val/CohenKappa'
'val/F1Score'
'val/HammingDistance'
'val/JaccardIndex'
'val/Precision'
'val/Recall'
'val/Specificity'
'val/AUROC'
'val/AveragePrecision'
These are derived from trainer.logged_metrics
.
Random-state¶
All random state of the tuning and the cross-validation is seeded to 42. Random state of the training can be specified through a parameter. The cross-validation will not only cross-validates along different folds but also over different random seeds. Thus, for a single cross-validation with 5 folds and 3 seeds, 15 runs will be executed.