Skip to content

tune

darts_segmentation.training.tune

More advanced hyper-parameter tuning.

available_devices module-attribute

available_devices = multiprocessing.Queue()

logger module-attribute

logger = logging.getLogger(
    __name__.replace("darts_", "darts.")
)

CrossValidationConfig dataclass

CrossValidationConfig(
    n_folds: int | None = None,
    n_randoms: int = 3,
    scoring_metric: list[str] = lambda: [
        "val/JaccardIndex",
        "val/Recall",
    ](),
    multi_score_strategy: typing.Literal[
        "harmonic", "arithmetic", "geometric", "min"
    ] = "harmonic",
)

Configuration for cross-validation.

This is used to configure the cross-validation process. It is used by the cross_validation_smp function.

Attributes:

  • n_folds (int | None) –

    Number of folds to perform in cross-validation. If None, all folds (total_folds) will be used. Defaults to None.

  • n_randoms (int) –

    Number of random seeds to perform in cross-validation. First three seeds are always 42, 21, 69, further seeds are deterministic generated. Defaults to 3.

  • scoring_metric (list[str]) –

    Metric(s) to use for scoring. Defaults to ["val/JaccardIndex", "val/Recall"].

  • multi_score_strategy (typing.Literal['harmonic', 'arithmetic', 'geometric', 'min']) –

    Strategy for combining multiple metrics. Defaults to "harmonic".

multi_score_strategy class-attribute instance-attribute

multi_score_strategy: typing.Literal[
    "harmonic", "arithmetic", "geometric", "min"
] = "harmonic"

n_folds class-attribute instance-attribute

n_folds: int | None = None

n_randoms class-attribute instance-attribute

n_randoms: int = 3

rng_seeds property

rng_seeds: list[int]

Generate a list of seeds for cross-validation.

Returns:

  • list[int]

    list[int]: A list of seeds for cross-validation.

  • list[int]

    The first three seeds are always 42, 21, 69, further seeds are deterministically generated.

scoring_metric class-attribute instance-attribute

scoring_metric: list[str] = dataclasses.field(
    default_factory=lambda: [
        "val/JaccardIndex",
        "val/Recall",
    ]
)

DataConfig dataclass

DataConfig(
    train_data_dir: pathlib.Path = pathlib.Path("train"),
    data_split_method: typing.Literal[
        "random", "region", "sample"
    ]
    | None = None,
    data_split_by: list[str | float] | None = None,
    fold_method: typing.Literal[
        "kfold",
        "shuffle",
        "stratified",
        "region",
        "region-stratified",
    ] = "kfold",
    total_folds: int = 5,
    subsample: int | None = None,
)

Data related parameters for training.

Defines the script inputs for the training script and can be propagated by the cross-validation and tuning scripts.

Attributes:

  • train_data_dir (pathlib.Path) –

    The path (top-level) to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script. Defaults to "train".

  • batch_size (int) –

    Batch size for training and validation.

  • data_split_method (typing.Literal['random', 'region', 'sample'] | None) –

    The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the test size can be specified by providing a list with a single a float between 0 and 1 to data_split_by This will be the fraction of the data to be used for testing. E.g. [0.2] will use 20% of the data for testing. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.

  • data_split_by (list[str | float] | None) –

    Select by which regions/samples to split or the size of test set. Defaults to None.

  • fold_method (typing.Literal['kfold', 'shuffle', 'stratified', 'region', 'region-stratified']) –

    Method for cross-validation split. Defaults to "kfold".

  • total_folds (int) –

    Total number of folds in cross-validation. Defaults to 5.

  • subsample (int | None) –

    If set, will subsample the dataset to this number of samples. This is useful for debugging and testing. Defaults to None.

data_split_by class-attribute instance-attribute

data_split_by: list[str | float] | None = None

data_split_method class-attribute instance-attribute

data_split_method: (
    typing.Literal["random", "region", "sample"] | None
) = None

fold_method class-attribute instance-attribute

fold_method: typing.Literal[
    "kfold",
    "shuffle",
    "stratified",
    "region",
    "region-stratified",
] = "kfold"

subsample class-attribute instance-attribute

subsample: int | None = None

total_folds class-attribute instance-attribute

total_folds: int = 5

train_data_dir class-attribute instance-attribute

train_data_dir: pathlib.Path = pathlib.Path('train')

DeviceConfig dataclass

DeviceConfig(
    accelerator: typing.Literal[
        "auto", "cpu", "gpu", "mps", "tpu"
    ] = "auto",
    strategy: typing.Literal[
        "auto",
        "ddp",
        "ddp_fork",
        "ddp_notebook",
        "fsdp",
        "cv-parallel",
        "tune-parallel",
    ] = "auto",
    devices: list[int | str] = lambda: ["auto"](),
    num_nodes: int = 1,
)

Device and Distributed Strategy related parameters.

Attributes:

  • accelerator (typing.Literal['auto', 'cpu', 'gpu', 'mps', 'tpu']) –

    Accelerator to use. Defaults to "auto".

  • strategy (typing.Literal['auto', 'ddp', 'ddp_fork', 'ddp_notebook', 'fsdp', 'cv-parallel', 'tune-parallel', 'cv-parallel', 'tune-parallel']) –

    Distributed strategy to use. Defaults to "auto".

  • devices (list[int | str]) –

    List of devices to use. Defaults to ["auto"].

  • num_nodes (int) –

    Number of nodes to use for distributed training. Defaults to 1.

accelerator class-attribute instance-attribute

accelerator: typing.Literal[
    "auto", "cpu", "gpu", "mps", "tpu"
] = "auto"

devices class-attribute instance-attribute

devices: list[int | str] = dataclasses.field(
    default_factory=lambda: ["auto"]
)

lightning_strategy property

lightning_strategy: str

Get the Lightning strategy for the current configuration.

Returns:

  • str ( str ) –

    The Lightning strategy to use.

num_nodes class-attribute instance-attribute

num_nodes: int = 1

strategy class-attribute instance-attribute

strategy: typing.Literal[
    "auto",
    "ddp",
    "ddp_fork",
    "ddp_notebook",
    "fsdp",
    "cv-parallel",
    "tune-parallel",
] = "auto"

in_parallel

in_parallel(
    device: int | str | None = None,
) -> darts_segmentation.training.train.DeviceConfig

Turn the current configuration into a suitable configuration for parallel training.

Parameters:

  • device (int | str | None, default: None ) –

    The device to use for parallel training. If None, assumes non-multiprocessing parallel training and propagate all devices. Defaults to None.

Returns:

Source code in darts-segmentation/src/darts_segmentation/training/train.py
def in_parallel(self, device: int | str | None = None) -> "DeviceConfig":
    """Turn the current configuration into a suitable configuration for parallel training.

    Args:
        device (int | str | None, optional): The device to use for parallel training.
            If None, assumes non-multiprocessing parallel training and propagate all devices.
            Defaults to None.

    Returns:
        DeviceConfig: A new DeviceConfig instance that is suitable for parallel training.

    """
    # In case of parallel training via multiprocessing, only few strategies are allowed.
    if self.strategy in ["ddp", "ddp_fork", "ddp_notebook", "fsdp"]:
        logger.warning("Using 'ddp_fork' instead of 'ddp' for multiprocessing.")
        return DeviceConfig(
            accelerator=self.accelerator,
            strategy="ddp_fork",  # Fork is the only supported strategy for multiprocessing
            devices=self.devices,
            num_nodes=self.num_nodes,
        )
    elif device is not None:
        return DeviceConfig(
            accelerator=self.accelerator,
            strategy=self.strategy,
            # If a device is specified, we assume that we want to run on a single device
            devices=[device],
            num_nodes=1,
        )
    else:
        return self

Hyperparameters dataclass

Hyperparameters(
    model_arch: str = "Unet",
    model_encoder: str = "dpn107",
    model_encoder_weights: str | None = None,
    augment: list[
        darts_segmentation.training.augmentations.Augmentation
    ]
    | None = None,
    learning_rate: float = 0.001,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    batch_size: int = 8,
    bands: list[str] | None = None,
)

Hyperparameters for Cyclopts CLI.

Attributes:

augment class-attribute instance-attribute

bands class-attribute instance-attribute

bands: list[str] | None = None

batch_size class-attribute instance-attribute

batch_size: int = 8

focal_loss_alpha class-attribute instance-attribute

focal_loss_alpha: float | None = None

focal_loss_gamma class-attribute instance-attribute

focal_loss_gamma: float = 2.0

gamma class-attribute instance-attribute

gamma: float = 0.9

learning_rate class-attribute instance-attribute

learning_rate: float = 0.001

model_arch class-attribute instance-attribute

model_arch: str = 'Unet'

model_encoder class-attribute instance-attribute

model_encoder: str = 'dpn107'

model_encoder_weights class-attribute instance-attribute

model_encoder_weights: str | None = None

LoggingConfig dataclass

LoggingConfig(
    artifact_dir: pathlib.Path = pathlib.Path("artifacts"),
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    plot_every_n_val_epochs: int = 5,
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
)

Logging related parameters for training.

Defines the script inputs for the training script and can be propagated by the cross-validation and tuning scripts.

Attributes:

artifact_dir class-attribute instance-attribute

artifact_dir: pathlib.Path = pathlib.Path('artifacts')

check_val_every_n_epoch class-attribute instance-attribute

check_val_every_n_epoch: int = 3

log_every_n_steps class-attribute instance-attribute

log_every_n_steps: int = 10

plot_every_n_val_epochs class-attribute instance-attribute

plot_every_n_val_epochs: int = 5

wandb_entity class-attribute instance-attribute

wandb_entity: str | None = None

wandb_project class-attribute instance-attribute

wandb_project: str | None = None

artifact_dir_at_cv

artifact_dir_at_cv(tune_name: str | None) -> pathlib.Path

Nest the artifact directory for cross-validation runs.

Similar to parse_artifact_dir_for_run, but meant to be used by the cross-validation script.

Also creates the directory if it does not exist.

Parameters:

  • tune_name (str | None) –

    Name of the tuning, if applicable.

Returns:

  • Path ( pathlib.Path ) –

    The nested artifact directory path for cross-validation runs.

Source code in darts-segmentation/src/darts_segmentation/training/train.py
def artifact_dir_at_cv(self, tune_name: str | None) -> Path:
    """Nest the artifact directory for cross-validation runs.

    Similar to `parse_artifact_dir_for_run`, but meant to be used by the cross-validation script.

    Also creates the directory if it does not exist.

    Args:
        tune_name (str | None): Name of the tuning, if applicable.

    Returns:
        Path: The nested artifact directory path for cross-validation runs.

    """
    artifact_dir = self.artifact_dir / tune_name if tune_name else self.artifact_dir / "_cross_validations"
    artifact_dir.mkdir(parents=True, exist_ok=True)
    return artifact_dir

artifact_dir_at_run

artifact_dir_at_run(
    cv_name: str | None, tune_name: str | None
) -> pathlib.Path

Nest the artifact directory to avoid cluttering the root directory.

For cv it is expected that the cv function already nests the artifact directory Meaning for cv the artifact_dir of this function should be either {artifact_dir}/_cross_validations/{cv_name} or {artifact_dir}/{tune_name}/{cv_name}

Also creates the directory if it does not exist.

Parameters:

  • cv_name (str | None) –

    Name of the cross-validation.

  • tune_name (str | None) –

    Name of the tuning.

Raises:

  • ValueError

    If tune_name is specified, but cv_name is not, which is invalid.

Returns:

  • Path ( pathlib.Path ) –

    The nested artifact directory path.

Source code in darts-segmentation/src/darts_segmentation/training/train.py
def artifact_dir_at_run(self, cv_name: str | None, tune_name: str | None) -> Path:
    """Nest the artifact directory to avoid cluttering the root directory.

    For cv it is expected that the cv function already nests the artifact directory
    Meaning for cv the artifact_dir of this function should be either
    {artifact_dir}/_cross_validations/{cv_name} or {artifact_dir}/{tune_name}/{cv_name}

    Also creates the directory if it does not exist.

    Args:
        cv_name (str | None): Name of the cross-validation.
        tune_name (str | None): Name of the tuning.

    Raises:
        ValueError: If tune_name is specified, but cv_name is not, which is invalid.

    Returns:
        Path: The nested artifact directory path.

    """
    # Run only
    if cv_name is None and tune_name is None:
        artifact_dir = self.artifact_dir / "_runs"
    # Cross-validation only
    elif cv_name is not None and tune_name is None:
        artifact_dir = self.artifact_dir / "_cross_validations" / cv_name
    # Cross-validation and tuning
    elif cv_name is not None and tune_name is not None:
        artifact_dir = self.artifact_dir / tune_name / cv_name
    # Tuning only (invalid)
    else:
        raise ValueError(
            "Cannot parse artifact directory for cross-validation and tuning. "
            "Please specify either cv_name or tune_name, but not both."
        )
    artifact_dir.mkdir(parents=True, exist_ok=True)
    return artifact_dir

TrainRunConfig dataclass

TrainRunConfig(
    name: str | None = None,
    cv_name: str | None = None,
    tune_name: str | None = None,
    fold: int = 0,
    random_seed: int = 42,
)

Run related parameters for training.

Defines the script inputs for the training script. Must be build by the cross-validation and tuning scripts.

Attributes:

  • name (str | None) –

    Name of the run. If None is generated automatically. Defaults to None.

  • cv_name (str | None) –

    Name of the cross-validation. Should only be specified by a cross-validation script. Defaults to None.

  • tune_name (str | None) –

    Name of the tuning. Should only be specified by a tuning script. Defaults to None.

  • fold (int) –

    Index of the current fold. Defaults to 0.

  • random_seed (int) –

    Random seed for deterministic training. Defaults to 42.

cv_name class-attribute instance-attribute

cv_name: str | None = None

fold class-attribute instance-attribute

fold: int = 0

name class-attribute instance-attribute

name: str | None = None

random_seed class-attribute instance-attribute

random_seed: int = 42

tune_name class-attribute instance-attribute

tune_name: str | None = None

TrainingConfig dataclass

TrainingConfig(
    continue_from_checkpoint: pathlib.Path | None = None,
    max_epochs: int = 100,
    early_stopping_patience: int = 5,
    num_workers: int = 0,
)

Training related parameters for training.

Defines the script inputs for the training script and can be propagated by the cross-validation and tuning scripts.

Attributes:

continue_from_checkpoint class-attribute instance-attribute

continue_from_checkpoint: pathlib.Path | None = None

early_stopping_patience class-attribute instance-attribute

early_stopping_patience: int = 5

max_epochs class-attribute instance-attribute

max_epochs: int = 100

num_workers class-attribute instance-attribute

num_workers: int = 0

_ProcessInputs dataclass

current instance-attribute

current: int

data_config instance-attribute

device_config instance-attribute

hparams instance-attribute

logging_config instance-attribute

total instance-attribute

total: int

training_config instance-attribute

tune_name instance-attribute

tune_name: str

_ProcessOutputs dataclass

_ProcessOutputs(
    run_infos: pandas.DataFrame,
    score: float,
    is_unstable: bool,
)

is_unstable instance-attribute

is_unstable: bool

run_infos instance-attribute

run_infos: pandas.DataFrame

score instance-attribute

score: float

_run_cv

Source code in darts-segmentation/src/darts_segmentation/training/tune.py
def _run_cv(inp: _ProcessInputs):
    # Wrapper function for handling parallel multiprocessing training runs.
    import pandas as pd

    from darts_segmentation.training.cv import cross_validation_smp

    cv_name = f"{inp.tune_name}-cv{inp.current}"

    # Setup device configuration: If strategy is "tune-parallel" expect a mp scenario:
    # Wait for a device to become available.
    # Otherwise, expect a serial scenario, where the devices and strategy are set by the user.
    is_parallel = inp.device_config.strategy == "tune-parallel"
    if is_parallel:
        device = available_devices.get()
        device_config = inp.device_config.in_parallel(device)
        logger.info(f"Starting cv '{cv_name}' ({inp.current + 1}/{inp.total}) on device {device}.")
    else:
        device = None
        device_config = inp.device_config.in_parallel()
        logger.info(f"Starting cv '{cv_name}' ({inp.current + 1}/{inp.total}).")

    try:
        score, is_unstable, cv_run_infos = cross_validation_smp(
            name=cv_name,
            tune_name=inp.tune_name,
            cv=inp.cv,
            training_config=inp.training_config,
            data_config=inp.data_config,
            logging_config=inp.logging_config,
            hparams=inp.hparams,
            device_config=device_config,
        )

        for key, value in asdict(inp.hparams).items():
            cv_run_infos[key] = value if not isinstance(value, list) else pd.Series([value] * len(cv_run_infos))

        cv_run_infos["cv_name"] = cv_name
        output = _ProcessOutputs(
            run_infos=cv_run_infos,
            score=score,
            is_unstable=is_unstable,
        )
    finally:
        # If we are in parallel mode, we need to return the device to the queue.
        if is_parallel:
            logger.debug(f"Free device {device} for cv {cv_name}")
            available_devices.put(device)
    return output

check_score_is_unstable

check_score_is_unstable(
    run_info: dict, scoring_metric: list[str] | str
) -> bool

Check the stability of the scoring metric.

If any metric value is not finite or equal to zero, the scoring metric is considered unstable.

Parameters:

  • run_info (dict) –

    The run information.

  • scoring_metric (list[str] | str) –

    The scoring metric.

Returns:

  • bool ( bool ) –

    True if the scoring metric is unstable, False otherwise.

Raises:

  • ValueError

    If an unknown scoring metric type is provided.

Source code in darts-segmentation/src/darts_segmentation/training/scoring.py
def check_score_is_unstable(run_info: dict, scoring_metric: list[str] | str) -> bool:
    """Check the stability of the scoring metric.

    If any metric value is not finite or equal to zero, the scoring metric is considered unstable.

    Args:
        run_info (dict): The run information.
        scoring_metric (list[str] | str): The scoring metric.

    Returns:
        bool: True if the scoring metric is unstable, False otherwise.

    Raises:
        ValueError: If an unknown scoring metric type is provided.

    """
    # Single score in list
    if isinstance(scoring_metric, list) and len(scoring_metric) == 1:
        scoring_metric = scoring_metric[0]

    if isinstance(scoring_metric, str):
        metric_value = run_info[scoring_metric]
        is_unstable = not isfinite(metric_value) or metric_value == 0
        return is_unstable
    elif isinstance(scoring_metric, list):
        metric_values = [run_info[metric] for metric in scoring_metric]
        is_unstable = any(not isfinite(val) or val == 0 for val in metric_values)
        return is_unstable
    else:
        raise ValueError("Invalid scoring metric type")

tune_smp

Tune the hyper-parameters of the model using cross-validation and random states.

Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

Please also consider reading our training guide (docs/guides/training.md).

This tuning script is designed to sweep over hyperparameters with a cross-validation used to evaluate each hyperparameter configuration. Optionally, by setting retrain_and_test to True, the best hyperparameters are then selected based on the cross-validation scores and a new model is trained on the entire train-split and tested on the test-split.

Hyperparameters can be configured using a hpconfig file (YAML or Toml). Please consult the training guide or the documentation of darts_segmentation.training.hparams.parse_hyperparameters to learn how such a file should be structured. Per default, a random search is performed, where the number of samples can be specified by n_trials. If n_trials is set to "grid", a grid search is performed instead. However, this expects to be every hyperparameter to be configured as either constant value or a choice / list.

To specify on which metric(s) the cv score is calculated, the scoring_metric parameter can be specified. Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics. This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower". If no direction is provided, it is assumed to be ":higher". Has no real effect on the single score calculation, since only the mean is calculated there.

In a multi-score setting, the score is calculated by combine-then-reduce the metrics. Meaning that first for each fold the metrics are combined using the specified strategy, and then the results are reduced via mean. Please refer to the documentation to understand the different multi-score strategies.

If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable". In such cases, the configuration is not considered for further evaluation.

Artifacts are stored under {artifact_dir}/{tune_name}.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - early_stopping_patience specifies how many epochs to wait for improvement before stopping. In epochs, this would be check_val_every_n_epoch * early_stopping_patience. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs. Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch. If log_every_n_steps is set to 50 then the training logs and metrics will be logged 4 times per epoch. If check_val_every_n_epoch is set to 5 then validation will be performed every 5 epochs. If plot_every_n_val_epochs is set to 2 then validation samples will be plotted every 10 epochs. If early_stopping_patience is set to 3 then early stopping will be performed after 15 epochs without improvement.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson

Parameters:

Returns:

  • tuple[float, pd.DataFrame]: The best score (if retrained and tested) and the run infos of all runs.

Raises:

  • ValueError

    If no hyperparameter configuration file is provided.

Source code in darts-segmentation/src/darts_segmentation/training/tune.py
def tune_smp(
    *,
    name: str | None = None,
    n_trials: int | Literal["grid"] = 100,
    retrain_and_test: bool = False,
    cv_config: CrossValidationConfig = CrossValidationConfig(),
    training_config: TrainingConfig = TrainingConfig(),
    data_config: DataConfig = DataConfig(),
    device_config: DeviceConfig = DeviceConfig(),
    logging_config: LoggingConfig = LoggingConfig(),
    hpconfig: Path | None = None,
    config_file: Annotated[Path | None, cyclopts.Parameter(parse=False)] = None,
):
    """Tune the hyper-parameters of the model using cross-validation and random states.

    Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

    Please also consider reading our training guide (docs/guides/training.md).

    This tuning script is designed to sweep over hyperparameters with a cross-validation
    used to evaluate each hyperparameter configuration.
    Optionally, by setting `retrain_and_test` to True, the best hyperparameters are then selected based on the
    cross-validation scores and a new model is trained on the entire train-split and tested on the test-split.

    Hyperparameters can be configured using a `hpconfig` file (YAML or Toml).
    Please consult the training guide or the documentation of
    `darts_segmentation.training.hparams.parse_hyperparameters` to learn how such a file should be structured.
    Per default, a random search is performed, where the number of samples can be specified by `n_trials`.
    If `n_trials` is set to "grid", a grid search is performed instead.
    However, this expects to be every hyperparameter to be configured as either constant value or a choice / list.

    To specify on which metric(s) the cv score is calculated, the `scoring_metric` parameter can be specified.
    Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics.
    This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower".
    If no direction is provided, it is assumed to be ":higher".
    Has no real effect on the single score calculation, since only the mean is calculated there.

    In a multi-score setting, the score is calculated by combine-then-reduce the metrics.
    Meaning that first for each fold the metrics are combined using the specified strategy,
    and then the results are reduced via mean.
    Please refer to the documentation to understand the different multi-score strategies.

    If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable".
    In such cases, the configuration is not considered for further evaluation.

    Artifacts are stored under `{artifact_dir}/{tune_name}`.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `early_stopping_patience` specifies how many epochs to wait for improvement before stopping.
            In epochs, this would be `check_val_every_n_epoch * early_stopping_patience`.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.
    Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch.
    If `log_every_n_steps` is set to 50 then the training logs and metrics will be logged 4 times per epoch.
    If `check_val_every_n_epoch` is set to 5 then validation will be performed every 5 epochs.
    If `plot_every_n_val_epochs` is set to 2 then validation samples will be plotted every 10 epochs.
    If `early_stopping_patience` is set to 3 then early stopping will be performed after 15 epochs without improvement.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── data.zarr/ # this zarr group contains the dataarrays x and y
    ├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
    └── labels.geojson
    ```

    Args:
        name (str | None, optional): Name of the tuning run.
            Will be generated based on the number of existing directories in the artifact directory if None.
            Defaults to None.
        n_trials (int | Literal["grid"], optional): Number of trials to perform in hyperparameter tuning.
            If "grid", span a grid search over all configured hyperparameters.
            In a grid search, only constant or choice hyperparameters are allowed.
            Defaults to 100.
        retrain_and_test (bool, optional): Whether to retrain the model with the best hyperparameters and test it.
            Defaults to False.
        cv_config (CrossValidationConfig, optional): Configuration for cross-validation.
            Defaults to CrossValidationConfig().
        training_config (TrainingConfig, optional): Configuration for training.
            Defaults to TrainingConfig().
        data_config (DataConfig, optional): Configuration for data.
            Defaults to DataConfig().
        device_config (DeviceConfig, optional): Configuration for device.
            Defaults to DeviceConfig().
        logging_config (LoggingConfig, optional): Configuration for logging.
            Defaults to LoggingConfig().
        hpconfig (Path | None, optional): Path to the hyperparameter configuration file.
            Please see the documentation of `hyperparameters` for more information.
            Defaults to None.
        config_file (Path | None, optional): Path to the configuration file. If provided,
            it will be used instead of `hpconfig` if `hpconfig` is None. Defaults to None.

    Returns:
        tuple[float, pd.DataFrame]: The best score (if retrained and tested) and the run infos of all runs.

    Raises:
        ValueError: If no hyperparameter configuration file is provided.

    """
    import pandas as pd
    from darts_utils.namegen import generate_counted_name

    from darts_segmentation.training.adp import _adp
    from darts_segmentation.training.hparams import parse_hyperparameters, sample_hyperparameters
    from darts_segmentation.training.scoring import score_from_single_run
    from darts_segmentation.training.train import test_smp, train_smp

    tick_fstart = time.perf_counter()

    tune_name = name or generate_counted_name(logging_config.artifact_dir)
    artifact_dir = logging_config.artifact_dir / tune_name
    run_infos_file = artifact_dir / f"{tune_name}.parquet"

    # Check if the artifact directory is empty
    assert not artifact_dir.exists(), f"{artifact_dir} already exists."
    artifact_dir.mkdir(parents=True, exist_ok=True)

    hpconfig = hpconfig or config_file
    if hpconfig is None:
        raise ValueError(
            "No hyperparameter configuration file provided. Please provide a valid file via the `--hpconfig` flag."
        )
    param_grid = parse_hyperparameters(hpconfig)
    logger.debug(f"Parsed hyperparameter grid: {param_grid}")
    param_list = sample_hyperparameters(param_grid, n_trials)

    logger.info(
        f"Starting tune '{tune_name}' with data from {data_config.train_data_dir.resolve()}."
        f" Artifacts will be saved to {artifact_dir.resolve()}."
        f" Will run n_trials*n_randoms*n_folds ="
        f" {len(param_list)}*{cv_config.n_randoms}*{cv_config.n_folds} ="
        f" {len(param_list) * cv_config.n_randoms * cv_config.n_folds} experiments."
    )

    # Plan which runs to perform. These are later consumed based on the parallelization strategy.
    process_inputs = [
        _ProcessInputs(
            current=i,
            total=len(param_list),
            tune_name=tune_name,
            cv=cv_config,
            training_config=training_config,
            logging_config=logging_config,
            data_config=data_config,
            device_config=device_config,
            hparams=hparams,
        )
        for i, hparams in enumerate(param_list)
    ]

    run_infos: list[pd.DataFrame] = []
    best_score = 0
    best_hp = None

    # This function abstracts away common logic for running multiprocessing
    for inp, output in _adp(
        process_inputs=process_inputs,
        is_parallel=device_config.strategy == "tune-parallel",
        devices=device_config.devices,
        available_devices=available_devices,
        _run=_run_cv,
    ):
        run_infos.append(output.run_infos)
        if not output.is_unstable and output.score > best_score:
            best_score = output.score
            best_hp = inp.hparams

        # Save already here to prevent data loss if something goes wrong
        pd.concat(run_infos).reset_index(drop=True).to_parquet(run_infos_file)
        logger.debug(f"Saved run infos to {run_infos_file}")

    if len(run_infos) == 0:
        logger.error("No hyperparameters resulted in a valid score. Please check the logs for more information.")
        return 0, run_infos

    run_infos = pd.concat(run_infos).reset_index(drop=True)

    tick_fend = time.perf_counter()

    if best_hp is None:
        logger.warning(
            f"Tuning completed in {tick_fend - tick_fstart:.2f}s."
            " No hyperparameters resulted in a valid score. Please check the logs for more information."
        )
        return 0, run_infos
    logger.info(
        f"Tuning completed in {tick_fend - tick_fstart:.2f}s. The best score was {best_score:.4f} with {best_hp}."
    )

    # =====================
    # === End of tuning ===
    # =====================

    if not retrain_and_test:
        return 0, run_infos

    logger.info("Starting retraining with the best hyperparameters.")

    tick_fstart = time.perf_counter()
    trainer = train_smp(
        run=TrainRunConfig(name=f"{tune_name}-retrain"),
        training_config=training_config,  # TODO: device and strategy
        data_config=DataConfig(
            train_data_dir=data_config.train_data_dir,
            data_split_method=data_config.data_split_method,
            data_split_by=data_config.data_split_by,
            fold_method=None,  # No fold method for retraining
            total_folds=None,  # No folds for retraining
        ),
        logging_config=LoggingConfig(
            artifact_dir=artifact_dir,
            log_every_n_steps=logging_config.log_every_n_steps,
            check_val_every_n_epoch=logging_config.check_val_every_n_epoch,
            plot_every_n_val_epochs=logging_config.plot_every_n_val_epochs,
            wandb_entity=logging_config.wandb_entity,
            wandb_project=logging_config.wandb_project,
        ),
        hparams=best_hp,
    )
    run_id = trainer.lightning_module.hparams["run_id"]
    trainer = test_smp(
        train_data_dir=data_config.train_data_dir,
        run_id=run_id,
        run_name=f"{tune_name}-retrain",
        model_ckp=trainer.checkpoint_callback.best_model_path,
        batch_size=best_hp.batch_size,
        data_split_method=data_config.data_split_method,
        data_split_by=data_config.data_split_by,
        artifact_dir=artifact_dir,
        num_workers=training_config.num_workers,
        device_config=device_config,
        wandb_entity=logging_config.wandb_entity,
        wandb_project=logging_config.wandb_project,
    )

    run_info = {k: v.item() for k, v in trainer.callback_metrics.items()}
    test_scoring_metric = (
        cv_config.scoring_metric.replace("val/", "test/")
        if isinstance(cv_config.scoring_metric, str)
        else [sm.replace("val/", "test/") for sm in cv_config.scoring_metric]
    )
    score = score_from_single_run(run_info, test_scoring_metric, cv_config.multi_score_strategy)
    is_unstable = check_score_is_unstable(run_info, cv_config.scoring_metric)
    tick_fend = time.perf_counter()
    logger.info(
        f"Retraining and testing completed successfully in {tick_fend - tick_fstart:.2f}s"
        f" with {score=:.4f} ({'stable' if not is_unstable else 'unstable'})."
    )

    return score, run_infos