Skip to content

darts.legacy_training.optuna_sweep_smp

Create an optuna sweep and run it on the specified cuda device, or continue an existing sweep.

If sweep_id already exists in sweep_db, the sweep will be continued. Otherwise, a new sweep will be created.

If a cuda_device is specified, run an agent on this device. If None, do nothing.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs.

This will use cross-validation.

Example

In one terminal, start a sweep:

    $ rye run darts sweep-smp --config-file /path/to/sweep-config.toml
    ...  # Many logs
    Created sweep with ID 123456789
    ... # More logs from spawned agent

In another terminal, start an a second agent:

    $ rye run darts sweep-smp --sweep-id 123456789
    ...

Parameters:

  • train_data_dir (pathlib.Path) –

    Path to the training data directory.

  • sweep_config (pathlib.Path) –

    Path to the sweep yaml configuration file. Must contain a valid wandb sweep configuration. Hyperparameters must contain the following fields: model_arch, model_encoder, augment, gamma, batch_size. Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information.

  • n_trials (int, default: 10 ) –

    Number of runs to execute. Defaults to 10.

  • sweep_db (str | None, default: None ) –

    Path to the optuna database. If None, a new database will be created.

  • sweep_id (str | None, default: None ) –

    The ID of the sweep. If None, a new sweep will be created. Defaults to None.

  • n_folds ((int, optinoal), default: 5 ) –

    Number of folds in cross-validation. Max 5. Defaults to 5.

  • n_randoms (int, default: 3 ) –

    Number of repetitions with different random-seeds. First 3 are always "42", "21" and "69" for better default comparibility with rest of this pipeline. Rest are pseudo-random generated beforehand, hence always equal. Defaults to 5.

  • artifact_dir (pathlib.Path, default: pathlib.Path('lightning_logs') ) –

    Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs").

  • max_epochs (int, default: 100 ) –

    Maximum number of epochs to train. Defaults to 100.

  • log_every_n_steps (int, default: 10 ) –

    Log every n steps. Defaults to 10.

  • check_val_every_n_epoch (int, default: 3 ) –

    Check validation every n epochs. Defaults to 3.

  • plot_every_n_val_epochs (int, default: 5 ) –

    Plot validation samples every n epochs. Defaults to 5.

  • num_workers (int, default: 0 ) –

    Number of Dataloader workers. Defaults to 0.

  • device (int | str | None, default: None ) –

    The device to run the model on. Defaults to None.

  • wandb_entity (str | None, default: None ) –

    Weights and Biases Entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    Weights and Biases Project. Defaults to None.

  • model_arch (str, default: 'Unet' ) –

    Model architecture to use. Defaults to "Unet".

  • model_encoder (str, default: 'dpn107' ) –

    Encoder to use. Defaults to "dpn107".

  • augment (bool, default: True ) –

    Weather to apply augments or not. Defaults to True.

  • learning_rate (float, default: 0.001 ) –

    Learning Rate. Defaults to 1e-3.

  • gamma (float, default: 0.9 ) –

    Multiplicative factor of learning rate decay. Defaults to 0.9.

  • focal_loss_alpha (float, default: None ) –

    Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.

  • focal_loss_gamma (float, default: 2.0 ) –

    Focal loss power factor. Defaults to 2.0.

  • batch_size (int, default: 8 ) –

    Batch Size. Defaults to 8.

Source code in darts/src/darts/legacy_training/sweep.py
def optuna_sweep_smp(
    *,
    # Data and sweep config
    train_data_dir: Path,
    sweep_config: Path,
    n_trials: int = 10,
    sweep_db: str | None = None,
    sweep_id: str | None = None,
    n_folds: int = 5,
    n_randoms: int = 3,
    artifact_dir: Path = Path("lightning_logs"),
    # Epoch and Logging config
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    plot_every_n_val_epochs: int = 5,
    # Device and Manager config
    num_workers: int = 0,
    device: int | str | None = None,
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
    # Hyperparameters (default values if not provided by sweep-config)
    model_arch: str = "Unet",
    model_encoder: str = "dpn107",
    augment: bool = True,
    learning_rate: float = 1e-3,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    batch_size: int = 8,
):
    """Create an optuna sweep and run it on the specified cuda device, or continue an existing sweep.

    If `sweep_id` already exists in `sweep_db`, the sweep will be continued. Otherwise, a new sweep will be created.

    If a `cuda_device` is specified, run an agent on this device. If None, do nothing.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.

    This will use cross-validation.

    Example:
        In one terminal, start a sweep:
        ```sh
            $ rye run darts sweep-smp --config-file /path/to/sweep-config.toml
            ...  # Many logs
            Created sweep with ID 123456789
            ... # More logs from spawned agent
        ```

        In another terminal, start an a second agent:
        ```sh
            $ rye run darts sweep-smp --sweep-id 123456789
            ...
        ```

    Args:
        train_data_dir (Path): Path to the training data directory.
        sweep_config (Path): Path to the sweep yaml configuration file. Must contain a valid wandb sweep configuration.
            Hyperparameters must contain the following fields: `model_arch`, `model_encoder`, `augment`, `gamma`,
            `batch_size`.
            Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information.
        n_trials (int, optional): Number of runs to execute. Defaults to 10.
        sweep_db (str | None, optional): Path to the optuna database. If None, a new database will be created.
        sweep_id (str | None, optional): The ID of the sweep. If None, a new sweep will be created. Defaults to None.
        n_folds (int, optinoal): Number of folds in cross-validation. Max 5. Defaults to 5.
        n_randoms (int, optional): Number of repetitions with different random-seeds.
            First 3 are always "42", "21" and "69" for better default comparibility with rest of this pipeline.
            Rest are pseudo-random generated beforehand, hence always equal.
            Defaults to 5.
        artifact_dir (Path, optional): Path to the training output directory.
            Will contain checkpoints and metrics. Defaults to Path("lightning_logs").
        max_epochs (int, optional): Maximum number of epochs to train. Defaults to 100.
        log_every_n_steps (int, optional): Log every n steps. Defaults to 10.
        check_val_every_n_epoch (int, optional): Check validation every n epochs. Defaults to 3.
        plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.
        num_workers (int, optional): Number of Dataloader workers. Defaults to 0.
        device (int | str | None, optional): The device to run the model on. Defaults to None.
        wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None.
        wandb_project (str | None, optional): Weights and Biases Project. Defaults to None.
        model_arch (str, optional): Model architecture to use. Defaults to "Unet".
        model_encoder (str, optional): Encoder to use. Defaults to "dpn107".
        augment (bool, optional): Weather to apply augments or not. Defaults to True.
        learning_rate (float, optional): Learning Rate. Defaults to 1e-3.
        gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.9.
        focal_loss_alpha (float, optional): Weight factor to balance positive and negative samples.
            Alpha must be in [0...1] range, high values will give more weight to positive class.
            None will not weight samples. Defaults to None.
        focal_loss_gamma (float, optional): Focal loss power factor. Defaults to 2.0.
        batch_size (int, optional): Batch Size. Defaults to 8.

    """
    import optuna
    from names_generator import generate_name

    from darts.legacy_training.util import suggest_optuna_params_from_wandb_config

    with sweep_config.open("r") as f:
        sweep_configuration = yaml.safe_load(f)

    logger.debug(f"Loaded sweep configuration from {sweep_config.resolve()}:\n{sweep_configuration}")

    # Create a new study-id if none is given
    if sweep_id is None:
        sweep_id = f"sweep-{generate_name('hyphen')}"
        logger.info(f"Generated new sweep ID: {sweep_id}")
        logger.info("To start a sweep agents, use the following command:")
        logger.info(f"$ rye run darts optuna-sweep-smp --sweep-id {sweep_id}")

    artifact_dir = artifact_dir / sweep_id
    artifact_dir.mkdir(parents=True, exist_ok=True)

    def objective(trial):
        hparams = suggest_optuna_params_from_wandb_config(trial, sweep_configuration)
        logger.info(f"Running trial with parameters: {hparams}")

        # Get the trial a more readable name
        trial_name = f"{generate_name(style='hyphen')}-{trial.number}"

        # We set the default weights to None, to be able to use different architectures
        model_encoder_weights = None
        # We set early stopping to None, because wandb will handle the early stopping
        early_stopping_patience = None

        # Overwrite the default values with the suggested ones, if they are present
        learning_rate_trial = hparams.get("learning_rate", learning_rate)
        gamma_trial = hparams.get("gamma", gamma)
        focal_loss_alpha_trial = hparams.get("focal_loss_alpha", focal_loss_alpha)
        focal_loss_gamma_trial = hparams.get("focal_loss_gamma", focal_loss_gamma)
        batch_size_trial = hparams.get("batch_size", batch_size)
        model_arch_trial = hparams.get("model_arch", model_arch)
        model_encoder_trial = hparams.get("model_encoder", model_encoder)
        augment_trial = hparams.get("augment", augment)

        crossval_scores = defaultdict(list)

        folds = list(range(n_folds))
        rng = random.Random(42)
        seeds = [42, 21, 69]
        if n_randoms > 3:
            seeds += rng.sample(range(9999), n_randoms - 3)
        elif n_randoms < 3:
            seeds = seeds[:n_randoms]

        for random_seed in seeds:
            for fold in folds:
                logger.info(f"Running cross-validation fold {fold}")
                _gather_and_reset_wandb_env()
                trainer = train_smp(
                    # Data config
                    train_data_dir=train_data_dir,
                    artifact_dir=artifact_dir,
                    fold=fold,
                    random_seed=random_seed,
                    # Hyperparameters
                    model_arch=model_arch_trial,
                    model_encoder=model_encoder_trial,
                    model_encoder_weights=model_encoder_weights,
                    augment=augment_trial,
                    learning_rate=learning_rate_trial,
                    gamma=gamma_trial,
                    focal_loss_alpha=focal_loss_alpha_trial,
                    focal_loss_gamma=focal_loss_gamma_trial,
                    batch_size=batch_size_trial,
                    # Epoch and Logging config
                    early_stopping_patience=early_stopping_patience,
                    max_epochs=max_epochs,
                    log_every_n_steps=log_every_n_steps,
                    check_val_every_n_epoch=check_val_every_n_epoch,
                    plot_every_n_val_epochs=plot_every_n_val_epochs,
                    # Device and Manager config
                    num_workers=num_workers,
                    device=device,
                    wandb_entity=wandb_entity,
                    wandb_project=wandb_project,
                    wandb_group=sweep_id,
                    trial_name=trial_name,
                    run_name=f"{trial_name}-f{fold}r{random_seed}",
                )
                for metric, value in trainer.callback_metrics.items():
                    crossval_scores[metric].append(value.item())

        logger.debug(f"Cross-validation scores: {crossval_scores}")
        crossval_jaccard = mean(crossval_scores["val/JaccardIndex"])
        crossval_recall = mean(crossval_scores["val/Recall"])

        return crossval_jaccard, crossval_recall

    study = optuna.create_study(
        storage=sweep_db,
        study_name=sweep_id,
        directions=["maximize", "maximize"],
        load_if_exists=True,
    )

    if device is None:
        logger.info("No device specified, closing script...")
        return

    logger.info("Starting optimizing")
    study.optimize(objective, n_trials=n_trials)