Skip to content

darts.legacy_training

Legacy training module for DARTS.

Functions:

convert_lightning_checkpoint

convert_lightning_checkpoint(
    *,
    lightning_checkpoint: pathlib.Path,
    out_directory: pathlib.Path,
    checkpoint_name: str,
    framework: str = "smp",
)

Convert a lightning checkpoint to our own format.

The final checkpoint will contain the model configuration and the state dict. It will be saved to:

    out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"

Parameters:

  • lightning_checkpoint (pathlib.Path) –

    Path to the lightning checkpoint.

  • out_directory (pathlib.Path) –

    Output directory for the converted checkpoint.

  • checkpoint_name (str) –

    A unique name of the new checkpoint.

  • framework (str, default: 'smp' ) –

    The framework used for the model. Defaults to "smp".

Source code in darts/src/darts/legacy_training/util.py
def convert_lightning_checkpoint(
    *,
    lightning_checkpoint: Path,
    out_directory: Path,
    checkpoint_name: str,
    framework: str = "smp",
):
    """Convert a lightning checkpoint to our own format.

    The final checkpoint will contain the model configuration and the state dict.
    It will be saved to:

    ```python
        out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"
    ```

    Args:
        lightning_checkpoint (Path): Path to the lightning checkpoint.
        out_directory (Path): Output directory for the converted checkpoint.
        checkpoint_name (str): A unique name of the new checkpoint.
        framework (str, optional): The framework used for the model. Defaults to "smp".

    """
    import torch

    logger.debug(f"Loading checkpoint from {lightning_checkpoint.resolve()}")
    lckpt = torch.load(lightning_checkpoint, weights_only=False, map_location=torch.device("cpu"))

    now = datetime.now()
    formatted_date = now.strftime("%Y-%m-%d")
    config = lckpt["hyper_parameters"]["config"]
    del config["model"]["encoder_weights"]
    config["time"] = formatted_date
    config["name"] = checkpoint_name
    config["model_framework"] = framework

    statedict = lckpt["state_dict"]
    # Statedict has model. prefix before every weight. We need to remove them. This is an in-place function
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(statedict, "model.")

    own_ckpt = {
        "config": config,
        "statedict": lckpt["state_dict"],
    }

    out_directory.mkdir(exist_ok=True, parents=True)

    out_checkpoint = out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"

    torch.save(own_ckpt, out_checkpoint)

    logger.info(f"Saved converted checkpoint to {out_checkpoint.resolve()}")

optuna_sweep_smp

optuna_sweep_smp(
    *,
    train_data_dir: pathlib.Path,
    sweep_config: pathlib.Path,
    n_trials: int = 10,
    sweep_db: str | None = None,
    sweep_id: str | None = None,
    n_folds: int = 5,
    n_randoms: int = 3,
    artifact_dir: pathlib.Path = pathlib.Path(
        "lightning_logs"
    ),
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    plot_every_n_val_epochs: int = 5,
    num_workers: int = 0,
    device: int | str | None = None,
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
    model_arch: str = "Unet",
    model_encoder: str = "dpn107",
    augment: bool = True,
    learning_rate: float = 0.001,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    batch_size: int = 8,
)

Create an optuna sweep and run it on the specified cuda device, or continue an existing sweep.

If sweep_id already exists in sweep_db, the sweep will be continued. Otherwise, a new sweep will be created.

If a cuda_device is specified, run an agent on this device. If None, do nothing.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs.

This will use cross-validation.

Example

In one terminal, start a sweep:

    $ rye run darts sweep-smp --config-file /path/to/sweep-config.toml
    ...  # Many logs
    Created sweep with ID 123456789
    ... # More logs from spawned agent

In another terminal, start an a second agent:

    $ rye run darts sweep-smp --sweep-id 123456789
    ...

Parameters:

  • train_data_dir (pathlib.Path) –

    Path to the training data directory.

  • sweep_config (pathlib.Path) –

    Path to the sweep yaml configuration file. Must contain a valid wandb sweep configuration. Hyperparameters must contain the following fields: model_arch, model_encoder, augment, gamma, batch_size. Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information.

  • n_trials (int, default: 10 ) –

    Number of runs to execute. Defaults to 10.

  • sweep_db (str | None, default: None ) –

    Path to the optuna database. If None, a new database will be created.

  • sweep_id (str | None, default: None ) –

    The ID of the sweep. If None, a new sweep will be created. Defaults to None.

  • n_folds ((int, optinoal), default: 5 ) –

    Number of folds in cross-validation. Max 5. Defaults to 5.

  • n_randoms (int, default: 3 ) –

    Number of repetitions with different random-seeds. First 3 are always "42", "21" and "69" for better default comparibility with rest of this pipeline. Rest are pseudo-random generated beforehand, hence always equal. Defaults to 5.

  • artifact_dir (pathlib.Path, default: pathlib.Path('lightning_logs') ) –

    Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs").

  • max_epochs (int, default: 100 ) –

    Maximum number of epochs to train. Defaults to 100.

  • log_every_n_steps (int, default: 10 ) –

    Log every n steps. Defaults to 10.

  • check_val_every_n_epoch (int, default: 3 ) –

    Check validation every n epochs. Defaults to 3.

  • plot_every_n_val_epochs (int, default: 5 ) –

    Plot validation samples every n epochs. Defaults to 5.

  • num_workers (int, default: 0 ) –

    Number of Dataloader workers. Defaults to 0.

  • device (int | str | None, default: None ) –

    The device to run the model on. Defaults to None.

  • wandb_entity (str | None, default: None ) –

    Weights and Biases Entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    Weights and Biases Project. Defaults to None.

  • model_arch (str, default: 'Unet' ) –

    Model architecture to use. Defaults to "Unet".

  • model_encoder (str, default: 'dpn107' ) –

    Encoder to use. Defaults to "dpn107".

  • augment (bool, default: True ) –

    Weather to apply augments or not. Defaults to True.

  • learning_rate (float, default: 0.001 ) –

    Learning Rate. Defaults to 1e-3.

  • gamma (float, default: 0.9 ) –

    Multiplicative factor of learning rate decay. Defaults to 0.9.

  • focal_loss_alpha (float, default: None ) –

    Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.

  • focal_loss_gamma (float, default: 2.0 ) –

    Focal loss power factor. Defaults to 2.0.

  • batch_size (int, default: 8 ) –

    Batch Size. Defaults to 8.

Source code in darts/src/darts/legacy_training/sweep.py
def optuna_sweep_smp(
    *,
    # Data and sweep config
    train_data_dir: Path,
    sweep_config: Path,
    n_trials: int = 10,
    sweep_db: str | None = None,
    sweep_id: str | None = None,
    n_folds: int = 5,
    n_randoms: int = 3,
    artifact_dir: Path = Path("lightning_logs"),
    # Epoch and Logging config
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    plot_every_n_val_epochs: int = 5,
    # Device and Manager config
    num_workers: int = 0,
    device: int | str | None = None,
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
    # Hyperparameters (default values if not provided by sweep-config)
    model_arch: str = "Unet",
    model_encoder: str = "dpn107",
    augment: bool = True,
    learning_rate: float = 1e-3,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    batch_size: int = 8,
):
    """Create an optuna sweep and run it on the specified cuda device, or continue an existing sweep.

    If `sweep_id` already exists in `sweep_db`, the sweep will be continued. Otherwise, a new sweep will be created.

    If a `cuda_device` is specified, run an agent on this device. If None, do nothing.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.

    This will use cross-validation.

    Example:
        In one terminal, start a sweep:
        ```sh
            $ rye run darts sweep-smp --config-file /path/to/sweep-config.toml
            ...  # Many logs
            Created sweep with ID 123456789
            ... # More logs from spawned agent
        ```

        In another terminal, start an a second agent:
        ```sh
            $ rye run darts sweep-smp --sweep-id 123456789
            ...
        ```

    Args:
        train_data_dir (Path): Path to the training data directory.
        sweep_config (Path): Path to the sweep yaml configuration file. Must contain a valid wandb sweep configuration.
            Hyperparameters must contain the following fields: `model_arch`, `model_encoder`, `augment`, `gamma`,
            `batch_size`.
            Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information.
        n_trials (int, optional): Number of runs to execute. Defaults to 10.
        sweep_db (str | None, optional): Path to the optuna database. If None, a new database will be created.
        sweep_id (str | None, optional): The ID of the sweep. If None, a new sweep will be created. Defaults to None.
        n_folds (int, optinoal): Number of folds in cross-validation. Max 5. Defaults to 5.
        n_randoms (int, optional): Number of repetitions with different random-seeds.
            First 3 are always "42", "21" and "69" for better default comparibility with rest of this pipeline.
            Rest are pseudo-random generated beforehand, hence always equal.
            Defaults to 5.
        artifact_dir (Path, optional): Path to the training output directory.
            Will contain checkpoints and metrics. Defaults to Path("lightning_logs").
        max_epochs (int, optional): Maximum number of epochs to train. Defaults to 100.
        log_every_n_steps (int, optional): Log every n steps. Defaults to 10.
        check_val_every_n_epoch (int, optional): Check validation every n epochs. Defaults to 3.
        plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.
        num_workers (int, optional): Number of Dataloader workers. Defaults to 0.
        device (int | str | None, optional): The device to run the model on. Defaults to None.
        wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None.
        wandb_project (str | None, optional): Weights and Biases Project. Defaults to None.
        model_arch (str, optional): Model architecture to use. Defaults to "Unet".
        model_encoder (str, optional): Encoder to use. Defaults to "dpn107".
        augment (bool, optional): Weather to apply augments or not. Defaults to True.
        learning_rate (float, optional): Learning Rate. Defaults to 1e-3.
        gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.9.
        focal_loss_alpha (float, optional): Weight factor to balance positive and negative samples.
            Alpha must be in [0...1] range, high values will give more weight to positive class.
            None will not weight samples. Defaults to None.
        focal_loss_gamma (float, optional): Focal loss power factor. Defaults to 2.0.
        batch_size (int, optional): Batch Size. Defaults to 8.

    """
    import optuna
    from names_generator import generate_name

    from darts.legacy_training.util import suggest_optuna_params_from_wandb_config

    with sweep_config.open("r") as f:
        sweep_configuration = yaml.safe_load(f)

    logger.debug(f"Loaded sweep configuration from {sweep_config.resolve()}:\n{sweep_configuration}")

    # Create a new study-id if none is given
    if sweep_id is None:
        sweep_id = f"sweep-{generate_name('hyphen')}"
        logger.info(f"Generated new sweep ID: {sweep_id}")
        logger.info("To start a sweep agents, use the following command:")
        logger.info(f"$ rye run darts optuna-sweep-smp --sweep-id {sweep_id}")

    artifact_dir = artifact_dir / sweep_id
    artifact_dir.mkdir(parents=True, exist_ok=True)

    def objective(trial):
        hparams = suggest_optuna_params_from_wandb_config(trial, sweep_configuration)
        logger.info(f"Running trial with parameters: {hparams}")

        # Get the trial a more readable name
        trial_name = f"{generate_name(style='hyphen')}-{trial.number}"

        # We set the default weights to None, to be able to use different architectures
        model_encoder_weights = None
        # We set early stopping to None, because wandb will handle the early stopping
        early_stopping_patience = None

        # Overwrite the default values with the suggested ones, if they are present
        learning_rate_trial = hparams.get("learning_rate", learning_rate)
        gamma_trial = hparams.get("gamma", gamma)
        focal_loss_alpha_trial = hparams.get("focal_loss_alpha", focal_loss_alpha)
        focal_loss_gamma_trial = hparams.get("focal_loss_gamma", focal_loss_gamma)
        batch_size_trial = hparams.get("batch_size", batch_size)
        model_arch_trial = hparams.get("model_arch", model_arch)
        model_encoder_trial = hparams.get("model_encoder", model_encoder)
        augment_trial = hparams.get("augment", augment)

        crossval_scores = defaultdict(list)

        folds = list(range(n_folds))
        rng = random.Random(42)
        seeds = [42, 21, 69]
        if n_randoms > 3:
            seeds += rng.sample(range(9999), n_randoms - 3)
        elif n_randoms < 3:
            seeds = seeds[:n_randoms]

        for random_seed in seeds:
            for fold in folds:
                logger.info(f"Running cross-validation fold {fold}")
                _gather_and_reset_wandb_env()
                trainer = train_smp(
                    # Data config
                    train_data_dir=train_data_dir,
                    artifact_dir=artifact_dir,
                    fold=fold,
                    random_seed=random_seed,
                    # Hyperparameters
                    model_arch=model_arch_trial,
                    model_encoder=model_encoder_trial,
                    model_encoder_weights=model_encoder_weights,
                    augment=augment_trial,
                    learning_rate=learning_rate_trial,
                    gamma=gamma_trial,
                    focal_loss_alpha=focal_loss_alpha_trial,
                    focal_loss_gamma=focal_loss_gamma_trial,
                    batch_size=batch_size_trial,
                    # Epoch and Logging config
                    early_stopping_patience=early_stopping_patience,
                    max_epochs=max_epochs,
                    log_every_n_steps=log_every_n_steps,
                    check_val_every_n_epoch=check_val_every_n_epoch,
                    plot_every_n_val_epochs=plot_every_n_val_epochs,
                    # Device and Manager config
                    num_workers=num_workers,
                    device=device,
                    wandb_entity=wandb_entity,
                    wandb_project=wandb_project,
                    wandb_group=sweep_id,
                    trial_name=trial_name,
                    run_name=f"{trial_name}-f{fold}r{random_seed}",
                )
                for metric, value in trainer.callback_metrics.items():
                    crossval_scores[metric].append(value.item())

        logger.debug(f"Cross-validation scores: {crossval_scores}")
        crossval_jaccard = mean(crossval_scores["val/JaccardIndex"])
        crossval_recall = mean(crossval_scores["val/Recall"])

        return crossval_jaccard, crossval_recall

    study = optuna.create_study(
        storage=sweep_db,
        study_name=sweep_id,
        directions=["maximize", "maximize"],
        load_if_exists=True,
    )

    if device is None:
        logger.info("No device specified, closing script...")
        return

    logger.info("Starting optimizing")
    study.optimize(objective, n_trials=n_trials)

preprocess_planet_train_data

preprocess_planet_train_data(
    *,
    bands: list[str],
    data_dir: pathlib.Path,
    labels_dir: pathlib.Path,
    train_data_dir: pathlib.Path,
    arcticdem_dir: pathlib.Path,
    tcvis_dir: pathlib.Path,
    admin_dir: pathlib.Path,
    preprocess_cache: pathlib.Path | None = None,
    device: typing.Literal["cuda", "cpu", "auto"]
    | int
    | None = None,
    dask_worker: int = min(
        16, multiprocessing.cpu_count() - 1
    ),
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 10,
    test_val_split: float = 0.05,
    test_regions: list[str] | None = None,
)

Preprocess Planet data for training.

The data is split into a cross-validation, a validation-test and a test set:

- `cross-val` is meant to be used for train and validation
- `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
- `test` leave-out region for testing the spatial distribution shift of the data

Each split is stored as a zarr group, containing a x and a y dataarray. The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size). The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size). Both dataarrays are chunked along the n_patches dimension. This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and therefore in a separate file.

Through the parameters test_val_split and test_regions, the test and validation split can be controlled. To test_regions can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and put them in the test-split. With the test_val_split parameter, the ratio between further splitting of a test-validation set can be controlled.

Through exclude_nopositve and exclude_nan, respective patches can be excluded from the final data.

Further, a config.toml file is saved in the train_data_dir containing the configuration used for the preprocessing. Addionally, a labels.geojson file is saved in the train_data_dir containing the joined labels geometries used for the creation of the binarized label-masks, containing also information about the split via the mode column.

The final directory structure of train_data_dir will look like this:

train_data_dir/
├── config.toml
├── cross-val.zarr/
├── test.zarr/
├── val-test.zarr/
└── labels.geojson

Parameters:

  • bands (list[str]) –

    The bands to be used for training. Must be present in the preprocessing.

  • data_dir (pathlib.Path) –

    The directory containing the Planet scenes and orthotiles.

  • labels_dir (pathlib.Path) –

    The directory containing the labels.

  • train_data_dir (pathlib.Path) –

    The "output" directory where the tensors are written to.

  • arcticdem_dir (pathlib.Path) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist.

  • tcvis_dir (pathlib.Path) –

    The directory containing the TCVis data.

  • admin_dir (pathlib.Path) –

    The directory containing the admin files.

  • preprocess_cache (pathlib.Path, default: None ) –

    The directory to store the preprocessed data. Defaults to None.

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • dask_worker (int, default: min(16, multiprocessing.cpu_count() - 1) ) –

    The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The overlap to use for inference. Defaults to 16.

  • exclude_nopositive (bool, default: False ) –

    Whether to exclude patches where the labels do not contain positives. Defaults to False.

  • exclude_nan (bool, default: True ) –

    Whether to exclude patches where the input data has nan values. Defaults to True.

  • mask_erosion_size (int, default: 10 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

  • test_val_split (float, default: 0.05 ) –

    The split ratio for the test and validation set. Defaults to 0.05.

  • test_regions (list[str] | str, default: None ) –

    The region to use for the test set. Defaults to None.

Source code in darts/src/darts/legacy_training/preprocess/planet.py
def preprocess_planet_train_data(
    *,
    bands: list[str],
    data_dir: Path,
    labels_dir: Path,
    train_data_dir: Path,
    arcticdem_dir: Path,
    tcvis_dir: Path,
    admin_dir: Path,
    preprocess_cache: Path | None = None,
    device: Literal["cuda", "cpu", "auto"] | int | None = None,
    dask_worker: int = min(16, mp.cpu_count() - 1),
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 10,
    test_val_split: float = 0.05,
    test_regions: list[str] | None = None,
):
    """Preprocess Planet data for training.

    The data is split into a cross-validation, a validation-test and a test set:

        - `cross-val` is meant to be used for train and validation
        - `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
        - `test` leave-out region for testing the spatial distribution shift of the data

    Each split is stored as a zarr group, containing a x and a y dataarray.
    The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size).
    The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size).
    Both dataarrays are chunked along the n_patches dimension.
    This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and
    therefore in a separate file.

    Through the parameters `test_val_split` and `test_regions`, the test and validation split can be controlled.
    To `test_regions` can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by
    https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and
    put them in the test-split.
    With the `test_val_split` parameter, the ratio between further splitting of a test-validation set can be controlled.

    Through `exclude_nopositve` and `exclude_nan`, respective patches can be excluded from the final data.

    Further, a `config.toml` file is saved in the `train_data_dir` containing the configuration used for the
    preprocessing.
    Addionally, a `labels.geojson` file is saved in the `train_data_dir` containing the joined labels geometries used
    for the creation of the binarized label-masks, containing also information about the split via the `mode` column.

    The final directory structure of `train_data_dir` will look like this:

    ```sh
    train_data_dir/
    ├── config.toml
    ├── cross-val.zarr/
    ├── test.zarr/
    ├── val-test.zarr/
    └── labels.geojson
    ```

    Args:
        bands (list[str]): The bands to be used for training. Must be present in the preprocessing.
        data_dir (Path): The directory containing the Planet scenes and orthotiles.
        labels_dir (Path): The directory containing the labels.
        train_data_dir (Path): The "output" directory where the tensors are written to.
        arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
            Will be created and downloaded if it does not exist.
        tcvis_dir (Path): The directory containing the TCVis data.
        admin_dir (Path): The directory containing the admin files.
        preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None.
        device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
            If "cuda" take the first device (0), if int take the specified device.
            If "auto" try to automatically select a free GPU (<50% memory usage).
            Defaults to "cuda" if available, else "cpu".
        dask_worker (int, optional): The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).
        ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
            project is defined within persistent API credentials obtained via `earthengine authenticate`.
        ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
        tpi_outer_radius (int, optional): The outer radius of the annulus kernel for the tpi calculation
            in m. Defaults to 100m.
        tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation
            in m. Defaults to 0.
        patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
        overlap (int, optional): The overlap to use for inference. Defaults to 16.
        exclude_nopositive (bool, optional): Whether to exclude patches where the labels do not contain positives.
            Defaults to False.
        exclude_nan (bool, optional): Whether to exclude patches where the input data has nan values.
            Defaults to True.
        mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
            Defaults to 10.
        test_val_split (float, optional): The split ratio for the test and validation set. Defaults to 0.05.
        test_regions (list[str] | str, optional): The region to use for the test set. Defaults to None.

    """
    # Import here to avoid long loading times when running other commands
    import geopandas as gpd
    import pandas as pd
    import toml
    import xarray as xr
    import zarr
    from darts_acquisition import load_arcticdem, load_planet_masks, load_planet_scene, load_tcvis
    from darts_preprocessing import preprocess_legacy_fast
    from darts_segmentation.training.prepare_training import create_training_patches
    from dask.distributed import Client, LocalCluster
    from lovely_tensors import monkey_patch
    from odc.stac import configure_rio
    from rich.progress import track
    from zarr.codecs import BloscCodec
    from zarr.storage import LocalStore

    from darts.utils.cuda import debug_info, decide_device
    from darts.utils.earthengine import init_ee
    from darts.utils.logging import console

    monkey_patch()
    debug_info()
    device = decide_device(device)
    init_ee(ee_project, ee_use_highvolume)

    with LocalCluster(n_workers=dask_worker) as cluster, Client(cluster) as client:
        logger.info(f"Using Dask client: {client} on cluster {cluster}")
        logger.info(f"Dashboard available at: {client.dashboard_link}")
        configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
        logger.info("Configured Rasterio with Dask")

        labels = (gpd.read_file(labels_file) for labels_file in labels_dir.glob("*/TrainingLabel*.gpkg"))
        labels = gpd.GeoDataFrame(pd.concat(labels, ignore_index=True))

        footprints = (gpd.read_file(footprints_file) for footprints_file in labels_dir.glob("*/ImageFootprints*.gpkg"))
        footprints = gpd.GeoDataFrame(pd.concat(footprints, ignore_index=True))

        # We hardcode these because they depend on the preprocessing used
        norm_factors = {
            "red": 1 / 3000,
            "green": 1 / 3000,
            "blue": 1 / 3000,
            "nir": 1 / 3000,
            "ndvi": 1 / 20000,
            "relative_elevation": 1 / 30000,
            "slope": 1 / 90,
            "tc_brightness": 1 / 255,
            "tc_greenness": 1 / 255,
            "tc_wetness": 1 / 255,
        }
        # Filter out bands that are not in the specified bands
        norm_factors = {k: v for k, v in norm_factors.items() if k in bands}

        train_data_dir.mkdir(exist_ok=True, parents=True)

        zgroups = {
            "cross-val": zarr.group(store=LocalStore(train_data_dir / "cross-val.zarr"), overwrite=True),
            "val-test": zarr.group(store=LocalStore(train_data_dir / "val-test.zarr"), overwrite=True),
            "test": zarr.group(store=LocalStore(train_data_dir / "test.zarr"), overwrite=True),
        }
        # We need do declare the number of patches to 0, because we can't know the final number of patches
        for root in zgroups.values():
            root.create(
                name="x",
                shape=(0, len(bands), patch_size, patch_size),
                # shards=(100, len(bands), patch_size, patch_size),
                chunks=(1, len(bands), patch_size, patch_size),
                dtype="float32",
                compressor=BloscCodec(cname="lz4", clevel=9),
            )
            root.create(
                name="y",
                shape=(0, patch_size, patch_size),
                # shards=(100, patch_size, patch_size),
                chunks=(1, patch_size, patch_size),
                dtype="uint8",
                compressor=BloscCodec(cname="lz4", clevel=9),
            )

        # Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region)
        n_patches = 0
        n_patches_by_mode = {"cross-val": 0, "val-test": 0, "test": 0}
        joint_lables = []
        planet_paths = sorted(_legacy_path_gen(data_dir))
        logger.info(f"Found {len(planet_paths)} PLANET scenes and orthotiles in {data_dir}")
        path_gen = split_dataset_paths(
            planet_paths, footprints, train_data_dir, test_val_split, test_regions, admin_dir
        )

        for i, (fpath, mode) in track(
            enumerate(path_gen), description="Processing samples", total=len(planet_paths), console=console
        ):
            try:
                planet_id = fpath.stem
                logger.debug(
                    f"Processing sample {i + 1} of {len(planet_paths)}"
                    f" '{fpath.resolve()}' ({planet_id=}) to split '{mode}'"
                )

                # Check for a cached preprocessed file
                if preprocess_cache and (preprocess_cache / f"{planet_id}.nc").exists():
                    cache_file = preprocess_cache / f"{planet_id}.nc"
                    logger.info(f"Loading preprocessed data from {cache_file.resolve()}")
                    tile = xr.open_dataset(preprocess_cache / f"{planet_id}.nc", engine="h5netcdf").set_coords(
                        "spatial_ref"
                    )
                else:
                    optical = load_planet_scene(fpath)
                    logger.info(f"Found optical tile with size {optical.sizes}")
                    arctidem_res = 2
                    arcticdem_buffer = ceil(tpi_outer_radius / arctidem_res * sqrt(2))
                    arcticdem = load_arcticdem(
                        optical.odc.geobox, arcticdem_dir, resolution=arctidem_res, buffer=arcticdem_buffer
                    )
                    tcvis = load_tcvis(optical.odc.geobox, tcvis_dir)
                    data_masks = load_planet_masks(fpath)

                    tile: xr.Dataset = preprocess_legacy_fast(
                        optical,
                        arcticdem,
                        tcvis,
                        data_masks,
                        tpi_outer_radius,
                        tpi_inner_radius,
                        device,
                    )
                    # Only cache if we have a cache directory
                    if preprocess_cache:
                        preprocess_cache.mkdir(exist_ok=True, parents=True)
                        cache_file = preprocess_cache / f"{planet_id}.nc"
                        logger.info(f"Caching preprocessed data to {cache_file.resolve()}")
                        tile.to_netcdf(cache_file, engine="h5netcdf")

                # Save the patches
                gen = create_training_patches(
                    tile=tile,
                    labels=labels[labels.image_id == planet_id],
                    bands=bands,
                    norm_factors=norm_factors,
                    patch_size=patch_size,
                    overlap=overlap,
                    exclude_nopositive=exclude_nopositive,
                    exclude_nan=exclude_nan,
                    device=device,
                    mask_erosion_size=mask_erosion_size,
                )

                zx = zgroups[mode]["x"]
                zy = zgroups[mode]["y"]
                patch_id = None
                for patch_id, (x, y) in enumerate(gen):
                    zx.append(x.unsqueeze(0).numpy().astype("float32"))
                    zy.append(y.unsqueeze(0).numpy().astype("uint8"))
                    n_patches += 1
                    n_patches_by_mode[mode] += 1
                if n_patches > 0 and len(labels) > 0:
                    labels["mode"] = mode
                    joint_lables.append(labels.to_crs("EPSG:3413"))

                logger.info(
                    f"Processed sample {i + 1} of {len(planet_paths)} '{fpath.resolve()}'"
                    f"({planet_id=}) with {patch_id} patches."
                )

            except KeyboardInterrupt:
                logger.info("Interrupted by user.")
                break

            except Exception as e:
                logger.warning(f"Could not process folder sample {i} '{fpath.resolve()}'.\nSkipping...")
                logger.exception(e)

    # Save the used labels
    joint_lables = pd.concat(joint_lables)
    joint_lables.to_file(train_data_dir / "labels.geojson", driver="GeoJSON")

    # Save a config file as toml
    config = {
        "darts": {
            "data_dir": data_dir,
            "labels_dir": labels_dir,
            "train_data_dir": train_data_dir,
            "arcticdem_dir": arcticdem_dir,
            "tcvis_dir": tcvis_dir,
            "bands": bands,
            "norm_factors": norm_factors,
            "device": device,
            "ee_project": ee_project,
            "ee_use_highvolume": ee_use_highvolume,
            "tpi_outer_radius": tpi_outer_radius,
            "tpi_inner_radius": tpi_inner_radius,
            "patch_size": patch_size,
            "overlap": overlap,
            "exclude_nopositive": exclude_nopositive,
            "exclude_nan": exclude_nan,
            "n_patches": n_patches,
        }
    }
    with open(train_data_dir / "config.toml", "w") as f:
        toml.dump(config, f)

    logger.info(f"Saved {n_patches} ({n_patches_by_mode}) patches to {train_data_dir}")

preprocess_s2_train_data

preprocess_s2_train_data(
    *,
    bands: list[str],
    sentinel2_dir: pathlib.Path,
    train_data_dir: pathlib.Path,
    arcticdem_dir: pathlib.Path,
    tcvis_dir: pathlib.Path,
    admin_dir: pathlib.Path,
    preprocess_cache: pathlib.Path | None = None,
    device: typing.Literal["cuda", "cpu", "auto"]
    | int
    | None = None,
    dask_worker: int = min(
        16, multiprocessing.cpu_count() - 1
    ),
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 10,
    test_val_split: float = 0.05,
    test_regions: list[str] | None = None,
)

Preprocess Sentinel 2 data for training.

The data is split into a cross-validation, a validation-test and a test set:

- `cross-val` is meant to be used for train and validation
- `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
- `test` leave-out region for testing the spatial distribution shift of the data

Each split is stored as a zarr group, containing a x and a y dataarray. The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size). The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size). Both dataarrays are chunked along the n_patches dimension. This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and therefore in a separate file.

Through the parameters test_val_split and test_regions, the test and validation split can be controlled. To test_regions can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and put them in the test-split. With the test_val_split parameter, the ratio between further splitting of a test-validation set can be controlled.

Through exclude_nopositve and exclude_nan, respective patches can be excluded from the final data.

Further, a config.toml file is saved in the train_data_dir containing the configuration used for the preprocessing. Addionally, a labels.geojson file is saved in the train_data_dir containing the joined labels geometries used for the creation of the binarized label-masks, containing also information about the split via the mode column.

The final directory structure of train_data_dir will look like this:

train_data_dir/
├── config.toml
├── cross-val.zarr/
├── test.zarr/
├── val-test.zarr/
└── labels.geojson

Parameters:

  • bands (list[str]) –

    The bands to be used for training. Must be present in the preprocessing.

  • sentinel2_dir (pathlib.Path) –

    The directory containing the Sentinel 2 scenes.

  • train_data_dir (pathlib.Path) –

    The "output" directory where the tensors are written to.

  • arcticdem_dir (pathlib.Path) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist.

  • tcvis_dir (pathlib.Path) –

    The directory containing the TCVis data.

  • admin_dir (pathlib.Path) –

    The directory containing the admin files.

  • preprocess_cache (pathlib.Path, default: None ) –

    The directory to store the preprocessed data. Defaults to None.

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • dask_worker (int, default: min(16, multiprocessing.cpu_count() - 1) ) –

    The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The overlap to use for inference. Defaults to 16.

  • exclude_nopositive (bool, default: False ) –

    Whether to exclude patches where the labels do not contain positives. Defaults to False.

  • exclude_nan (bool, default: True ) –

    Whether to exclude patches where the input data has nan values. Defaults to True.

  • mask_erosion_size (int, default: 10 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

  • test_val_split (float, default: 0.05 ) –

    The split ratio for the test and validation set. Defaults to 0.05.

  • test_regions (list[str] | str, default: None ) –

    The region to use for the test set. Defaults to None.

Source code in darts/src/darts/legacy_training/preprocess/s2.py
def preprocess_s2_train_data(
    *,
    bands: list[str],
    sentinel2_dir: Path,
    train_data_dir: Path,
    arcticdem_dir: Path,
    tcvis_dir: Path,
    admin_dir: Path,
    preprocess_cache: Path | None = None,
    device: Literal["cuda", "cpu", "auto"] | int | None = None,
    dask_worker: int = min(16, mp.cpu_count() - 1),
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 10,
    test_val_split: float = 0.05,
    test_regions: list[str] | None = None,
):
    """Preprocess Sentinel 2 data for training.

    The data is split into a cross-validation, a validation-test and a test set:

        - `cross-val` is meant to be used for train and validation
        - `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
        - `test` leave-out region for testing the spatial distribution shift of the data

    Each split is stored as a zarr group, containing a x and a y dataarray.
    The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size).
    The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size).
    Both dataarrays are chunked along the n_patches dimension.
    This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and
    therefore in a separate file.

    Through the parameters `test_val_split` and `test_regions`, the test and validation split can be controlled.
    To `test_regions` can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by
    https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and
    put them in the test-split.
    With the `test_val_split` parameter, the ratio between further splitting of a test-validation set can be controlled.

    Through `exclude_nopositve` and `exclude_nan`, respective patches can be excluded from the final data.

    Further, a `config.toml` file is saved in the `train_data_dir` containing the configuration used for the
    preprocessing.
    Addionally, a `labels.geojson` file is saved in the `train_data_dir` containing the joined labels geometries used
    for the creation of the binarized label-masks, containing also information about the split via the `mode` column.

    The final directory structure of `train_data_dir` will look like this:

    ```sh
    train_data_dir/
    ├── config.toml
    ├── cross-val.zarr/
    ├── test.zarr/
    ├── val-test.zarr/
    └── labels.geojson
    ```

    Args:
        bands (list[str]): The bands to be used for training. Must be present in the preprocessing.
        sentinel2_dir (Path): The directory containing the Sentinel 2 scenes.
        train_data_dir (Path): The "output" directory where the tensors are written to.
        arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
            Will be created and downloaded if it does not exist.
        tcvis_dir (Path): The directory containing the TCVis data.
        admin_dir (Path): The directory containing the admin files.
        preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None.
        device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
            If "cuda" take the first device (0), if int take the specified device.
            If "auto" try to automatically select a free GPU (<50% memory usage).
            Defaults to "cuda" if available, else "cpu".
        dask_worker (int, optional): The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).
        ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
            project is defined within persistent API credentials obtained via `earthengine authenticate`.
        ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
        tpi_outer_radius (int, optional): The outer radius of the annulus kernel for the tpi calculation
            in m. Defaults to 100m.
        tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation
            in m. Defaults to 0.
        patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
        overlap (int, optional): The overlap to use for inference. Defaults to 16.
        exclude_nopositive (bool, optional): Whether to exclude patches where the labels do not contain positives.
            Defaults to False.
        exclude_nan (bool, optional): Whether to exclude patches where the input data has nan values.
            Defaults to True.
        mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
            Defaults to 10.
        test_val_split (float, optional): The split ratio for the test and validation set. Defaults to 0.05.
        test_regions (list[str] | str, optional): The region to use for the test set. Defaults to None.

    """
    # Import here to avoid long loading times when running other commands
    import geopandas as gpd
    import pandas as pd
    import toml
    import xarray as xr
    import zarr
    from darts_acquisition import load_arcticdem, load_s2_masks, load_s2_scene, load_tcvis
    from darts_acquisition.s2 import parse_s2_tile_id
    from darts_preprocessing import preprocess_legacy_fast
    from darts_segmentation.training.prepare_training import create_training_patches
    from dask.distributed import Client, LocalCluster
    from lovely_tensors import monkey_patch
    from odc.stac import configure_rio
    from rich.progress import track
    from zarr.codecs import BloscCodec
    from zarr.storage import LocalStore

    from darts.utils.cuda import debug_info, decide_device
    from darts.utils.earthengine import init_ee
    from darts.utils.logging import console

    monkey_patch()
    debug_info()
    device = decide_device(device)
    init_ee(ee_project, ee_use_highvolume)

    with LocalCluster(n_workers=dask_worker) as cluster, Client(cluster) as client:
        logger.info(f"Using Dask client: {client} on cluster {cluster}")
        logger.info(f"Dashboard available at: {client.dashboard_link}")
        configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
        logger.info("Configured Rasterio with Dask")

        # We hardcode these because they depend on the preprocessing used
        norm_factors = {
            "red": 1 / 3000,
            "green": 1 / 3000,
            "blue": 1 / 3000,
            "nir": 1 / 3000,
            "ndvi": 1 / 20000,
            "relative_elevation": 1 / 30000,
            "slope": 1 / 90,
            "tc_brightness": 1 / 255,
            "tc_greenness": 1 / 255,
            "tc_wetness": 1 / 255,
        }
        # Filter out bands that are not in the specified bands
        norm_factors = {k: v for k, v in norm_factors.items() if k in bands}

        train_data_dir.mkdir(exist_ok=True, parents=True)

        zgroups = {
            "cross-val": zarr.group(store=LocalStore(train_data_dir / "cross-val.zarr"), overwrite=True),
            "val-test": zarr.group(store=LocalStore(train_data_dir / "val-test.zarr"), overwrite=True),
            "test": zarr.group(store=LocalStore(train_data_dir / "test.zarr"), overwrite=True),
        }
        # We need do declare the number of patches to 0, because we can't know the final number of patches
        for root in zgroups.values():
            root.create(
                name="x",
                shape=(0, len(bands), patch_size, patch_size),
                # shards=(100, len(bands), patch_size, patch_size),
                chunks=(1, len(bands), patch_size, patch_size),
                dtype="float32",
                compressors=BloscCodec(cname="lz4", clevel=9),
            )
            root.create(
                name="y",
                shape=(0, patch_size, patch_size),
                # shards=(100, patch_size, patch_size),
                chunks=(1, patch_size, patch_size),
                dtype="uint8",
                compressors=BloscCodec(cname="lz4", clevel=9),
            )

        # Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region)
        n_patches = 0
        n_patches_by_mode = {"cross-val": 0, "val-test": 0, "test": 0}
        joint_lables = []
        s2_paths = sorted(sentinel2_dir.glob("*/"))
        logger.info(f"Found {len(s2_paths)} Sentinel 2 scenes in {sentinel2_dir}")
        path_gen = split_dataset_paths(s2_paths, train_data_dir, test_val_split, test_regions, admin_dir)
        for i, (fpath, mode) in track(
            enumerate(path_gen), description="Processing samples", total=len(s2_paths), console=console
        ):
            try:
                _, s2_tile_id, tile_id = parse_s2_tile_id(fpath)

                logger.debug(
                    f"Processing sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}' ({tile_id=}) to split '{mode}'"
                )

                # Check for a cached preprocessed file
                if preprocess_cache and (preprocess_cache / f"{tile_id}.nc").exists():
                    cache_file = preprocess_cache / f"{tile_id}.nc"
                    logger.info(f"Loading preprocessed data from {cache_file.resolve()}")
                    tile = xr.open_dataset(preprocess_cache / f"{tile_id}.nc", engine="h5netcdf").set_coords(
                        "spatial_ref"
                    )
                else:
                    optical = load_s2_scene(fpath)
                    logger.info(f"Found optical tile with size {optical.sizes}")
                    arctidem_res = 10
                    arcticdem_buffer = ceil(tpi_outer_radius / arctidem_res * sqrt(2))
                    arcticdem = load_arcticdem(
                        optical.odc.geobox, arcticdem_dir, resolution=arctidem_res, buffer=arcticdem_buffer
                    )
                    tcvis = load_tcvis(optical.odc.geobox, tcvis_dir)
                    data_masks = load_s2_masks(fpath, optical.odc.geobox)

                    tile: xr.Dataset = preprocess_legacy_fast(
                        optical,
                        arcticdem,
                        tcvis,
                        data_masks,
                        tpi_outer_radius,
                        tpi_inner_radius,
                        device,
                    )
                    # Only cache if we have a cache directory
                    if preprocess_cache:
                        preprocess_cache.mkdir(exist_ok=True, parents=True)
                        cache_file = preprocess_cache / f"{tile_id}.nc"
                        logger.info(f"Caching preprocessed data to {cache_file.resolve()}")
                        tile.to_netcdf(cache_file, engine="h5netcdf")

                labels = gpd.read_file(fpath / f"{s2_tile_id}.shp")

                # Save the patches
                gen = create_training_patches(
                    tile,
                    labels,
                    bands,
                    norm_factors,
                    patch_size,
                    overlap,
                    exclude_nopositive,
                    exclude_nan,
                    device,
                    mask_erosion_size,
                )

                zx = zgroups[mode]["x"]
                zy = zgroups[mode]["y"]
                patch_id = None
                for patch_id, (x, y) in enumerate(gen):
                    zx.append(x.unsqueeze(0).numpy().astype("float32"))
                    zy.append(y.unsqueeze(0).numpy().astype("uint8"))
                    n_patches += 1
                    n_patches_by_mode[mode] += 1
                if n_patches > 0 and len(labels) > 0:
                    labels["mode"] = mode
                    joint_lables.append(labels.to_crs("EPSG:3413"))

                logger.info(
                    f"Processed sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}'"
                    f"({tile_id=}) with {patch_id} patches."
                )
            except KeyboardInterrupt:
                logger.info("Interrupted by user.")
                break

            except Exception as e:
                logger.warning(f"Could not process folder sample {i} '{fpath.resolve()}'.\nSkipping...")
                logger.exception(e)

    # Save the used labels
    joint_lables = pd.concat(joint_lables)
    joint_lables.to_file(train_data_dir / "labels.geojson", driver="GeoJSON")

    # Save a config file as toml
    config = {
        "darts": {
            "sentinel2_dir": sentinel2_dir,
            "train_data_dir": train_data_dir,
            "arcticdem_dir": arcticdem_dir,
            "tcvis_dir": tcvis_dir,
            "bands": bands,
            "norm_factors": norm_factors,
            "device": device,
            "ee_project": ee_project,
            "ee_use_highvolume": ee_use_highvolume,
            "tpi_outer_radius": tpi_outer_radius,
            "tpi_inner_radius": tpi_inner_radius,
            "patch_size": patch_size,
            "overlap": overlap,
            "exclude_nopositive": exclude_nopositive,
            "exclude_nan": exclude_nan,
            "n_patches": n_patches,
        }
    }
    with open(train_data_dir / "config.toml", "w") as f:
        toml.dump(config, f)

    logger.info(f"Saved {n_patches} ({n_patches_by_mode}) patches to {train_data_dir}")

test_smp

test_smp(
    *,
    train_data_dir: pathlib.Path,
    run_id: str,
    run_name: str,
    model_ckp: pathlib.Path | None = None,
    batch_size: int = 8,
    artifact_dir: pathlib.Path = pathlib.Path(
        "lightning_logs"
    ),
    num_workers: int = 0,
    device: int | str = "auto",
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
) -> pytorch_lightning.Trainer

Run the testing of the SMP model.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── cross-val.zarr/ # this zarr group contains the dataarrays x and y for the training and validation
├── test.zarr/ # this zarr group contains the dataarrays x and y for the left-out-region test set
├── val-test.zarr/ # this zarr group contains the dataarrays x and y for the random selected validation set
└── labels.geojson

Parameters:

  • train_data_dir (pathlib.Path) –

    Path to the training data directory (top-level).

  • run_id (str) –

    ID of the run.

  • run_name (str) –

    Name of the run.

  • model_ckp (pathlib.Path | None, default: None ) –

    Path to the model checkpoint. If None, try to find the latest checkpoint in artifact_dir / run_name / run_id / checkpoints. Defaults to None.

  • batch_size (int, default: 8 ) –

    Batch size. Defaults to 8.

  • artifact_dir (pathlib.Path, default: pathlib.Path('lightning_logs') ) –

    Directory to save artifacts. Defaults to Path("lightning_logs").

  • num_workers (int, default: 0 ) –

    Number of workers for the DataLoader. Defaults to 0.

  • device (int | str, default: 'auto' ) –

    Device to use. Defaults to "auto".

  • wandb_entity (str | None, default: None ) –

    WandB entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    WandB project. Defaults to None.

Returns:

  • Trainer ( pytorch_lightning.Trainer ) –

    The trainer object used for training.

Source code in darts/src/darts/legacy_training/test.py
def test_smp(
    *,
    train_data_dir: Path,
    run_id: str,
    run_name: str,
    model_ckp: Path | None = None,
    batch_size: int = 8,
    artifact_dir: Path = Path("lightning_logs"),
    num_workers: int = 0,
    device: int | str = "auto",
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
) -> "pl.Trainer":
    """Run the testing of the SMP model.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── cross-val.zarr/ # this zarr group contains the dataarrays x and y for the training and validation
    ├── test.zarr/ # this zarr group contains the dataarrays x and y for the left-out-region test set
    ├── val-test.zarr/ # this zarr group contains the dataarrays x and y for the random selected validation set
    └── labels.geojson
    ```

    Args:
        train_data_dir (Path): Path to the training data directory (top-level).
        run_id (str): ID of the run.
        run_name (str): Name of the run.
        model_ckp (Path | None): Path to the model checkpoint.
            If None, try to find the latest checkpoint in `artifact_dir / run_name / run_id / checkpoints`.
            Defaults to None.
        batch_size (int, optional): Batch size. Defaults to 8.
        artifact_dir (Path, optional): Directory to save artifacts. Defaults to Path("lightning_logs").
        num_workers (int, optional): Number of workers for the DataLoader. Defaults to 0.
        device (int | str, optional): Device to use. Defaults to "auto".
        wandb_entity (str | None, optional): WandB entity. Defaults to None.
        wandb_project (str | None, optional): WandB project. Defaults to None.

    Returns:
        Trainer: The trainer object used for training.

    """
    import lightning as L  # noqa: N812
    import lovely_tensors
    import torch
    from darts_segmentation.training.callbacks import BinarySegmentationMetrics
    from darts_segmentation.training.data import DartsDataModule
    from darts_segmentation.training.module import SMPSegmenter
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import RichProgressBar
    from lightning.pytorch.loggers import CSVLogger, WandbLogger

    from darts.utils.logging import LoggingManager

    LoggingManager.apply_logging_handlers("lightning.pytorch")

    tick_fstart = time.perf_counter()
    logger.info(f"Starting testing '{run_name}' ('{run_id}') with data from {train_data_dir.resolve()}.")
    logger.debug(f"Using config:\n\t{batch_size=}\n\t{device=}")

    lovely_tensors.monkey_patch()

    torch.set_float32_matmul_precision("medium")
    seed_everything(42, workers=True)

    preprocess_config = toml.load(train_data_dir / "config.toml")["darts"]

    # Data and model
    datamodule_val_test = DartsDataModule(
        data_dir=train_data_dir / "val-test.zarr",
        batch_size=batch_size,
        num_workers=num_workers,
    )
    datamodule_test = DartsDataModule(
        data_dir=train_data_dir / "test.zarr",
        batch_size=batch_size,
        num_workers=num_workers,
    )
    # Try to infer model checkpoint if not given
    if model_ckp is None:
        checkpoint_dir = artifact_dir / run_name / run_id / "checkpoints"
        logger.debug(f"No checkpoint provided. Looking for model checkpoint in {checkpoint_dir.resolve()}")
        model_ckp = max(checkpoint_dir.glob("*.ckpt"), key=lambda x: x.stat().st_mtime)
    model = SMPSegmenter.load_from_checkpoint(model_ckp)

    # Loggers
    trainer_loggers = [
        CSVLogger(save_dir=artifact_dir, name=run_name, version=run_id),
    ]
    logger.debug(f"Logging CSV to {Path(trainer_loggers[0].log_dir).resolve()}")
    if wandb_entity and wandb_project:
        wandb_logger = WandbLogger(
            save_dir=artifact_dir,
            name=run_name,
            id=run_id,
            project=wandb_project,
            entity=wandb_entity,
        )
        trainer_loggers.append(wandb_logger)
        logger.debug(
            f"Logging to WandB with entity '{wandb_entity}' and project '{wandb_project}'."
            f"Artifacts are logged to {(Path(wandb_logger.save_dir) / 'wandb').resolve()}"
        )

    # Callbacks
    metrics_cb = BinarySegmentationMetrics(
        input_combination=preprocess_config["bands"],
    )
    callbacks = [
        RichProgressBar(),
        metrics_cb,
    ]

    # Test
    trainer = L.Trainer(
        callbacks=callbacks,
        logger=trainer_loggers,
        accelerator="gpu" if isinstance(device, int) else device,
        devices=[device] if isinstance(device, int) else device,
        deterministic=True,
    )
    # Overwrite the names of the test sets to test agains two separate sets
    metrics_cb.test_set = "val-test"
    model.test_set = "val-test"
    trainer.test(model, datamodule_val_test, ckpt_path=model_ckp)
    metrics_cb.test_set = "test"
    model.test_set = "test"
    trainer.test(model, datamodule_test)

    tick_fend = time.perf_counter()
    logger.info(f"Finished testing '{run_name}' in {tick_fend - tick_fstart:.2f}s.")

    if wandb_entity and wandb_project:
        wandb_logger.finalize("success")
        wandb_logger.experiment.finish(exit_code=0)
        logger.debug(f"Finalized WandB logging for '{run_name}'")

    return trainer

train_smp

train_smp(
    *,
    train_data_dir: pathlib.Path,
    artifact_dir: pathlib.Path = pathlib.Path(
        "lightning_logs"
    ),
    fold: int = 0,
    continue_from_checkpoint: pathlib.Path | None = None,
    model_arch: str = "Unet",
    model_encoder: str = "dpn107",
    model_encoder_weights: str | None = None,
    augment: bool = True,
    learning_rate: float = 0.001,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    batch_size: int = 8,
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    early_stopping_patience: int = 5,
    plot_every_n_val_epochs: int = 5,
    random_seed: int = 42,
    num_workers: int = 0,
    device: int | str = "auto",
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
    wandb_group: str | None = None,
    run_name: str | None = None,
    run_id: str | None = None,
    trial_name: str | None = None,
) -> pytorch_lightning.Trainer

Run the training of the SMP model.

Please see https://smp.readthedocs.io/en/latest/index.html for model configurations.

Each training run is assigned a unique name and id pair and optionally a trial name. The name, which the user can provide, should be used as a grouping mechanism of equal hyperparameter and code. Hence, different versions of the same name should only differ by random state or run settings parameter, like logs. Each version is assigned a unique id. Artifacts (metrics & checkpoints) are then stored under {artifact_dir}/{run_name}/{run_id} in no-crossval runs. If trial_name is specified, the artifacts are stored under {artifact_dir}/{trial_name}/{run_name}-{run_id}. Wandb logs are always stored under {wandb_entity}/{wandb_project}/{run_name}, regardless of trial_name. However, they are further grouped by the trial_name (via job_type), if specified. Both run_name and run_id are also stored in the hparams of each checkpoint.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - early_stopping_patience specifies how many epochs to wait for improvement before stopping. In epochs, this would be check_val_every_n_epoch * early_stopping_patience. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── cross-val.zarr/ # this zarr group contains the dataarrays x and y for the training and validation
├── test.zarr/ # this zarr group contains the dataarrays x and y for the left-out-region test set
├── val-test.zarr/ # this zarr group contains the dataarrays x and y for the random selected validation set
└── labels.geojson

Parameters:

  • train_data_dir (pathlib.Path) –

    Path to the training data directory (top-level).

  • artifact_dir (pathlib.Path, default: pathlib.Path('lightning_logs') ) –

    Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs").

  • fold (int, default: 0 ) –

    The current fold to train on. Must be in [0, 4]. Defaults to 0.

  • continue_from_checkpoint (pathlib.Path | None, default: None ) –

    Path to a checkpoint to continue training from. Defaults to None.

  • model_arch (str, default: 'Unet' ) –

    Model architecture to use. Defaults to "Unet".

  • model_encoder (str, default: 'dpn107' ) –

    Encoder to use. Defaults to "dpn107".

  • model_encoder_weights (str | None, default: None ) –

    Path to the encoder weights. Defaults to None.

  • augment (bool, default: True ) –

    Weather to apply augments or not. Defaults to True.

  • learning_rate (float, default: 0.001 ) –

    Learning Rate. Defaults to 1e-3.

  • gamma (float, default: 0.9 ) –

    Multiplicative factor of learning rate decay. Defaults to 0.9.

  • focal_loss_alpha (float, default: None ) –

    Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.

  • focal_loss_gamma (float, default: 2.0 ) –

    Focal loss power factor. Defaults to 2.0.

  • batch_size (int, default: 8 ) –

    Batch Size. Defaults to 8.

  • max_epochs (int, default: 100 ) –

    Maximum number of epochs to train. Defaults to 100.

  • log_every_n_steps (int, default: 10 ) –

    Log every n steps. Defaults to 10.

  • check_val_every_n_epoch (int, default: 3 ) –

    Check validation every n epochs. Defaults to 3.

  • early_stopping_patience (int, default: 5 ) –

    Number of epochs to wait for improvement before stopping. Defaults to 5.

  • plot_every_n_val_epochs (int, default: 5 ) –

    Plot validation samples every n epochs. Defaults to 5.

  • random_seed (int, default: 42 ) –

    Random seed for deterministic training. Defaults to 42.

  • num_workers (int, default: 0 ) –

    Number of Dataloader workers. Defaults to 0.

  • device (int | str, default: 'auto' ) –

    The device to run the model on. Defaults to "auto".

  • wandb_entity (str | None, default: None ) –

    Weights and Biases Entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    Weights and Biases Project. Defaults to None.

  • wandb_group (str | None, default: None ) –

    Wandb group. Usefull for CV-Sweeps. Defaults to None.

  • run_name (str | None, default: None ) –

    Name of this run, as a further grouping method for logs etc. If None, will generate a random one. Defaults to None.

  • run_id (str | None, default: None ) –

    ID of the run. If None, will generate a random one. Defaults to None.

  • trial_name (str | None, default: None ) –

    Name of the cross-validation run / trial. This effects primary logging and artifact storage. If None, will do nothing. Defaults to None.

Returns:

  • Trainer ( pytorch_lightning.Trainer ) –

    The trainer object used for training.

Source code in darts/src/darts/legacy_training/train.py
def train_smp(
    *,
    # Data config
    train_data_dir: Path,
    artifact_dir: Path = Path("lightning_logs"),
    fold: int = 0,
    continue_from_checkpoint: Path | None = None,
    # Hyperparameters
    model_arch: str = "Unet",
    model_encoder: str = "dpn107",
    model_encoder_weights: str | None = None,
    augment: bool = True,
    learning_rate: float = 1e-3,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    batch_size: int = 8,
    # Epoch and Logging config
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    early_stopping_patience: int = 5,
    plot_every_n_val_epochs: int = 5,
    # Device and Manager config
    random_seed: int = 42,
    num_workers: int = 0,
    device: int | str = "auto",
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
    wandb_group: str | None = None,
    run_name: str | None = None,
    run_id: str | None = None,
    trial_name: str | None = None,
) -> "pl.Trainer":
    """Run the training of the SMP model.

    Please see https://smp.readthedocs.io/en/latest/index.html for model configurations.

    Each training run is assigned a unique **name** and **id** pair and optionally a trial name.
    The name, which the user _can_ provide, should be used as a grouping mechanism of equal hyperparameter and code.
    Hence, different versions of the same name should only differ by random state or run settings parameter, like logs.
    Each version is assigned a unique id.
    Artifacts (metrics & checkpoints) are then stored under `{artifact_dir}/{run_name}/{run_id}` in no-crossval runs.
    If `trial_name` is specified, the artifacts are stored under `{artifact_dir}/{trial_name}/{run_name}-{run_id}`.
    Wandb logs are always stored under `{wandb_entity}/{wandb_project}/{run_name}`, regardless of `trial_name`.
    However, they are further grouped by the `trial_name` (via job_type), if specified.
    Both `run_name` and `run_id` are also stored in the hparams of each checkpoint.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `early_stopping_patience` specifies how many epochs to wait for improvement before stopping.
            In epochs, this would be `check_val_every_n_epoch * early_stopping_patience`.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── cross-val.zarr/ # this zarr group contains the dataarrays x and y for the training and validation
    ├── test.zarr/ # this zarr group contains the dataarrays x and y for the left-out-region test set
    ├── val-test.zarr/ # this zarr group contains the dataarrays x and y for the random selected validation set
    └── labels.geojson
    ```

    Args:
        train_data_dir (Path): Path to the training data directory (top-level).
        artifact_dir (Path, optional): Path to the training output directory.
            Will contain checkpoints and metrics. Defaults to Path("lightning_logs").
        fold (int, optional): The current fold to train on. Must be in [0, 4]. Defaults to 0.
        continue_from_checkpoint (Path | None, optional): Path to a checkpoint to continue training from.
            Defaults to None.
        model_arch (str, optional): Model architecture to use. Defaults to "Unet".
        model_encoder (str, optional): Encoder to use. Defaults to "dpn107".
        model_encoder_weights (str | None, optional): Path to the encoder weights. Defaults to None.
        augment (bool, optional): Weather to apply augments or not. Defaults to True.
        learning_rate (float, optional): Learning Rate. Defaults to 1e-3.
        gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.9.
        focal_loss_alpha (float, optional): Weight factor to balance positive and negative samples.
            Alpha must be in [0...1] range, high values will give more weight to positive class.
            None will not weight samples. Defaults to None.
        focal_loss_gamma (float, optional): Focal loss power factor. Defaults to 2.0.
        batch_size (int, optional): Batch Size. Defaults to 8.
        max_epochs (int, optional): Maximum number of epochs to train. Defaults to 100.
        log_every_n_steps (int, optional): Log every n steps. Defaults to 10.
        check_val_every_n_epoch (int, optional): Check validation every n epochs. Defaults to 3.
        early_stopping_patience (int, optional): Number of epochs to wait for improvement before stopping.
            Defaults to 5.
        plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.
        random_seed (int, optional): Random seed for deterministic training. Defaults to 42.
        num_workers (int, optional): Number of Dataloader workers. Defaults to 0.
        device (int | str, optional): The device to run the model on. Defaults to "auto".
        wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None.
        wandb_project (str | None, optional): Weights and Biases Project. Defaults to None.
        wandb_group (str | None, optional): Wandb group. Usefull for CV-Sweeps. Defaults to None.
        run_name (str | None, optional): Name of this run, as a further grouping method for logs etc.
            If None, will generate a random one. Defaults to None.
        run_id (str | None, optional): ID of the run. If None, will generate a random one. Defaults to None.
        trial_name (str | None, optional): Name of the cross-validation run / trial.
            This effects primary logging and artifact storage.
            If None, will do nothing. Defaults to None.

    Returns:
        Trainer: The trainer object used for training.

    """
    import lightning as L  # noqa: N812
    import lovely_tensors
    import torch
    from darts_segmentation.segment import SMPSegmenterConfig
    from darts_segmentation.training.callbacks import BinarySegmentationMetrics
    from darts_segmentation.training.data import DartsDataModule
    from darts_segmentation.training.module import SMPSegmenter
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar
    from lightning.pytorch.loggers import CSVLogger, WandbLogger

    from darts.legacy_training.util import generate_id, get_generated_name
    from darts.utils.logging import LoggingManager

    LoggingManager.apply_logging_handlers("lightning.pytorch")

    tick_fstart = time.perf_counter()

    # Create unique run identification (name can be specified by user, id can be interpreded as a 'version')
    run_name = run_name or get_generated_name(artifact_dir)
    run_id = run_id or generate_id()

    logger.info(f"Starting training '{run_name}' ('{run_id}') with data from {train_data_dir.resolve()}.")
    logger.debug(
        f"Using config:\n\t{model_arch=}\n\t{model_encoder=}\n\t{model_encoder_weights=}\n\t{augment=}\n\t"
        f"{learning_rate=}\n\t{gamma=}\n\t{batch_size=}\n\t{max_epochs=}\n\t{log_every_n_steps=}\n\t"
        f"{check_val_every_n_epoch=}\n\t{early_stopping_patience=}\n\t{plot_every_n_val_epochs=}\n\t{num_workers=}"
        f"\n\t{device=}\n\t{random_seed=}"
    )

    lovely_tensors.monkey_patch()

    torch.set_float32_matmul_precision("medium")
    seed_everything(random_seed, workers=True)

    preprocess_config = toml.load(train_data_dir / "config.toml")["darts"]

    config = SMPSegmenterConfig(
        input_combination=preprocess_config["bands"],
        model={
            "arch": model_arch,
            "encoder_name": model_encoder,
            "encoder_weights": model_encoder_weights,
            "in_channels": len(preprocess_config["bands"]),
            "classes": 1,
        },
        norm_factors=preprocess_config["norm_factors"],
    )

    # Data and model
    datamodule = DartsDataModule(
        data_dir=train_data_dir / "cross-val.zarr",
        batch_size=batch_size,
        fold=fold,
        augment=augment,
        num_workers=num_workers,
    )
    model = SMPSegmenter(
        config=config,
        learning_rate=learning_rate,
        gamma=gamma,
        focal_loss_alpha=focal_loss_alpha,
        focal_loss_gamma=focal_loss_gamma,
        # These are only stored in the hparams and are not used
        run_id=run_id,
        run_name=run_name,
        trial_name=trial_name,
        random_seed=random_seed,
    )

    # Loggers
    is_crossval = bool(trial_name)
    trainer_loggers = [
        CSVLogger(
            save_dir=artifact_dir,
            name=run_name if not is_crossval else trial_name,
            version=run_id if not is_crossval else f"{run_name}-{run_id}",
        ),
    ]
    logger.debug(f"Logging CSV to {Path(trainer_loggers[0].log_dir).resolve()}")
    if wandb_entity and wandb_project:
        wandb_logger = WandbLogger(
            save_dir=artifact_dir,
            name=run_name,
            version=run_id,
            project=wandb_project,
            entity=wandb_entity,
            resume="allow",
            group=wandb_group,
            job_type=trial_name,
        )
        trainer_loggers.append(wandb_logger)
        logger.debug(
            f"Logging to WandB with entity '{wandb_entity}' and project '{wandb_project}'."
            f"Artifacts are logged to {(Path(wandb_logger.save_dir) / 'wandb').resolve()}"
        )

    # Callbacks
    callbacks = [
        RichProgressBar(),
        BinarySegmentationMetrics(
            input_combination=config["input_combination"],
            val_set=f"val{fold}",
            plot_every_n_val_epochs=plot_every_n_val_epochs,
            is_crossval=is_crossval,
        ),
    ]
    if early_stopping_patience:
        logger.debug(f"Using EarlyStopping with patience {early_stopping_patience}")
        early_stopping = EarlyStopping(monitor="val/JaccardIndex", mode="max", patience=early_stopping_patience)
        callbacks.append(early_stopping)

    # Train
    trainer = L.Trainer(
        max_epochs=max_epochs,
        callbacks=callbacks,
        log_every_n_steps=log_every_n_steps,
        logger=trainer_loggers,
        check_val_every_n_epoch=check_val_every_n_epoch,
        accelerator="gpu" if isinstance(device, int) else device,
        devices=[device] if isinstance(device, int) else device,
        deterministic=False,
    )
    trainer.fit(model, datamodule, ckpt_path=continue_from_checkpoint)

    tick_fend = time.perf_counter()
    logger.info(f"Finished training '{run_name}' in {tick_fend - tick_fstart:.2f}s.")

    if wandb_entity and wandb_project:
        wandb_logger.finalize("success")
        wandb_logger.experiment.finish(exit_code=0)
        logger.debug(f"Finalized WandB logging for '{run_name}'")

    return trainer

wandb_sweep_smp

wandb_sweep_smp(
    *,
    train_data_dir: pathlib.Path,
    sweep_config: pathlib.Path,
    n_trials: int = 10,
    sweep_id: str | None = None,
    artifact_dir: pathlib.Path = pathlib.Path(
        "lightning_logs"
    ),
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    plot_every_n_val_epochs: int = 5,
    num_workers: int = 0,
    device: int | str | None = None,
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
)

Create a sweep with wandb and run it on the specified cuda device, or continue an existing sweep.

If sweep_id is None, a new sweep will be created. Otherwise, the sweep with the given ID will be continued. All artifacts are gathered under nested directory based on the sweep id: {artifact_dir}/sweep-{sweep_id}. Since each sweep-configuration has (currently) an own name and id, a single run can be found under: {artifact_dir}/sweep-{sweep_id}/{run_name}/{run_id}. Read the training-docs for more info.

If a cuda_device is specified, run an agent on this device. If None, do nothing.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs.

This will NOT use cross-validation. For cross-validation, use optuna_sweep_smp.

Example

In one terminal, start a sweep:

    $ rye run darts wandb-sweep-smp --config-file /path/to/sweep-config.toml
    ...  # Many logs
    Created sweep with ID 123456789
    ... # More logs from spawned agent

In another terminal, start an a second agent:

    $ rye run darts wandb-sweep-smp --sweep-id 123456789
    ...

Parameters:

  • train_data_dir (pathlib.Path) –

    Path to the training data directory.

  • sweep_config (pathlib.Path) –

    Path to the sweep yaml configuration file. Must contain a valid wandb sweep configuration. Hyperparameters must contain the following fields: model_arch, model_encoder, augment, gamma, batch_size. Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information.

  • n_trials (int, default: 10 ) –

    Number of runs to execute. Defaults to 10.

  • sweep_id (str | None, default: None ) –

    The ID of the sweep. If None, a new sweep will be created. Defaults to None.

  • artifact_dir (pathlib.Path, default: pathlib.Path('lightning_logs') ) –

    Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs").

  • max_epochs (int, default: 100 ) –

    Maximum number of epochs to train. Defaults to 100.

  • log_every_n_steps (int, default: 10 ) –

    Log every n steps. Defaults to 10.

  • check_val_every_n_epoch (int, default: 3 ) –

    Check validation every n epochs. Defaults to 3.

  • plot_every_n_val_epochs (int, default: 5 ) –

    Plot validation samples every n epochs. Defaults to 5.

  • num_workers (int, default: 0 ) –

    Number of Dataloader workers. Defaults to 0.

  • device (int | str | None, default: None ) –

    The device to run the model on. Defaults to None.

  • wandb_entity (str | None, default: None ) –

    Weights and Biases Entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    Weights and Biases Project. Defaults to None.

Source code in darts/src/darts/legacy_training/train.py
def wandb_sweep_smp(
    *,
    # Data and sweep config
    train_data_dir: Path,
    sweep_config: Path,
    n_trials: int = 10,
    sweep_id: str | None = None,
    artifact_dir: Path = Path("lightning_logs"),
    # Epoch and Logging config
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    plot_every_n_val_epochs: int = 5,
    # Device and Manager config
    num_workers: int = 0,
    device: int | str | None = None,
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
):
    """Create a sweep with wandb and run it on the specified cuda device, or continue an existing sweep.

    If `sweep_id` is None, a new sweep will be created. Otherwise, the sweep with the given ID will be continued.
    All artifacts are gathered under nested directory based on the sweep id: {artifact_dir}/sweep-{sweep_id}.
    Since each sweep-configuration has (currently) an own name and id, a single run can be found under:
    {artifact_dir}/sweep-{sweep_id}/{run_name}/{run_id}. Read the training-docs for more info.

    If a `cuda_device` is specified, run an agent on this device. If None, do nothing.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.

    This will NOT use cross-validation. For cross-validation, use `optuna_sweep_smp`.

    Example:
        In one terminal, start a sweep:
        ```sh
            $ rye run darts wandb-sweep-smp --config-file /path/to/sweep-config.toml
            ...  # Many logs
            Created sweep with ID 123456789
            ... # More logs from spawned agent
        ```

        In another terminal, start an a second agent:
        ```sh
            $ rye run darts wandb-sweep-smp --sweep-id 123456789
            ...
        ```

    Args:
        train_data_dir (Path): Path to the training data directory.
        sweep_config (Path): Path to the sweep yaml configuration file. Must contain a valid wandb sweep configuration.
            Hyperparameters must contain the following fields: `model_arch`, `model_encoder`, `augment`, `gamma`,
            `batch_size`.
            Please read https://docs.wandb.ai/guides/sweeps/sweep-config-keys for more information.
        n_trials (int, optional): Number of runs to execute. Defaults to 10.
        sweep_id (str | None, optional): The ID of the sweep. If None, a new sweep will be created. Defaults to None.
        artifact_dir (Path, optional): Path to the training output directory.
            Will contain checkpoints and metrics. Defaults to Path("lightning_logs").
        max_epochs (int, optional): Maximum number of epochs to train. Defaults to 100.
        log_every_n_steps (int, optional): Log every n steps. Defaults to 10.
        check_val_every_n_epoch (int, optional): Check validation every n epochs. Defaults to 3.
        plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.
        num_workers (int, optional): Number of Dataloader workers. Defaults to 0.
        device (int | str | None, optional): The device to run the model on. Defaults to None.
        wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None.
        wandb_project (str | None, optional): Weights and Biases Project. Defaults to None.

    """
    import wandb

    # Wandb has a stupid way of logging (they log per default with click.echo to stdout)
    # We need to silence this and redirect all possible logs to our logger
    # wl = wandb.setup({"silent": True})
    # wandb.termsetup(wl.settings, logging.getLogger("wandb"))
    # LoggingManager.apply_logging_handlers("wandb")

    if sweep_id is not None and device is None:
        logger.warning("Continuing a sweep without specifying a device will not do anything.")

    with sweep_config.open("r") as f:
        sweep_configuration = yaml.safe_load(f)

    logger.debug(f"Loaded sweep configuration from {sweep_config.resolve()}:\n{sweep_configuration}")

    if sweep_id is None:
        sweep_id = wandb.sweep(sweep=sweep_configuration, project=wandb_project, entity=wandb_entity)
        logger.info(f"Created sweep with ID {sweep_id}")
        logger.info("To start a sweep agents, use the following command:")
        logger.info(f"$ rye run darts sweep_smp --sweep-id {sweep_id}")

    artifact_dir = artifact_dir / f"sweep-{sweep_id}"
    artifact_dir.mkdir(parents=True, exist_ok=True)

    def run():
        run = wandb.init(config=sweep_configuration)
        # We need to manually log the run data since the wandb logger only logs to its own logs and click
        logger.info(f"Starting sweep run '{run.settings.run_name}'")
        logger.debug(f"Run data is saved locally in {Path(run.settings.sync_dir).resolve()}")
        logger.debug(f"View project at {run.settings.project_url}")
        logger.debug(f"View sweep at {run.settings.sweep_url}")
        logger.debug(f"View run at {run.settings.run_url}")

        # We set the default weights to None, to be able to use different architectures
        model_encoder_weights = None
        # We set early stopping to None, because wandb will handle the early stopping
        early_stopping_patience = None
        learning_rate = wandb.config["learning_rate"]
        gamma = wandb.config["gamma"]
        batch_size = wandb.config["batch_size"]
        model_arch = wandb.config["model_arch"]
        model_encoder = wandb.config["model_encoder"]
        augment = wandb.config["augment"]
        focal_loss_alpha = wandb.config["focal_loss_alpha"]
        focal_loss_gamma = wandb.config["focal_loss_gamma"]
        fold = wandb.config.get("fold", 0)
        random_seed = wandb.config.get("random_seed", 42)

        train_smp(
            # Data config
            train_data_dir=train_data_dir,
            artifact_dir=artifact_dir,
            fold=fold,
            # Hyperparameters
            model_arch=model_arch,
            model_encoder=model_encoder,
            model_encoder_weights=model_encoder_weights,
            augment=augment,
            learning_rate=learning_rate,
            gamma=gamma,
            focal_loss_alpha=focal_loss_alpha,
            focal_loss_gamma=focal_loss_gamma,
            batch_size=batch_size,
            # Epoch and Logging config
            early_stopping_patience=early_stopping_patience,
            max_epochs=max_epochs,
            log_every_n_steps=log_every_n_steps,
            check_val_every_n_epoch=check_val_every_n_epoch,
            plot_every_n_val_epochs=plot_every_n_val_epochs,
            # Device and Manager config
            random_seed=random_seed,
            num_workers=num_workers,
            device=device,
            wandb_entity=wandb_entity,
            wandb_project=wandb_project,
            run_name=wandb.run.name,
            run_id=wandb.run.id,
        )

    if device is None:
        logger.info("No device specified, closing script...")
        return

    logger.info("Starting a default sweep agent")
    wandb.agent(sweep_id, function=run, count=n_trials, project=wandb_project, entity=wandb_entity)