Skip to content

Cross-Validation

[uv run] darts cross-validate-smp ...

Fold strategies

While cross-validating, the data can further be split into a training and validation set. One can specify the fraction of the validation set by providing an integer to total_folds. Higher values will result in smaller, validation sets and therefore more fold-combinations. To reduce the number of folds actually run, one can provide the n_folds parameter to limit the number of folds actually run. Thus, some folds will be skipped. The "folding" is based on scikit-learn and currently supports the following folding methods, which can be specified by the fold_method parameter:

Even in normal training a single KFold split is used to split between training and validation. This can be disabled by setting fold_method to None. In such cases, the validation set becomes equal to the training set, meaning longer validation time and the metrics are always calculated on seen data. This is useful for e.g. the final training of a model before deployment.

Using DartsDataModule

The data splitting is implemented by the darts_segmentation.training.data.DartsDataModule and can therefore be used in other settings as well.

darts_segmentation.training.data.DartsDataModule

DartsDataModule(
    data_dir: pathlib.Path,
    batch_size: int,
    data_split_method: typing.Literal[
        "random", "region", "sample"
    ]
    | None = None,
    data_split_by: list[str | float] | None = None,
    fold_method: typing.Literal[
        "kfold",
        "shuffle",
        "stratified",
        "region",
        "region-stratified",
    ]
    | None = "kfold",
    total_folds: int = 5,
    fold: int = 0,
    subsample: int | None = None,
    bands: darts_segmentation.utils.Bands
    | list[str]
    | None = None,
    augment: list[
        darts_segmentation.training.augmentations.Augmentation
    ]
    | None = None,
    num_workers: int = 0,
    in_memory: bool = False,
)

Bases: lightning.LightningDataModule

Initialize the data module.

Supports spliting the data into train and test set while also defining cv-folds. Folding only applies to the non-test set and splits this into a train and validation set.

Example
  1. Normal train-validate. (Can also be used for testing on the complete dataset)

    dm = DartsDataModule(data_dir, batch_size)
    

  2. Specifying a test split by random (20% of the data will be used for testing)

    dm = DartsDataModule(data_dir, batch_size, data_split_method="random")
    

  3. Specific fold for cross-validation (On the complete dataset, because data_split_method is "none"). This will be take the third of a total of7 folds to determine the validation set.

    dm = DartsDataModule(data_dir, batch_size, fold_method="region-stratified", fold=2, total_folds=7)
    

In general this should be used in combination with a cross-validation loop.

for fold in range(total_folds):
    dm = DartsDataModule(
        data_dir,
        batch_size,
        fold_method="region-stratified",
        fold=fold,
        total_folds=total_folds)
    ...

  1. Don't split anything -> only train
    dm = DartsDataModule(data_dir, batch_size, fold_method=None)
    

Parameters:

  • data_dir (pathlib.Path) –

    The path to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script.

  • batch_size (int) –

    Batch size for training and validation.

  • data_split_method (typing.Literal['random', 'region', 'sample'] | None, default: None ) –

    The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the test size can be specified by providing a list with a single a float between 0 and 1 to data_split_by This will be the fraction of the data to be used for testing. E.g. [0.2] will use 20% of the data for testing. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.

  • data_split_by (list[str | float] | None, default: None ) –

    Select by which regions/samples to split or the size of test set. Defaults to None.

  • fold_method (typing.Literal['kfold', 'shuffle', 'stratified', 'region', 'region-stratified'] | None, default: 'kfold' ) –

    Method for cross-validation split. Defaults to "kfold".

  • total_folds (int, default: 5 ) –

    Total number of folds in cross-validation. Defaults to 5.

  • fold (int, default: 0 ) –

    Index of the current fold. Defaults to 0.

  • subsample (int | None, default: None ) –

    If set, will subsample the dataset to this number of samples. This is useful for debugging and testing. Defaults to None.

  • bands (darts_segmentation.utils.Bands | list[str] | None, default: None ) –

    List of bands to use. Expects the data_dir to contain a config.toml with a "darts.bands" key, with which the indices of the bands will be mapped to. Defaults to None.

  • augment (bool, default: None ) –

    Whether to augment the data. Does nothing for testing. Defaults to True.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. See torch.utils.data.DataLoader. Defaults to 0.

  • in_memory (bool, default: False ) –

    Whether to load the data into memory. Defaults to False.

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(
    self,
    data_dir: Path,
    batch_size: int,
    # data_split is for the test split
    data_split_method: Literal["random", "region", "sample"] | None = None,
    data_split_by: list[str | float] | None = None,
    # fold is for cross-validation split (train/val)
    fold_method: Literal["kfold", "shuffle", "stratified", "region", "region-stratified"] | None = "kfold",
    total_folds: int = 5,
    fold: int = 0,
    subsample: int | None = None,
    bands: Bands | list[str] | None = None,
    augment: list[Augmentation] | None = None,  # Not used for val or test
    num_workers: int = 0,
    in_memory: bool = False,
):
    """Initialize the data module.

    Supports spliting the data into train and test set while also defining cv-folds.
    Folding only applies to the non-test set and splits this into a train and validation set.

    Example:
        1. Normal train-validate. (Can also be used for testing on the complete dataset)
        ```py
        dm = DartsDataModule(data_dir, batch_size)
        ```

        2. Specifying a test split by random (20% of the data will be used for testing)
        ```py
        dm = DartsDataModule(data_dir, batch_size, data_split_method="random")
        ```

        3. Specific fold for cross-validation (On the complete dataset, because data_split_method is "none").
        This will be take the third of a total of7 folds to determine the validation set.
        ```py
        dm = DartsDataModule(data_dir, batch_size, fold_method="region-stratified", fold=2, total_folds=7)
        ```

        In general this should be used in combination with a cross-validation loop.
        ```py
        for fold in range(total_folds):
            dm = DartsDataModule(
                data_dir,
                batch_size,
                fold_method="region-stratified",
                fold=fold,
                total_folds=total_folds)
            ...
        ```

        4. Don't split anything -> only train
        ```py
        dm = DartsDataModule(data_dir, batch_size, fold_method=None)
        ```

    Args:
        data_dir (Path): The path to the data to be used for training.
            Expects a directory containing:
            1. a zarr group called "data.zarr" containing a "x" and "y" array
            2. a geoparquet file called "metadata.parquet" containing the metadata for the data.
                This metadata should contain at least the following columns:
                - "sample_id": The id of the sample
                - "region": The region the sample belongs to
                - "empty": Whether the image is empty
                The index should refer to the index of the sample in the zarr data.
            This directory should be created by a preprocessing script.
        batch_size (int): Batch size for training and validation.
        data_split_method (Literal["random", "region", "sample"] | None, optional):
            The method to use for splitting the data into a train and a test set.
            "random" will split the data randomly, the seed is always 42 and the test size can be specified
            by providing a list with a single a float between 0 and 1 to data_split_by
            This will be the fraction of the data to be used for testing.
            E.g. [0.2] will use 20% of the data for testing.
            "region" will split the data by one or multiple regions,
            which can be specified by providing a str or list of str to data_split_by.
            "sample" will split the data by sample ids, which can also be specified similar to "region".
            If None, no split is done and the complete dataset is used for both training and testing.
            The train split will further be split in the cross validation process.
            Defaults to None.
        data_split_by (list[str | float] | None, optional): Select by which regions/samples to split or
            the size of test set. Defaults to None.
        fold_method (Literal["kfold", "shuffle", "stratified", "region", "region-stratified"] | None, optional):
            Method for cross-validation split. Defaults to "kfold".
        total_folds (int, optional): Total number of folds in cross-validation. Defaults to 5.
        fold (int, optional): Index of the current fold. Defaults to 0.
        subsample (int | None, optional): If set, will subsample the dataset to this number of samples.
            This is useful for debugging and testing. Defaults to None.
        bands (Bands | list[str] | None, optional): List of bands to use.
            Expects the data_dir to contain a config.toml with a "darts.bands" key,
            with which the indices of the bands will be mapped to.
            Defaults to None.
        augment (bool, optional): Whether to augment the data. Does nothing for testing. Defaults to True.
        num_workers (int, optional): Number of workers for data loading. See torch.utils.data.DataLoader.
            Defaults to 0.
        in_memory (bool, optional): Whether to load the data into memory. Defaults to False.

    """
    super().__init__()
    self.save_hyperparameters(ignore=["num_workers", "in_memory"])
    self.data_dir = data_dir
    self.batch_size = batch_size

    self.fold = fold
    self.data_split_method = data_split_method
    self.data_split_by = data_split_by
    self.fold_method = fold_method
    self.total_folds = total_folds

    self.subsample = subsample
    self.augment = augment
    self.num_workers = num_workers
    self.in_memory = in_memory

    data_dir = Path(data_dir)

    metadata_file = data_dir / "metadata.parquet"
    assert metadata_file.exists(), f"Metadata file {metadata_file} not found!"

    config_file = data_dir / "config.toml"
    assert config_file.exists(), f"Config file {config_file} not found!"
    data_bands = toml.load(config_file)["darts"]["bands"]
    bands = bands.names if isinstance(bands, Bands) else bands
    self.bands = [data_bands.index(b) for b in bands] if bands else None

    zdir = data_dir / "data.zarr"
    assert zdir.exists(), f"Data directory {zdir} not found!"
    zroot = zarr.group(store=LocalStore(data_dir / "data.zarr"))
    self.nsamples = zroot["x"].shape[0]
    logger.debug(f"Data directory {zdir} found with {self.nsamples} samples.")

Scoring strategies

To turn the information (metrics) gathered of a single cross-validation into a useful score, we need to somehow aggregate the metrics. In cases we are only interested in a single metric, this is easy: we can easily compute the mean. This metric can be specified by the scoring_metric parameter of the cross validation. It is also possible to use multiple metrics by specifying a list of metrics in the scoring_metric parameter. This, however, makes it a little more complicated.

Multi-metric scoring is implemented as combine-then-reduce, meaning that first for each fold the metrics are combined using the specified strategy, and then the results are reduced via mean. The combining strategy can be specified by the multi_score_strategy parameter. As of now, there are four strategies implemented: "arithmetic", "geometric", "harmonic" and "min".

The following visualization should help visualize how the different strategies work. Note that the loss is interpreted as "lower is better" and has also a broader range of possible values, exceeding 1. For the multi-metric scoring with IoU and Loss the arithmetic and geometric strategies are very instable. The scores for very low loss values where so high that the scores needed to be clipped to the range [0, 1] for the visualization to be able to show the behaviour of these strategies. However, especially the geometric mean shows a smoother curve than the harmonic mean for the multi-metric scoring with IoU and Recall. This should show that the strategy should be chosen carefully and in respect to the metrics used.

IoU & Loss Scoring strategies for JaccardIndex and Loss
IoU & Recall Scoring strategies for JaccardIndex and Recall
Code to reproduce the visualization

If you are unsure which strategy to use, you can use this code snippet to make a visualization based on your metrics:

import numpy as np
import xarray as xr

a = np.arange(0, 1, 0.01)
a = xr.DataArray(a, dims=["a"], coords={"a": a})
# 1 / ... indicates "lower is better" - replace it if needed
b = np.arange(0, 2, 0.01)
b = 1 / xr.DataArray(b, dims=["b"], coords={"b": b})

def viz_strategies(a, b):
    harmonic = 2 / (1 / a + 1 / b)
    geometric = np.sqrt(a * b)
    arithmetic = (a + b) / 2
    minimum = np.minimum(a, b)

    harmonic = harmonic.rename("harmonic mean")
    geometric = geometric.rename("geometric mean")
    arithmetic = arithmetic.rename("arithmetic mean")
    minimum = minimum.rename("minimum")

    fig, axs = plt.subplots(1, 4, figsize=(25, 5))
    axs = axs.flatten()
    harmonic.plot(ax=axs[0])
    axs[0].set_title("Harmonic")
    geometric.plot(ax=axs[1], vmax=min(geometric.max(), 1))
    axs[1].set_title("Geometric")
    arithmetic.plot(ax=axs[2], vmax=min(arithmetic.max(), 1))
    axs[2].set_title("Arithmetic")
    minimum.plot(ax=axs[3])
    axs[3].set_title("Minimum")
    return fig

viz_strategies(a, b).show()

Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics. This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower". If no direction is provided, it is assumed to be ":higher". Has no real effect on the single score calculation, since only the mean is calculated there.

Available metrics

The following metrics are visible to the scoring function:

  • 'train/time'
  • 'train/device/batches_per_second'
  • 'train/device/samples_per_second'
  • 'train/device/flops_per_second'
  • 'train/device/mfu'
  • 'train/loss'
  • 'train/Accuracy'
  • 'train/CohenKappa'
  • 'train/F1Score'
  • 'train/HammingDistance'
  • 'train/JaccardIndex'
  • 'train/Precision'
  • 'train/Recall'
  • 'train/Specificity'
  • 'val/loss'
  • 'val/Accuracy'
  • 'val/CohenKappa'
  • 'val/F1Score'
  • 'val/HammingDistance'
  • 'val/JaccardIndex'
  • 'val/Precision'
  • 'val/Recall'
  • 'val/Specificity'
  • 'val/AUROC'
  • 'val/AveragePrecision'

These are derived from trainer.logged_metrics.

Random-state

All random state of the tuning and the cross-validation is seeded to 42. Random state of the training can be specified through a parameter. The cross-validation will not only cross-validates along different folds but also over different random seeds. Thus, for a single cross-validation with 5 folds and 3 seeds, 15 runs will be executed.