Skip to content

darts.legacy_training.wandb_sweep_smp

Create a sweep with wandb and run it on the specified cuda device, or continue an existing sweep.

If sweep_id is None, a new sweep will be created. Otherwise, the sweep with the given ID will be continued. All artifacts are gathered under nested directory based on the sweep id: {artifact_dir}/sweep-{sweep_id}. Since each sweep-configuration has (currently) an own name and id, a single run can be found under: {artifact_dir}/sweep-{sweep_id}/{run_name}/{run_id}. Read the training-docs for more info.

If a cuda_device is specified, run an agent on this device. If None, do nothing.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs.

This will NOT use cross-validation. For cross-validation, use optuna_sweep_smp.

Example

In one terminal, start a sweep:

    $ rye run darts wandb-sweep-smp --config-file /path/to/sweep-config.toml
    ...  # Many logs
    Created sweep with ID 123456789
    ... # More logs from spawned agent

In another terminal, start an a second agent:

    $ rye run darts wandb-sweep-smp --sweep-id 123456789
    ...

Parameters:

  • train_data_dir (pathlib.Path) –

    Path to the training data directory.

  • sweep_config (pathlib.Path) –

    Path to the sweep yaml configuration file. Must contain a valid wandb sweep configuration. Hyperparameters must contain the following fields: model_arch, model_encoder, augment, gamma, batch_size. Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information.

  • n_trials (int, default: 10 ) –

    Number of runs to execute. Defaults to 10.

  • sweep_id (str | None, default: None ) –

    The ID of the sweep. If None, a new sweep will be created. Defaults to None.

  • artifact_dir (pathlib.Path, default: pathlib.Path('lightning_logs') ) –

    Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs").

  • max_epochs (int, default: 100 ) –

    Maximum number of epochs to train. Defaults to 100.

  • log_every_n_steps (int, default: 10 ) –

    Log every n steps. Defaults to 10.

  • check_val_every_n_epoch (int, default: 3 ) –

    Check validation every n epochs. Defaults to 3.

  • plot_every_n_val_epochs (int, default: 5 ) –

    Plot validation samples every n epochs. Defaults to 5.

  • num_workers (int, default: 0 ) –

    Number of Dataloader workers. Defaults to 0.

  • device (int | str | None, default: None ) –

    The device to run the model on. Defaults to None.

  • wandb_entity (str | None, default: None ) –

    Weights and Biases Entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    Weights and Biases Project. Defaults to None.

Source code in darts/src/darts/legacy_training/train.py
def wandb_sweep_smp(
    *,
    # Data and sweep config
    train_data_dir: Path,
    sweep_config: Path,
    n_trials: int = 10,
    sweep_id: str | None = None,
    artifact_dir: Path = Path("lightning_logs"),
    # Epoch and Logging config
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    plot_every_n_val_epochs: int = 5,
    # Device and Manager config
    num_workers: int = 0,
    device: int | str | None = None,
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
):
    """Create a sweep with wandb and run it on the specified cuda device, or continue an existing sweep.

    If `sweep_id` is None, a new sweep will be created. Otherwise, the sweep with the given ID will be continued.
    All artifacts are gathered under nested directory based on the sweep id: {artifact_dir}/sweep-{sweep_id}.
    Since each sweep-configuration has (currently) an own name and id, a single run can be found under:
    {artifact_dir}/sweep-{sweep_id}/{run_name}/{run_id}. Read the training-docs for more info.

    If a `cuda_device` is specified, run an agent on this device. If None, do nothing.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.

    This will NOT use cross-validation. For cross-validation, use `optuna_sweep_smp`.

    Example:
        In one terminal, start a sweep:
        ```sh
            $ rye run darts wandb-sweep-smp --config-file /path/to/sweep-config.toml
            ...  # Many logs
            Created sweep with ID 123456789
            ... # More logs from spawned agent
        ```

        In another terminal, start an a second agent:
        ```sh
            $ rye run darts wandb-sweep-smp --sweep-id 123456789
            ...
        ```

    Args:
        train_data_dir (Path): Path to the training data directory.
        sweep_config (Path): Path to the sweep yaml configuration file. Must contain a valid wandb sweep configuration.
            Hyperparameters must contain the following fields: `model_arch`, `model_encoder`, `augment`, `gamma`,
            `batch_size`.
            Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information.
        n_trials (int, optional): Number of runs to execute. Defaults to 10.
        sweep_id (str | None, optional): The ID of the sweep. If None, a new sweep will be created. Defaults to None.
        artifact_dir (Path, optional): Path to the training output directory.
            Will contain checkpoints and metrics. Defaults to Path("lightning_logs").
        max_epochs (int, optional): Maximum number of epochs to train. Defaults to 100.
        log_every_n_steps (int, optional): Log every n steps. Defaults to 10.
        check_val_every_n_epoch (int, optional): Check validation every n epochs. Defaults to 3.
        plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.
        num_workers (int, optional): Number of Dataloader workers. Defaults to 0.
        device (int | str | None, optional): The device to run the model on. Defaults to None.
        wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None.
        wandb_project (str | None, optional): Weights and Biases Project. Defaults to None.

    """
    import wandb

    # Wandb has a stupid way of logging (they log per default with click.echo to stdout)
    # We need to silence this and redirect all possible logs to our logger
    # wl = wandb.setup({"silent": True})
    # wandb.termsetup(wl.settings, logging.getLogger("wandb"))
    # LoggingManager.apply_logging_handlers("wandb")

    if sweep_id is not None and device is None:
        logger.warning("Continuing a sweep without specifying a device will not do anything.")

    with sweep_config.open("r") as f:
        sweep_configuration = yaml.safe_load(f)

    logger.debug(f"Loaded sweep configuration from {sweep_config.resolve()}:\n{sweep_configuration}")

    if sweep_id is None:
        sweep_id = wandb.sweep(sweep=sweep_configuration, project=wandb_project, entity=wandb_entity)
        logger.info(f"Created sweep with ID {sweep_id}")
        logger.info("To start a sweep agents, use the following command:")
        logger.info(f"$ rye run darts sweep_smp --sweep-id {sweep_id}")

    artifact_dir = artifact_dir / f"sweep-{sweep_id}"
    artifact_dir.mkdir(parents=True, exist_ok=True)

    def run():
        run = wandb.init(config=sweep_configuration)
        # We need to manually log the run data since the wandb logger only logs to its own logs and click
        logger.info(f"Starting sweep run '{run.settings.run_name}'")
        logger.debug(f"Run data is saved locally in {Path(run.settings.sync_dir).resolve()}")
        logger.debug(f"View project at {run.settings.project_url}")
        logger.debug(f"View sweep at {run.settings.sweep_url}")
        logger.debug(f"View run at {run.settings.run_url}")

        # We set the default weights to None, to be able to use different architectures
        model_encoder_weights = None
        # We set early stopping to None, because wandb will handle the early stopping
        early_stopping_patience = None
        learning_rate = wandb.config["learning_rate"]
        gamma = wandb.config["gamma"]
        batch_size = wandb.config["batch_size"]
        model_arch = wandb.config["model_arch"]
        model_encoder = wandb.config["model_encoder"]
        augment = wandb.config["augment"]
        focal_loss_alpha = wandb.config["focal_loss_alpha"]
        focal_loss_gamma = wandb.config["focal_loss_gamma"]
        fold = wandb.config.get("fold", 0)
        random_seed = wandb.config.get("random_seed", 42)

        train_smp(
            # Data config
            train_data_dir=train_data_dir,
            artifact_dir=artifact_dir,
            fold=fold,
            # Hyperparameters
            model_arch=model_arch,
            model_encoder=model_encoder,
            model_encoder_weights=model_encoder_weights,
            augment=augment,
            learning_rate=learning_rate,
            gamma=gamma,
            focal_loss_alpha=focal_loss_alpha,
            focal_loss_gamma=focal_loss_gamma,
            batch_size=batch_size,
            # Epoch and Logging config
            early_stopping_patience=early_stopping_patience,
            max_epochs=max_epochs,
            log_every_n_steps=log_every_n_steps,
            check_val_every_n_epoch=check_val_every_n_epoch,
            plot_every_n_val_epochs=plot_every_n_val_epochs,
            # Device and Manager config
            random_seed=random_seed,
            num_workers=num_workers,
            device=device,
            wandb_entity=wandb_entity,
            wandb_project=wandb_project,
            run_name=wandb.run.name,
            run_id=wandb.run.id,
        )

    if device is None:
        logger.info("No device specified, closing script...")
        return

    logger.info("Starting a default sweep agent")
    wandb.agent(sweep_id, function=run, count=n_trials, project=wandb_project, entity=wandb_entity)