Skip to content

darts_segmentation.training

darts_segmentation.training

Training related functions and classes for Image Segmentation.

convert_lightning_checkpoint

convert_lightning_checkpoint(
    *,
    lightning_checkpoint: pathlib.Path,
    out_directory: pathlib.Path,
    checkpoint_name: str,
    framework: str = "smp",
)

Convert a lightning checkpoint to our own format.

The final checkpoint will contain the model configuration and the state dict. It will be saved to:

    out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"

Parameters:

  • lightning_checkpoint (pathlib.Path) –

    Path to the lightning checkpoint.

  • out_directory (pathlib.Path) –

    Output directory for the converted checkpoint.

  • checkpoint_name (str) –

    A unique name of the new checkpoint.

  • framework (str, default: 'smp' ) –

    The framework used for the model. Defaults to "smp".

Source code in darts-segmentation/src/darts_segmentation/training/train.py
def convert_lightning_checkpoint(
    *,
    lightning_checkpoint: Path,
    out_directory: Path,
    checkpoint_name: str,
    framework: str = "smp",
):
    """Convert a lightning checkpoint to our own format.

    The final checkpoint will contain the model configuration and the state dict.
    It will be saved to:

    ```python
        out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"
    ```

    Args:
        lightning_checkpoint (Path): Path to the lightning checkpoint.
        out_directory (Path): Output directory for the converted checkpoint.
        checkpoint_name (str): A unique name of the new checkpoint.
        framework (str, optional): The framework used for the model. Defaults to "smp".

    """
    import torch

    logger.debug(f"Loading checkpoint from {lightning_checkpoint.resolve()}")
    lckpt = torch.load(lightning_checkpoint, weights_only=False, map_location=torch.device("cpu"))

    now = datetime.now()
    formatted_date = now.strftime("%Y-%m-%d")
    config = lckpt["hyper_parameters"]["config"]
    del config["model"]["encoder_weights"]
    config["time"] = formatted_date
    config["name"] = checkpoint_name
    config["model_framework"] = framework

    statedict = lckpt["state_dict"]
    # Statedict has model. prefix before every weight. We need to remove them. This is an in-place function
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(statedict, "model.")

    own_ckpt = {
        "config": config,
        "statedict": lckpt["state_dict"],
    }

    out_directory.mkdir(exist_ok=True, parents=True)

    out_checkpoint = out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"

    torch.save(own_ckpt, out_checkpoint)

    logger.info(f"Saved converted checkpoint to {out_checkpoint.resolve()}")

cross_validation_smp

Perform cross-validation for a model with given hyperparameters.

Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

Please also consider reading our training guide (docs/guides/training.md).

This cross-validation function is designed to evaluate the performance of a single model configuration. It can be used by a tuning script to tune hyperparameters. It calls the training function, hence most functionality is the same as the training function. In general, it does perform this:

for seed in seeds:
    for fold in folds:
        train_model(seed=seed, fold=fold, ...)

and calculates a score from the results.

To specify on which metric(s) the score is calculated, the scoring_metric parameter can be specified. Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics. This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower". If no direction is provided, it is assumed to be ":higher". Has no real effect on the single score calculation, since only the mean is calculated there.

In a multi-score setting, the score is calculated by combine-then-reduce the metrics. Meaning that first for each fold the metrics are combined using the specified strategy, and then the results are reduced via mean. Please refer to the documentation to understand the different multi-score strategies.

If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable".

Artifacts are stored under {artifact_dir}/{tune_name} for tunes (meaning if tune_name is not None) else {artifact_dir}/_cross_validation.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - early_stopping_patience specifies how many epochs to wait for improvement before stopping. In epochs, this would be check_val_every_n_epoch * early_stopping_patience. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs. Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch. If log_every_n_steps is set to 50 then the training logs and metrics will be logged 4 times per epoch. If check_val_every_n_epoch is set to 5 then validation will be performed every 5 epochs. If plot_every_n_val_epochs is set to 2 then validation samples will be plotted every 10 epochs. If early_stopping_patience is set to 3 then early stopping will be performed after 15 epochs without improvement.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson

Parameters:

Returns:

  • tuple[float, bool, pd.DataFrame]: A single score, a boolean indicating if the score is unstable, and a DataFrame containing run info (seed, fold, metrics, duration, checkpoint)

Raises:

  • ValueError

    If no runs were performed, meaning the configuration is invalid or no data was found.

Source code in darts-segmentation/src/darts_segmentation/training/cv.py
def cross_validation_smp(
    *,
    name: str | None = None,
    tune_name: str | None = None,
    cv: CrossValidationConfig = CrossValidationConfig(),
    training_config: TrainingConfig = TrainingConfig(),
    data_config: DataConfig = DataConfig(),
    device_config: DeviceConfig = DeviceConfig(),
    hparams: Hyperparameters = Hyperparameters(),
    logging_config: LoggingConfig = LoggingConfig(),
):
    """Perform cross-validation for a model with given hyperparameters.

    Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

    Please also consider reading our training guide (docs/guides/training.md).

    This cross-validation function is designed to evaluate the performance of a single model configuration.
    It can be used by a tuning script to tune hyperparameters.
    It calls the training function, hence most functionality is the same as the training function.
    In general, it does perform this:

    ```py
    for seed in seeds:
        for fold in folds:
            train_model(seed=seed, fold=fold, ...)
    ```

    and calculates a score from the results.

    To specify on which metric(s) the score is calculated, the `scoring_metric` parameter can be specified.
    Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics.
    This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower".
    If no direction is provided, it is assumed to be ":higher".
    Has no real effect on the single score calculation, since only the mean is calculated there.

    In a multi-score setting, the score is calculated by combine-then-reduce the metrics.
    Meaning that first for each fold the metrics are combined using the specified strategy,
    and then the results are reduced via mean.
    Please refer to the documentation to understand the different multi-score strategies.

    If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable".

    Artifacts are stored under `{artifact_dir}/{tune_name}` for tunes (meaning if `tune_name` is not None)
    else `{artifact_dir}/_cross_validation`.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `early_stopping_patience` specifies how many epochs to wait for improvement before stopping.
            In epochs, this would be `check_val_every_n_epoch * early_stopping_patience`.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.
    Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch.
    If `log_every_n_steps` is set to 50 then the training logs and metrics will be logged 4 times per epoch.
    If `check_val_every_n_epoch` is set to 5 then validation will be performed every 5 epochs.
    If `plot_every_n_val_epochs` is set to 2 then validation samples will be plotted every 10 epochs.
    If `early_stopping_patience` is set to 3 then early stopping will be performed after 15 epochs without improvement.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── data.zarr/ # this zarr group contains the dataarrays x and y
    ├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
    └── labels.geojson
    ```

    Args:
        name (str | None, optional): Name of the cross-validation. If None, a name is generated automatically.
            Defaults to None.
        tune_name (str | None, optional): Name of the tuning. Should only be specified by a tuning script.
            Defaults to None.
        cv (CrossValidationConfig): Configuration for cross-validation.
        training_config (TrainingConfig): Configuration for the training.
        data_config (DataConfig): Configuration for the data.
        device_config (DeviceConfig): Configuration for the devices to use.
        hparams (Hyperparameters): Hyperparameters for the training.
        logging_config (LoggingConfig): Logging configuration.

    Returns:
        tuple[float, bool, pd.DataFrame]: A single score, a boolean indicating if the score is unstable,
            and a DataFrame containing run info (seed, fold, metrics, duration, checkpoint)

    Raises:
        ValueError: If no runs were performed, meaning the configuration is invalid or no data was found.

    """
    import pandas as pd
    from darts_utils.namegen import generate_counted_name

    from darts_segmentation.training.adp import _adp
    from darts_segmentation.training.scoring import score_from_runs

    tick_fstart = time.perf_counter()

    artifact_dir = logging_config.artifact_dir_at_cv(tune_name)
    cv_name = name or generate_counted_name(artifact_dir)
    artifact_dir = artifact_dir / cv_name
    artifact_dir.mkdir(parents=True, exist_ok=True)

    n_folds = cv.n_folds or data_config.total_folds

    logger.info(
        f"Starting cross-validation '{cv_name}' with data from {data_config.train_data_dir.resolve()}."
        f" Artifacts will be saved to {artifact_dir.resolve()}."
        f" Will run n_randoms*n_folds = {cv.n_randoms}*{n_folds} = {cv.n_randoms * n_folds} experiments."
    )

    seeds = cv.rng_seeds
    logger.debug(f"Using seeds: {seeds}")

    # Plan which runs to perform. These are later consumed based on the parallelization strategy.
    process_inputs: list[_ProcessInputs] = []
    for i, seed in enumerate(seeds):
        for fold in range(n_folds):
            current = i * len(seeds) + fold
            total = n_folds * len(seeds)
            run = TrainRunConfig(
                name=f"{cv_name}-run-f{fold}s{seed}",
                cv_name=cv_name,
                tune_name=tune_name,
                fold=fold,
                random_seed=seed,
            )
            process_inputs.append(
                _ProcessInputs(
                    current=current,
                    total=total,
                    seed=seed,
                    fold=fold,
                    cv=cv,
                    run=run,
                    training_config=training_config,
                    logging_config=logging_config,
                    data_config=data_config,
                    device_config=device_config,
                    hparams=hparams,
                )
            )

    run_infos = []
    # This function abstracts away common logic for running multiprocessing
    for inp, output in _adp(
        process_inputs=process_inputs,
        is_parallel=device_config.strategy == "cv-parallel",
        devices=device_config.devices,
        available_devices=available_devices,
        _run=_run_training,
    ):
        run_infos.append(output.run_info)

    if len(run_infos) == 0:
        raise ValueError(
            "No runs were performed. Please check your configuration and data."
            " If you are using a tuning script, make sure to specify the correct parameters."
        )

    logger.debug(f"{run_infos=}")
    score = score_from_runs(run_infos, cv.scoring_metric, cv.multi_score_strategy)

    run_infos = pd.DataFrame(run_infos)
    run_infos["score"] = score
    is_unstable = run_infos["is_unstable"].any()
    run_infos["score_is_unstable"] = is_unstable
    if is_unstable:
        logger.warning("Score is unstable, meaning at least one of the metrics is NaN, Inf, -Inf or 0.")
    run_infos.to_parquet(artifact_dir / "run_infos.parquet")
    logger.debug(f"Saved run infos to {artifact_dir / 'run_infos.parquet'}")

    tick_fend = time.perf_counter()
    logger.info(
        f"Finished cross-validation '{cv_name}' in {tick_fend - tick_fstart:.2f}s"
        f" with {score=:.4f} ({'stable' if not is_unstable else 'unstable'})."
    )

    return score, is_unstable, run_infos

test_smp

test_smp(
    *,
    train_data_dir: pathlib.Path,
    run_id: str,
    run_name: str,
    model_ckp: pathlib.Path | None = None,
    batch_size: int = 8,
    data_split_method: typing.Literal[
        "random", "region", "sample"
    ]
    | None = None,
    data_split_by: list[str] | str | float | None = None,
    bands: list[str] | None = None,
    artifact_dir: pathlib.Path = pathlib.Path("artifacts"),
    num_workers: int = 0,
    device_config: darts_segmentation.training.train.DeviceConfig = darts_segmentation.training.train.DeviceConfig(),
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
) -> pytorch_lightning.Trainer

Run the testing of the SMP model.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson

Parameters:

  • train_data_dir (pathlib.Path) –

    The path (top-level) to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script.

  • run_id (str) –

    ID of the run.

  • run_name (str) –

    Name of the run.

  • model_ckp (pathlib.Path | None, default: None ) –

    Path to the model checkpoint. If None, try to find the latest checkpoint in artifact_dir / run_name / run_id / checkpoints. Defaults to None.

  • batch_size (int, default: 8 ) –

    Batch size for training and validation.

  • data_split_method (typing.Literal['random', 'region', 'sample'] | None, default: None ) –

    The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the size of the test set can be specified by providing a float between 0 and 1 to data_split_by. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.

  • data_split_by (list[str] | str | float | None, default: None ) –

    Select by which seed/regions/samples split. Defaults to None.

  • bands (list[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • artifact_dir (pathlib.Path, default: pathlib.Path('artifacts') ) –

    Directory to save artifacts. Defaults to Path("lightning_logs").

  • num_workers (int, default: 0 ) –

    Number of workers for the DataLoader. Defaults to 0.

  • device_config (darts_segmentation.training.train.DeviceConfig, default: darts_segmentation.training.train.DeviceConfig() ) –

    Device and distributed strategy related parameters.

  • wandb_entity (str | None, default: None ) –

    WandB entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    WandB project. Defaults to None.

Returns:

  • Trainer ( pytorch_lightning.Trainer ) –

    The trainer object used for training.

Source code in darts-segmentation/src/darts_segmentation/training/train.py
def test_smp(
    *,
    train_data_dir: Path,
    run_id: str,
    run_name: str,
    model_ckp: Path | None = None,
    batch_size: int = 8,
    data_split_method: Literal["random", "region", "sample"] | None = None,
    data_split_by: list[str] | str | float | None = None,
    bands: list[str] | None = None,
    artifact_dir: Path = Path("artifacts"),
    num_workers: int = 0,
    device_config: DeviceConfig = DeviceConfig(),
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
) -> "pl.Trainer":
    """Run the testing of the SMP model.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── data.zarr/ # this zarr group contains the dataarrays x and y
    ├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
    └── labels.geojson
    ```

    Args:
        train_data_dir (Path): The path (top-level) to the data to be used for training.
            Expects a directory containing:
            1. a zarr group called "data.zarr" containing a "x" and "y" array
            2. a geoparquet file called "metadata.parquet" containing the metadata for the data.
                This metadata should contain at least the following columns:
                - "sample_id": The id of the sample
                - "region": The region the sample belongs to
                - "empty": Whether the image is empty
                The index should refer to the index of the sample in the zarr data.
            This directory should be created by a preprocessing script.
        run_id (str): ID of the run.
        run_name (str): Name of the run.
        model_ckp (Path | None): Path to the model checkpoint.
            If None, try to find the latest checkpoint in `artifact_dir / run_name / run_id / checkpoints`.
            Defaults to None.
        batch_size (int): Batch size for training and validation.
        data_split_method (Literal["random", "region", "sample"] | None, optional):
            The method to use for splitting the data into a train and a test set.
            "random" will split the data randomly, the seed is always 42 and the size of the test set can be
            specified by providing a float between 0 and 1 to data_split_by.
            "region" will split the data by one or multiple regions,
            which can be specified by providing a str or list of str to data_split_by.
            "sample" will split the data by sample ids, which can also be specified similar to "region".
            If None, no split is done and the complete dataset is used for both training and testing.
            The train split will further be split in the cross validation process.
            Defaults to None.
        data_split_by (list[str] | str | float | None, optional): Select by which seed/regions/samples split.
            Defaults to None.
        bands (list[str] | None, optional): List of bands to use. Defaults to None.
        artifact_dir (Path, optional): Directory to save artifacts. Defaults to Path("lightning_logs").
        num_workers (int, optional): Number of workers for the DataLoader. Defaults to 0.
        device_config (DeviceConfig, optional): Device and distributed strategy related parameters.
        wandb_entity (str | None, optional): WandB entity. Defaults to None.
        wandb_project (str | None, optional): WandB project. Defaults to None.

    Returns:
        Trainer: The trainer object used for training.

    """
    import lightning as L  # noqa: N812
    import lovely_tensors
    import torch
    from darts.utils.logging import LoggingManager
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import RichProgressBar, ThroughputMonitor
    from lightning.pytorch.loggers import CSVLogger, WandbLogger

    from darts_segmentation.training.callbacks import BinarySegmentationMetrics
    from darts_segmentation.training.data import DartsDataModule
    from darts_segmentation.training.module import LitSMP
    from darts_segmentation.utils import Bands

    LoggingManager.apply_logging_handlers("lightning.pytorch")

    tick_fstart = time.perf_counter()

    # Further nest the artifact directory to avoid cluttering the root directory
    artifact_dir = artifact_dir / "_runs"

    logger.info(
        f"Starting testing '{run_name}' ('{run_id}') with data from {train_data_dir.resolve()}."
        f" Artifacts will be saved to {(artifact_dir / f'{run_name}-{run_id}').resolve()}."
    )
    logger.debug(f"Using config:\n\t{batch_size=}\n\t{device_config}")

    lovely_tensors.set_config(color=False)
    lovely_tensors.monkey_patch()
    torch.set_float32_matmul_precision("medium")
    seed_everything(42, workers=True)

    data_config = toml.load(train_data_dir / "config.toml")["darts"]

    all_bands = Bands.from_config(data_config)
    bands = all_bands.filter(bands) if bands else all_bands

    # Data and model
    datamodule = DartsDataModule(
        data_dir=train_data_dir,
        batch_size=batch_size,
        data_split_method=data_split_method,
        data_split_by=data_split_by,
        bands=bands,
        num_workers=num_workers,
    )
    # Try to infer model checkpoint if not given
    if model_ckp is None:
        checkpoint_dir = artifact_dir / f"{run_name}-{run_id}" / "checkpoints"
        logger.debug(f"No checkpoint provided. Looking for model checkpoint in {checkpoint_dir.resolve()}")
        model_ckp = max(checkpoint_dir.glob("*.ckpt"), key=lambda x: x.stat().st_mtime)
    logger.debug(f"Using model checkpoint at {model_ckp.resolve()}")
    model = LitSMP.load_from_checkpoint(model_ckp)

    # Loggers
    trainer_loggers = [
        CSVLogger(save_dir=artifact_dir, version=f"{run_name}-{run_id}"),
    ]
    logger.debug(f"Logging CSV to {Path(trainer_loggers[0].log_dir).resolve()}")
    if wandb_entity and wandb_project:
        wandb_logger = WandbLogger(
            save_dir=artifact_dir.parent,
            name=run_name,
            version=run_id,
            project=wandb_project,
            entity=wandb_entity,
            resume="allow",
            # Using the group and job_type is a workaround for wandb's lack of support for manually sweeps
            group="none",
            job_type="none",
        )
        trainer_loggers.append(wandb_logger)
        logger.debug(
            f"Logging to WandB with entity '{wandb_entity}' and project '{wandb_project}'."
            f"Artifacts are logged to {(Path(wandb_logger.save_dir) / 'wandb').resolve()}"
        )

    # Callbacks
    callbacks = [
        RichProgressBar(),
        BinarySegmentationMetrics(
            bands=bands,
            batch_size=batch_size,
            patch_size=data_config["patch_size"],
        ),
        ThroughputMonitor(batch_size_fn=lambda batch: batch[0].size(0)),
    ]

    # Test
    trainer = L.Trainer(
        callbacks=callbacks,
        logger=trainer_loggers,
        accelerator=device_config.accelerator,
        strategy=device_config.lightning_strategy,
        num_nodes=device_config.num_nodes,
        devices=device_config.devices,
        deterministic=True,
    )

    trainer.test(model, datamodule, ckpt_path=model_ckp)

    tick_fend = time.perf_counter()
    logger.info(f"Finished testing '{run_name}' in {tick_fend - tick_fstart:.2f}s.")

    if wandb_entity and wandb_project:
        wandb_logger.finalize("success")
        wandb_logger.experiment.finish(exit_code=0)
        logger.debug(f"Finalized WandB logging for '{run_name}'")

    return trainer

train_smp

Run the training of the SMP model, specifically binary segmentation.

Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

Please also consider reading our training guide (docs/guides/training.md).

This training function is meant for single training runs but is also used for cross-validation and hyperparameter tuning by cv.py and tune.py. This strongly affects where artifacts are stored:

  • Run was created by a tune: {artifact_dir}/{tune_name}/{cv_name}/{run_name}-{run_id}
  • Run was created by a cross-validation: {artifact_dir}/_cross_validations/{cv_name}/{run_name}-{run_id}
  • Single runs: {artifact_dir}/_runs/{run_name}-{run_id}

run_name can be specified by the user, else it is generated automatically. In case of cross-validation, the run name is generated automatically by the cross-validation. run_id is generated automatically by the training function. Both are saved to the final checkpoint.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - early_stopping_patience specifies how many epochs to wait for improvement before stopping. In epochs, this would be check_val_every_n_epoch * early_stopping_patience. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs. Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch. If log_every_n_steps is set to 50 then the training logs and metrics will be logged 4 times per epoch. If check_val_every_n_epoch is set to 5 then validation will be performed every 5 epochs. If plot_every_n_val_epochs is set to 2 then validation samples will be plotted every 10 epochs. If early_stopping_patience is set to 3 then early stopping will be performed after 15 epochs without improvement.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson

Parameters:

Returns:

  • pl.Trainer: The trainer object used for training. Contains also metrics.

Source code in darts-segmentation/src/darts_segmentation/training/train.py
def train_smp(
    *,
    run: TrainRunConfig = TrainRunConfig(),
    training_config: TrainingConfig = TrainingConfig(),
    data_config: DataConfig = DataConfig(),
    logging_config: LoggingConfig = LoggingConfig(),
    device_config: DeviceConfig = DeviceConfig(),
    hparams: Hyperparameters = Hyperparameters(),
):
    """Run the training of the SMP model, specifically binary segmentation.

    Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

    Please also consider reading our training guide (docs/guides/training.md).

    This training function is meant for single training runs but is also used for cross-validation and hyperparameter
    tuning by cv.py and tune.py.
    This strongly affects where artifacts are stored:

    - Run was created by a tune: `{artifact_dir}/{tune_name}/{cv_name}/{run_name}-{run_id}`
    - Run was created by a cross-validation: `{artifact_dir}/_cross_validations/{cv_name}/{run_name}-{run_id}`
    - Single runs: `{artifact_dir}/_runs/{run_name}-{run_id}`

    `run_name` can be specified by the user, else it is generated automatically.
    In case of cross-validation, the run name is generated automatically by the cross-validation.
    `run_id` is generated automatically by the training function.
    Both are saved to the final checkpoint.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `early_stopping_patience` specifies how many epochs to wait for improvement before stopping.
            In epochs, this would be `check_val_every_n_epoch * early_stopping_patience`.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.
    Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch.
    If `log_every_n_steps` is set to 50 then the training logs and metrics will be logged 4 times per epoch.
    If `check_val_every_n_epoch` is set to 5 then validation will be performed every 5 epochs.
    If `plot_every_n_val_epochs` is set to 2 then validation samples will be plotted every 10 epochs.
    If `early_stopping_patience` is set to 3 then early stopping will be performed after 15 epochs without improvement.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── data.zarr/ # this zarr group contains the dataarrays x and y
    ├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
    └── labels.geojson
    ```

    Args:
        data_config (DataConfig): Data related parameters for training.
        run (TrainRunConfig): Run related parameters for training.
        logging_config (LoggingConfig): Logging related parameters for training.
        device_config (DeviceConfig): Device and distributed strategy related parameters.
        training_config (TrainingConfig): Training related parameters for training.
        hparams (Hyperparameters): Hyperparameters for the model.

    Returns:
        pl.Trainer: The trainer object used for training. Contains also metrics.

    """
    import lightning as L  # noqa: N812
    import lovely_tensors
    import torch
    from darts.utils.logging import LoggingManager
    from darts_utils.namegen import generate_counted_name, generate_id
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar
    from lightning.pytorch.loggers import CSVLogger, WandbLogger

    from darts_segmentation.segment import SMPSegmenterConfig
    from darts_segmentation.training.callbacks import BinarySegmentationMetrics, BinarySegmentationPreview
    from darts_segmentation.training.data import DartsDataModule
    from darts_segmentation.training.module import LitSMP
    from darts_segmentation.utils import Bands

    LoggingManager.apply_logging_handlers("lightning.pytorch", level=logging.INFO)

    tick_fstart = time.perf_counter()

    # Get the right nesting of the artifact directory
    artifact_dir = logging_config.artifact_dir_at_run(run.cv_name, run.tune_name)

    # Create unique run identification (name can be specified by user, id can be interpreded as a 'version')
    run_name = run.name or generate_counted_name(artifact_dir)
    run_id = generate_id()  # Needed for wandb

    logger.info(
        f"Starting training '{run_name}' ('{run_id}') with data from {data_config.train_data_dir.resolve()}."
        f" Artifacts will be saved to {(artifact_dir / f'{run_name}-{run_id}').resolve()}."
    )
    logger.debug(
        f"Using config:\n\t{run}\n\t{training_config}\n\t{data_config}\n\t{logging_config}\n\t"
        f"{device_config}\n\t{hparams}"
    )
    if training_config.continue_from_checkpoint:
        logger.debug(f"Continuing from checkpoint '{training_config.continue_from_checkpoint.resolve()}'")

    lovely_tensors.monkey_patch()
    lovely_tensors.set_config(color=False)
    torch.set_float32_matmul_precision("medium")
    seed_everything(run.random_seed, workers=True, verbose=False)

    dataset_config = toml.load(data_config.train_data_dir / "config.toml")["darts"]
    all_bands = Bands.from_config(dataset_config)
    bands = all_bands.filter(hparams.bands) if hparams.bands else all_bands
    config = SMPSegmenterConfig(
        bands=bands,
        model={
            "arch": hparams.model_arch,
            "encoder_name": hparams.model_encoder,
            "encoder_weights": hparams.model_encoder_weights,
            "in_channels": len(all_bands) if bands is None else len(bands),
            "classes": 1,
        },
    )

    # Data and model
    datamodule = DartsDataModule(
        data_dir=data_config.train_data_dir,
        batch_size=hparams.batch_size,
        data_split_method=data_config.data_split_method,
        data_split_by=data_config.data_split_by,
        fold_method=data_config.fold_method,
        total_folds=data_config.total_folds,
        fold=run.fold,
        subsample=data_config.subsample,
        bands=hparams.bands,
        augment=hparams.augment,
        num_workers=training_config.num_workers,
    )
    model = LitSMP(
        config=config,
        learning_rate=hparams.learning_rate,
        gamma=hparams.gamma,
        focal_loss_alpha=hparams.focal_loss_alpha,
        focal_loss_gamma=hparams.focal_loss_gamma,
        # These are only stored in the hparams and are not used
        run_id=run_id,
        run_name=run_name,
        cv_name=run.cv_name or "none",
        tune_name=run.tune_name or "none",
        random_seed=run.random_seed,
    )

    # Loggers
    trainer_loggers = [
        CSVLogger(save_dir=artifact_dir, name=None, version=f"{run_name}-{run_id}"),
    ]
    logger.debug(f"Logging CSV to {Path(trainer_loggers[0].log_dir).resolve()}")
    if logging_config.wandb_entity and logging_config.wandb_project:
        tags = [data_config.train_data_dir.stem]
        if run.cv_name:
            tags.append(run.cv_name)
        if run.tune_name:
            tags.append(run.tune_name)
        wandb_logger = WandbLogger(
            save_dir=artifact_dir.parent.parent if run.tune_name or run.cv_name else artifact_dir.parent,
            name=run_name,
            version=run_id,
            project=logging_config.wandb_project,
            entity=logging_config.wandb_entity,
            resume="allow",
            # Using the group and job_type is a workaround for wandb's lack of support for manually sweeps
            group=run.tune_name or "none",
            job_type=run.cv_name or "none",
            # Using tags to quickly identify the run
            tags=tags,
        )
        trainer_loggers.append(wandb_logger)
        logger.debug(
            f"Logging to WandB with entity '{logging_config.wandb_entity}' and project '{logging_config.wandb_project}'"
            f"Artifacts are logged to {(Path(wandb_logger.save_dir) / 'wandb').resolve()}"
        )

    # Callbacks and profiler
    callbacks = [
        RichProgressBar(),
        BinarySegmentationMetrics(
            bands=bands,
            val_set=f"val{run.fold}",
            plot_every_n_val_epochs=logging_config.plot_every_n_val_epochs,
            is_crossval=bool(run.cv_name),
            batch_size=hparams.batch_size,
            patch_size=dataset_config["patch_size"],
        ),
        BinarySegmentationPreview(
            bands=bands,
            val_set=f"val{run.fold}",
            plot_every_n_val_epochs=logging_config.plot_every_n_val_epochs,
        ),
        # Something does not work well here...
        # ThroughputMonitor(batch_size_fn=lambda batch: batch[0].size(0), window_size=log_every_n_steps),
    ]
    if training_config.early_stopping_patience:
        logger.debug(f"Using EarlyStopping with patience {training_config.early_stopping_patience}")
        early_stopping = EarlyStopping(
            monitor="val/JaccardIndex", mode="max", patience=training_config.early_stopping_patience
        )
        callbacks.append(early_stopping)

    # Unsupported: https://github.com/Lightning-AI/pytorch-lightning/issues/19983
    # profiler_dir = artifact_dir / f"{run_name}-{run_id}" / "profiler"
    # profiler_dir.mkdir(parents=True, exist_ok=True)
    # profiler = AdvancedProfiler(dirpath=profiler_dir, filename="perf_logs", dump_stats=True)
    # logger.debug(f"Using profiler with output to {profiler.dirpath.resolve()}")

    logger.debug(
        f"Creating lightning-trainer on {device_config.accelerator} with devices {device_config.devices}"
        f" and strategy '{device_config.lightning_strategy}'"
    )
    # Train
    trainer = L.Trainer(
        max_epochs=training_config.max_epochs,
        callbacks=callbacks,
        log_every_n_steps=logging_config.log_every_n_steps,
        logger=trainer_loggers,
        check_val_every_n_epoch=logging_config.check_val_every_n_epoch,
        accelerator=device_config.accelerator,
        devices=device_config.devices if device_config.devices[0] != "auto" else "auto",
        strategy=device_config.lightning_strategy,
        num_nodes=device_config.num_nodes,
        deterministic=False,  # True does not work for some reason
        # profiler=profiler,
    )
    trainer.fit(model, datamodule, ckpt_path=training_config.continue_from_checkpoint)

    tick_fend = time.perf_counter()
    logger.info(f"Finished training '{run_name}' in {tick_fend - tick_fstart:.2f}s.")

    if logging_config.wandb_entity and logging_config.wandb_project:
        wandb_logger.finalize("success")
        wandb_logger.experiment.finish(exit_code=0)
        logger.debug(f"Finalized WandB logging for '{run_name}'")

    return trainer

tune_smp

Tune the hyper-parameters of the model using cross-validation and random states.

Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

Please also consider reading our training guide (docs/guides/training.md).

This tuning script is designed to sweep over hyperparameters with a cross-validation used to evaluate each hyperparameter configuration. Optionally, by setting retrain_and_test to True, the best hyperparameters are then selected based on the cross-validation scores and a new model is trained on the entire train-split and tested on the test-split.

Hyperparameters can be configured using a hpconfig file (YAML or Toml). Please consult the training guide or the documentation of darts_segmentation.training.hparams.parse_hyperparameters to learn how such a file should be structured. Per default, a random search is performed, where the number of samples can be specified by n_trials. If n_trials is set to "grid", a grid search is performed instead. However, this expects to be every hyperparameter to be configured as either constant value or a choice / list.

To specify on which metric(s) the cv score is calculated, the scoring_metric parameter can be specified. Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics. This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower". If no direction is provided, it is assumed to be ":higher". Has no real effect on the single score calculation, since only the mean is calculated there.

In a multi-score setting, the score is calculated by combine-then-reduce the metrics. Meaning that first for each fold the metrics are combined using the specified strategy, and then the results are reduced via mean. Please refer to the documentation to understand the different multi-score strategies.

If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable". In such cases, the configuration is not considered for further evaluation.

Artifacts are stored under {artifact_dir}/{tune_name}.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - early_stopping_patience specifies how many epochs to wait for improvement before stopping. In epochs, this would be check_val_every_n_epoch * early_stopping_patience. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs. Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch. If log_every_n_steps is set to 50 then the training logs and metrics will be logged 4 times per epoch. If check_val_every_n_epoch is set to 5 then validation will be performed every 5 epochs. If plot_every_n_val_epochs is set to 2 then validation samples will be plotted every 10 epochs. If early_stopping_patience is set to 3 then early stopping will be performed after 15 epochs without improvement.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson

Parameters:

Returns:

  • tuple[float, pd.DataFrame]: The best score (if retrained and tested) and the run infos of all runs.

Raises:

  • ValueError

    If no hyperparameter configuration file is provided.

Source code in darts-segmentation/src/darts_segmentation/training/tune.py
def tune_smp(
    *,
    name: str | None = None,
    n_trials: int | Literal["grid"] = 100,
    retrain_and_test: bool = False,
    cv_config: CrossValidationConfig = CrossValidationConfig(),
    training_config: TrainingConfig = TrainingConfig(),
    data_config: DataConfig = DataConfig(),
    device_config: DeviceConfig = DeviceConfig(),
    logging_config: LoggingConfig = LoggingConfig(),
    hpconfig: Path | None = None,
    config_file: Annotated[Path | None, cyclopts.Parameter(parse=False)] = None,
):
    """Tune the hyper-parameters of the model using cross-validation and random states.

    Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

    Please also consider reading our training guide (docs/guides/training.md).

    This tuning script is designed to sweep over hyperparameters with a cross-validation
    used to evaluate each hyperparameter configuration.
    Optionally, by setting `retrain_and_test` to True, the best hyperparameters are then selected based on the
    cross-validation scores and a new model is trained on the entire train-split and tested on the test-split.

    Hyperparameters can be configured using a `hpconfig` file (YAML or Toml).
    Please consult the training guide or the documentation of
    `darts_segmentation.training.hparams.parse_hyperparameters` to learn how such a file should be structured.
    Per default, a random search is performed, where the number of samples can be specified by `n_trials`.
    If `n_trials` is set to "grid", a grid search is performed instead.
    However, this expects to be every hyperparameter to be configured as either constant value or a choice / list.

    To specify on which metric(s) the cv score is calculated, the `scoring_metric` parameter can be specified.
    Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics.
    This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower".
    If no direction is provided, it is assumed to be ":higher".
    Has no real effect on the single score calculation, since only the mean is calculated there.

    In a multi-score setting, the score is calculated by combine-then-reduce the metrics.
    Meaning that first for each fold the metrics are combined using the specified strategy,
    and then the results are reduced via mean.
    Please refer to the documentation to understand the different multi-score strategies.

    If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable".
    In such cases, the configuration is not considered for further evaluation.

    Artifacts are stored under `{artifact_dir}/{tune_name}`.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `early_stopping_patience` specifies how many epochs to wait for improvement before stopping.
            In epochs, this would be `check_val_every_n_epoch * early_stopping_patience`.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.
    Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch.
    If `log_every_n_steps` is set to 50 then the training logs and metrics will be logged 4 times per epoch.
    If `check_val_every_n_epoch` is set to 5 then validation will be performed every 5 epochs.
    If `plot_every_n_val_epochs` is set to 2 then validation samples will be plotted every 10 epochs.
    If `early_stopping_patience` is set to 3 then early stopping will be performed after 15 epochs without improvement.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── data.zarr/ # this zarr group contains the dataarrays x and y
    ├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
    └── labels.geojson
    ```

    Args:
        name (str | None, optional): Name of the tuning run.
            Will be generated based on the number of existing directories in the artifact directory if None.
            Defaults to None.
        n_trials (int | Literal["grid"], optional): Number of trials to perform in hyperparameter tuning.
            If "grid", span a grid search over all configured hyperparameters.
            In a grid search, only constant or choice hyperparameters are allowed.
            Defaults to 100.
        retrain_and_test (bool, optional): Whether to retrain the model with the best hyperparameters and test it.
            Defaults to False.
        cv_config (CrossValidationConfig, optional): Configuration for cross-validation.
            Defaults to CrossValidationConfig().
        training_config (TrainingConfig, optional): Configuration for training.
            Defaults to TrainingConfig().
        data_config (DataConfig, optional): Configuration for data.
            Defaults to DataConfig().
        device_config (DeviceConfig, optional): Configuration for device.
            Defaults to DeviceConfig().
        logging_config (LoggingConfig, optional): Configuration for logging.
            Defaults to LoggingConfig().
        hpconfig (Path | None, optional): Path to the hyperparameter configuration file.
            Please see the documentation of `hyperparameters` for more information.
            Defaults to None.
        config_file (Path | None, optional): Path to the configuration file. If provided,
            it will be used instead of `hpconfig` if `hpconfig` is None. Defaults to None.

    Returns:
        tuple[float, pd.DataFrame]: The best score (if retrained and tested) and the run infos of all runs.

    Raises:
        ValueError: If no hyperparameter configuration file is provided.

    """
    import pandas as pd
    from darts_utils.namegen import generate_counted_name

    from darts_segmentation.training.adp import _adp
    from darts_segmentation.training.hparams import parse_hyperparameters, sample_hyperparameters
    from darts_segmentation.training.scoring import score_from_single_run
    from darts_segmentation.training.train import test_smp, train_smp

    tick_fstart = time.perf_counter()

    tune_name = name or generate_counted_name(logging_config.artifact_dir)
    artifact_dir = logging_config.artifact_dir / tune_name
    run_infos_file = artifact_dir / f"{tune_name}.parquet"

    # Check if the artifact directory is empty
    assert not artifact_dir.exists(), f"{artifact_dir} already exists."
    artifact_dir.mkdir(parents=True, exist_ok=True)

    hpconfig = hpconfig or config_file
    if hpconfig is None:
        raise ValueError(
            "No hyperparameter configuration file provided. Please provide a valid file via the `--hpconfig` flag."
        )
    param_grid = parse_hyperparameters(hpconfig)
    logger.debug(f"Parsed hyperparameter grid: {param_grid}")
    param_list = sample_hyperparameters(param_grid, n_trials)

    logger.info(
        f"Starting tune '{tune_name}' with data from {data_config.train_data_dir.resolve()}."
        f" Artifacts will be saved to {artifact_dir.resolve()}."
        f" Will run n_trials*n_randoms*n_folds ="
        f" {len(param_list)}*{cv_config.n_randoms}*{cv_config.n_folds} ="
        f" {len(param_list) * cv_config.n_randoms * cv_config.n_folds} experiments."
    )

    # Plan which runs to perform. These are later consumed based on the parallelization strategy.
    process_inputs = [
        _ProcessInputs(
            current=i,
            total=len(param_list),
            tune_name=tune_name,
            cv=cv_config,
            training_config=training_config,
            logging_config=logging_config,
            data_config=data_config,
            device_config=device_config,
            hparams=hparams,
        )
        for i, hparams in enumerate(param_list)
    ]

    run_infos: list[pd.DataFrame] = []
    best_score = 0
    best_hp = None

    # This function abstracts away common logic for running multiprocessing
    for inp, output in _adp(
        process_inputs=process_inputs,
        is_parallel=device_config.strategy == "tune-parallel",
        devices=device_config.devices,
        available_devices=available_devices,
        _run=_run_cv,
    ):
        run_infos.append(output.run_infos)
        if not output.is_unstable and output.score > best_score:
            best_score = output.score
            best_hp = inp.hparams

        # Save already here to prevent data loss if something goes wrong
        pd.concat(run_infos).reset_index(drop=True).to_parquet(run_infos_file)
        logger.debug(f"Saved run infos to {run_infos_file}")

    if len(run_infos) == 0:
        logger.error("No hyperparameters resulted in a valid score. Please check the logs for more information.")
        return 0, run_infos

    run_infos = pd.concat(run_infos).reset_index(drop=True)

    tick_fend = time.perf_counter()

    if best_hp is None:
        logger.warning(
            f"Tuning completed in {tick_fend - tick_fstart:.2f}s."
            " No hyperparameters resulted in a valid score. Please check the logs for more information."
        )
        return 0, run_infos
    logger.info(
        f"Tuning completed in {tick_fend - tick_fstart:.2f}s. The best score was {best_score:.4f} with {best_hp}."
    )

    # =====================
    # === End of tuning ===
    # =====================

    if not retrain_and_test:
        return 0, run_infos

    logger.info("Starting retraining with the best hyperparameters.")

    tick_fstart = time.perf_counter()
    trainer = train_smp(
        run=TrainRunConfig(name=f"{tune_name}-retrain"),
        training_config=training_config,  # TODO: device and strategy
        data_config=DataConfig(
            train_data_dir=data_config.train_data_dir,
            data_split_method=data_config.data_split_method,
            data_split_by=data_config.data_split_by,
            fold_method=None,  # No fold method for retraining
            total_folds=None,  # No folds for retraining
        ),
        logging_config=LoggingConfig(
            artifact_dir=artifact_dir,
            log_every_n_steps=logging_config.log_every_n_steps,
            check_val_every_n_epoch=logging_config.check_val_every_n_epoch,
            plot_every_n_val_epochs=logging_config.plot_every_n_val_epochs,
            wandb_entity=logging_config.wandb_entity,
            wandb_project=logging_config.wandb_project,
        ),
        hparams=best_hp,
    )
    run_id = trainer.lightning_module.hparams["run_id"]
    trainer = test_smp(
        train_data_dir=data_config.train_data_dir,
        run_id=run_id,
        run_name=f"{tune_name}-retrain",
        model_ckp=trainer.checkpoint_callback.best_model_path,
        batch_size=best_hp.batch_size,
        data_split_method=data_config.data_split_method,
        data_split_by=data_config.data_split_by,
        artifact_dir=artifact_dir,
        num_workers=training_config.num_workers,
        device_config=device_config,
        wandb_entity=logging_config.wandb_entity,
        wandb_project=logging_config.wandb_project,
    )

    run_info = {k: v.item() for k, v in trainer.callback_metrics.items()}
    test_scoring_metric = (
        cv_config.scoring_metric.replace("val/", "test/")
        if isinstance(cv_config.scoring_metric, str)
        else [sm.replace("val/", "test/") for sm in cv_config.scoring_metric]
    )
    score = score_from_single_run(run_info, test_scoring_metric, cv_config.multi_score_strategy)
    is_unstable = check_score_is_unstable(run_info, cv_config.scoring_metric)
    tick_fend = time.perf_counter()
    logger.info(
        f"Retraining and testing completed successfully in {tick_fend - tick_fstart:.2f}s"
        f" with {score=:.4f} ({'stable' if not is_unstable else 'unstable'})."
    )

    return score, run_infos