Skip to content

Training

Preprocessed data

All training and sweeps expect data to be present in preprocessed form. This means that the train_data_dir should look like this:

train_data_dir/
├── config.toml
├── cross-val.zarr/
├── test.zarr/
├── val-test.zarr/
└── labels.geojson

With each zarr group containing a x and y dataarray.

Ideally, use the preprocessing functions explained below to create this structure.

Preprocess the data

The train, validation and test flow ist best descriped in the following image: DARTS training process

To split your sentinel 2 data into the three different datasets and preprocess it, you can use the following command:

[uv run] darts preprocess-s2-train-data --your-args-here ... 

PLANET data

If you are using PLANET data, you can use the following command instead:

[uv run] darts preprocess-planet-train-data --your-args-here ...

This will create three data splits:

  • cross-val, used for train and validation
  • val-test 5% random leave-out for testing the randomness distribution shift of the data
  • test leave-out region for testing the spatial distribution shift of the data

The final train data is saved to disk in form of zarr arrays with dimensions [n, c, h, w] and [n, h, w] for the labels respectivly, with chunksizes of n=1. Hence, every sample is saved in a separate chunk and therefore in a seperate file on disk, but all managed by zarr.

The preprocessing is done with the same components used in the segmentation pipeline. Hence, the same configuration options are available. In addition, this preprocessing splits larger images into smaller patches of a fixed size. Size and overlap can be configured in the configuration file or via the arguments of the CLI.

You can also use the underlying functions directly:

darts.legacy_training.preprocess_s2_train_data

preprocess_s2_train_data(
    *,
    bands: list[str],
    sentinel2_dir: pathlib.Path,
    train_data_dir: pathlib.Path,
    arcticdem_dir: pathlib.Path,
    tcvis_dir: pathlib.Path,
    admin_dir: pathlib.Path,
    preprocess_cache: pathlib.Path | None = None,
    device: typing.Literal["cuda", "cpu", "auto"]
    | int
    | None = None,
    dask_worker: int = min(
        16, multiprocessing.cpu_count() - 1
    ),
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 10,
    test_val_split: float = 0.05,
    test_regions: list[str] | None = None,
)

Preprocess Sentinel 2 data for training.

The data is split into a cross-validation, a validation-test and a test set:

- `cross-val` is meant to be used for train and validation
- `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
- `test` leave-out region for testing the spatial distribution shift of the data

Each split is stored as a zarr group, containing a x and a y dataarray. The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size). The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size). Both dataarrays are chunked along the n_patches dimension. This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and therefore in a separate file.

Through the parameters test_val_split and test_regions, the test and validation split can be controlled. To test_regions can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and put them in the test-split. With the test_val_split parameter, the ratio between further splitting of a test-validation set can be controlled.

Through exclude_nopositve and exclude_nan, respective patches can be excluded from the final data.

Further, a config.toml file is saved in the train_data_dir containing the configuration used for the preprocessing. Addionally, a labels.geojson file is saved in the train_data_dir containing the joined labels geometries used for the creation of the binarized label-masks, containing also information about the split via the mode column.

The final directory structure of train_data_dir will look like this:

train_data_dir/
├── config.toml
├── cross-val.zarr/
├── test.zarr/
├── val-test.zarr/
└── labels.geojson

Parameters:

  • bands (list[str]) –

    The bands to be used for training. Must be present in the preprocessing.

  • sentinel2_dir (pathlib.Path) –

    The directory containing the Sentinel 2 scenes.

  • train_data_dir (pathlib.Path) –

    The "output" directory where the tensors are written to.

  • arcticdem_dir (pathlib.Path) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist.

  • tcvis_dir (pathlib.Path) –

    The directory containing the TCVis data.

  • admin_dir (pathlib.Path) –

    The directory containing the admin files.

  • preprocess_cache (pathlib.Path, default: None ) –

    The directory to store the preprocessed data. Defaults to None.

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • dask_worker (int, default: min(16, multiprocessing.cpu_count() - 1) ) –

    The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The overlap to use for inference. Defaults to 16.

  • exclude_nopositive (bool, default: False ) –

    Whether to exclude patches where the labels do not contain positives. Defaults to False.

  • exclude_nan (bool, default: True ) –

    Whether to exclude patches where the input data has nan values. Defaults to True.

  • mask_erosion_size (int, default: 10 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

  • test_val_split (float, default: 0.05 ) –

    The split ratio for the test and validation set. Defaults to 0.05.

  • test_regions (list[str] | str, default: None ) –

    The region to use for the test set. Defaults to None.

Source code in darts/src/darts/legacy_training/preprocess/s2.py
def preprocess_s2_train_data(
    *,
    bands: list[str],
    sentinel2_dir: Path,
    train_data_dir: Path,
    arcticdem_dir: Path,
    tcvis_dir: Path,
    admin_dir: Path,
    preprocess_cache: Path | None = None,
    device: Literal["cuda", "cpu", "auto"] | int | None = None,
    dask_worker: int = min(16, mp.cpu_count() - 1),
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 10,
    test_val_split: float = 0.05,
    test_regions: list[str] | None = None,
):
    """Preprocess Sentinel 2 data for training.

    The data is split into a cross-validation, a validation-test and a test set:

        - `cross-val` is meant to be used for train and validation
        - `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
        - `test` leave-out region for testing the spatial distribution shift of the data

    Each split is stored as a zarr group, containing a x and a y dataarray.
    The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size).
    The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size).
    Both dataarrays are chunked along the n_patches dimension.
    This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and
    therefore in a separate file.

    Through the parameters `test_val_split` and `test_regions`, the test and validation split can be controlled.
    To `test_regions` can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by
    https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and
    put them in the test-split.
    With the `test_val_split` parameter, the ratio between further splitting of a test-validation set can be controlled.

    Through `exclude_nopositve` and `exclude_nan`, respective patches can be excluded from the final data.

    Further, a `config.toml` file is saved in the `train_data_dir` containing the configuration used for the
    preprocessing.
    Addionally, a `labels.geojson` file is saved in the `train_data_dir` containing the joined labels geometries used
    for the creation of the binarized label-masks, containing also information about the split via the `mode` column.

    The final directory structure of `train_data_dir` will look like this:

    ```sh
    train_data_dir/
    ├── config.toml
    ├── cross-val.zarr/
    ├── test.zarr/
    ├── val-test.zarr/
    └── labels.geojson
    ```

    Args:
        bands (list[str]): The bands to be used for training. Must be present in the preprocessing.
        sentinel2_dir (Path): The directory containing the Sentinel 2 scenes.
        train_data_dir (Path): The "output" directory where the tensors are written to.
        arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
            Will be created and downloaded if it does not exist.
        tcvis_dir (Path): The directory containing the TCVis data.
        admin_dir (Path): The directory containing the admin files.
        preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None.
        device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
            If "cuda" take the first device (0), if int take the specified device.
            If "auto" try to automatically select a free GPU (<50% memory usage).
            Defaults to "cuda" if available, else "cpu".
        dask_worker (int, optional): The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).
        ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
            project is defined within persistent API credentials obtained via `earthengine authenticate`.
        ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
        tpi_outer_radius (int, optional): The outer radius of the annulus kernel for the tpi calculation
            in m. Defaults to 100m.
        tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation
            in m. Defaults to 0.
        patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
        overlap (int, optional): The overlap to use for inference. Defaults to 16.
        exclude_nopositive (bool, optional): Whether to exclude patches where the labels do not contain positives.
            Defaults to False.
        exclude_nan (bool, optional): Whether to exclude patches where the input data has nan values.
            Defaults to True.
        mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
            Defaults to 10.
        test_val_split (float, optional): The split ratio for the test and validation set. Defaults to 0.05.
        test_regions (list[str] | str, optional): The region to use for the test set. Defaults to None.

    """
    # Import here to avoid long loading times when running other commands
    import geopandas as gpd
    import pandas as pd
    import toml
    import xarray as xr
    import zarr
    from darts_acquisition import load_arcticdem, load_s2_masks, load_s2_scene, load_tcvis
    from darts_acquisition.s2 import parse_s2_tile_id
    from darts_preprocessing import preprocess_legacy_fast
    from darts_segmentation.training.prepare_training import create_training_patches
    from dask.distributed import Client, LocalCluster
    from lovely_tensors import monkey_patch
    from odc.stac import configure_rio
    from rich.progress import track
    from zarr.codecs import BloscCodec
    from zarr.storage import LocalStore

    from darts.utils.cuda import debug_info, decide_device
    from darts.utils.earthengine import init_ee
    from darts.utils.logging import console

    monkey_patch()
    debug_info()
    device = decide_device(device)
    init_ee(ee_project, ee_use_highvolume)

    with LocalCluster(n_workers=dask_worker) as cluster, Client(cluster) as client:
        logger.info(f"Using Dask client: {client} on cluster {cluster}")
        logger.info(f"Dashboard available at: {client.dashboard_link}")
        configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
        logger.info("Configured Rasterio with Dask")

        # We hardcode these because they depend on the preprocessing used
        norm_factors = {
            "red": 1 / 3000,
            "green": 1 / 3000,
            "blue": 1 / 3000,
            "nir": 1 / 3000,
            "ndvi": 1 / 20000,
            "relative_elevation": 1 / 30000,
            "slope": 1 / 90,
            "tc_brightness": 1 / 255,
            "tc_greenness": 1 / 255,
            "tc_wetness": 1 / 255,
        }
        # Filter out bands that are not in the specified bands
        norm_factors = {k: v for k, v in norm_factors.items() if k in bands}

        train_data_dir.mkdir(exist_ok=True, parents=True)

        zgroups = {
            "cross-val": zarr.group(store=LocalStore(train_data_dir / "cross-val.zarr"), overwrite=True),
            "val-test": zarr.group(store=LocalStore(train_data_dir / "val-test.zarr"), overwrite=True),
            "test": zarr.group(store=LocalStore(train_data_dir / "test.zarr"), overwrite=True),
        }
        # We need do declare the number of patches to 0, because we can't know the final number of patches
        for root in zgroups.values():
            root.create(
                name="x",
                shape=(0, len(bands), patch_size, patch_size),
                # shards=(100, len(bands), patch_size, patch_size),
                chunks=(1, len(bands), patch_size, patch_size),
                dtype="float32",
                compressors=BloscCodec(cname="lz4", clevel=9),
            )
            root.create(
                name="y",
                shape=(0, patch_size, patch_size),
                # shards=(100, patch_size, patch_size),
                chunks=(1, patch_size, patch_size),
                dtype="uint8",
                compressors=BloscCodec(cname="lz4", clevel=9),
            )

        # Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region)
        n_patches = 0
        n_patches_by_mode = {"cross-val": 0, "val-test": 0, "test": 0}
        joint_lables = []
        s2_paths = sorted(sentinel2_dir.glob("*/"))
        logger.info(f"Found {len(s2_paths)} Sentinel 2 scenes in {sentinel2_dir}")
        path_gen = split_dataset_paths(s2_paths, train_data_dir, test_val_split, test_regions, admin_dir)
        for i, (fpath, mode) in track(
            enumerate(path_gen), description="Processing samples", total=len(s2_paths), console=console
        ):
            try:
                _, s2_tile_id, tile_id = parse_s2_tile_id(fpath)

                logger.debug(
                    f"Processing sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}' ({tile_id=}) to split '{mode}'"
                )

                # Check for a cached preprocessed file
                if preprocess_cache and (preprocess_cache / f"{tile_id}.nc").exists():
                    cache_file = preprocess_cache / f"{tile_id}.nc"
                    logger.info(f"Loading preprocessed data from {cache_file.resolve()}")
                    tile = xr.open_dataset(preprocess_cache / f"{tile_id}.nc", engine="h5netcdf").set_coords(
                        "spatial_ref"
                    )
                else:
                    optical = load_s2_scene(fpath)
                    logger.info(f"Found optical tile with size {optical.sizes}")
                    arctidem_res = 10
                    arcticdem_buffer = ceil(tpi_outer_radius / arctidem_res * sqrt(2))
                    arcticdem = load_arcticdem(
                        optical.odc.geobox, arcticdem_dir, resolution=arctidem_res, buffer=arcticdem_buffer
                    )
                    tcvis = load_tcvis(optical.odc.geobox, tcvis_dir)
                    data_masks = load_s2_masks(fpath, optical.odc.geobox)

                    tile: xr.Dataset = preprocess_legacy_fast(
                        optical,
                        arcticdem,
                        tcvis,
                        data_masks,
                        tpi_outer_radius,
                        tpi_inner_radius,
                        device,
                    )
                    # Only cache if we have a cache directory
                    if preprocess_cache:
                        preprocess_cache.mkdir(exist_ok=True, parents=True)
                        cache_file = preprocess_cache / f"{tile_id}.nc"
                        logger.info(f"Caching preprocessed data to {cache_file.resolve()}")
                        tile.to_netcdf(cache_file, engine="h5netcdf")

                labels = gpd.read_file(fpath / f"{s2_tile_id}.shp")

                # Save the patches
                gen = create_training_patches(
                    tile,
                    labels,
                    bands,
                    norm_factors,
                    patch_size,
                    overlap,
                    exclude_nopositive,
                    exclude_nan,
                    device,
                    mask_erosion_size,
                )

                zx = zgroups[mode]["x"]
                zy = zgroups[mode]["y"]
                patch_id = None
                for patch_id, (x, y) in enumerate(gen):
                    zx.append(x.unsqueeze(0).numpy().astype("float32"))
                    zy.append(y.unsqueeze(0).numpy().astype("uint8"))
                    n_patches += 1
                    n_patches_by_mode[mode] += 1
                if n_patches > 0 and len(labels) > 0:
                    labels["mode"] = mode
                    joint_lables.append(labels.to_crs("EPSG:3413"))

                logger.info(
                    f"Processed sample {i + 1} of {len(s2_paths)} '{fpath.resolve()}'"
                    f"({tile_id=}) with {patch_id} patches."
                )
            except KeyboardInterrupt:
                logger.info("Interrupted by user.")
                break

            except Exception as e:
                logger.warning(f"Could not process folder sample {i} '{fpath.resolve()}'.\nSkipping...")
                logger.exception(e)

    # Save the used labels
    joint_lables = pd.concat(joint_lables)
    joint_lables.to_file(train_data_dir / "labels.geojson", driver="GeoJSON")

    # Save a config file as toml
    config = {
        "darts": {
            "sentinel2_dir": sentinel2_dir,
            "train_data_dir": train_data_dir,
            "arcticdem_dir": arcticdem_dir,
            "tcvis_dir": tcvis_dir,
            "bands": bands,
            "norm_factors": norm_factors,
            "device": device,
            "ee_project": ee_project,
            "ee_use_highvolume": ee_use_highvolume,
            "tpi_outer_radius": tpi_outer_radius,
            "tpi_inner_radius": tpi_inner_radius,
            "patch_size": patch_size,
            "overlap": overlap,
            "exclude_nopositive": exclude_nopositive,
            "exclude_nan": exclude_nan,
            "n_patches": n_patches,
        }
    }
    with open(train_data_dir / "config.toml", "w") as f:
        toml.dump(config, f)

    logger.info(f"Saved {n_patches} ({n_patches_by_mode}) patches to {train_data_dir}")

darts.legacy_training.preprocess_planet_train_data

preprocess_planet_train_data(
    *,
    bands: list[str],
    data_dir: pathlib.Path,
    labels_dir: pathlib.Path,
    train_data_dir: pathlib.Path,
    arcticdem_dir: pathlib.Path,
    tcvis_dir: pathlib.Path,
    admin_dir: pathlib.Path,
    preprocess_cache: pathlib.Path | None = None,
    device: typing.Literal["cuda", "cpu", "auto"]
    | int
    | None = None,
    dask_worker: int = min(
        16, multiprocessing.cpu_count() - 1
    ),
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 10,
    test_val_split: float = 0.05,
    test_regions: list[str] | None = None,
)

Preprocess Planet data for training.

The data is split into a cross-validation, a validation-test and a test set:

- `cross-val` is meant to be used for train and validation
- `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
- `test` leave-out region for testing the spatial distribution shift of the data

Each split is stored as a zarr group, containing a x and a y dataarray. The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size). The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size). Both dataarrays are chunked along the n_patches dimension. This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and therefore in a separate file.

Through the parameters test_val_split and test_regions, the test and validation split can be controlled. To test_regions can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and put them in the test-split. With the test_val_split parameter, the ratio between further splitting of a test-validation set can be controlled.

Through exclude_nopositve and exclude_nan, respective patches can be excluded from the final data.

Further, a config.toml file is saved in the train_data_dir containing the configuration used for the preprocessing. Addionally, a labels.geojson file is saved in the train_data_dir containing the joined labels geometries used for the creation of the binarized label-masks, containing also information about the split via the mode column.

The final directory structure of train_data_dir will look like this:

train_data_dir/
├── config.toml
├── cross-val.zarr/
├── test.zarr/
├── val-test.zarr/
└── labels.geojson

Parameters:

  • bands (list[str]) –

    The bands to be used for training. Must be present in the preprocessing.

  • data_dir (pathlib.Path) –

    The directory containing the Planet scenes and orthotiles.

  • labels_dir (pathlib.Path) –

    The directory containing the labels.

  • train_data_dir (pathlib.Path) –

    The "output" directory where the tensors are written to.

  • arcticdem_dir (pathlib.Path) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist.

  • tcvis_dir (pathlib.Path) –

    The directory containing the TCVis data.

  • admin_dir (pathlib.Path) –

    The directory containing the admin files.

  • preprocess_cache (pathlib.Path, default: None ) –

    The directory to store the preprocessed data. Defaults to None.

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • dask_worker (int, default: min(16, multiprocessing.cpu_count() - 1) ) –

    The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The overlap to use for inference. Defaults to 16.

  • exclude_nopositive (bool, default: False ) –

    Whether to exclude patches where the labels do not contain positives. Defaults to False.

  • exclude_nan (bool, default: True ) –

    Whether to exclude patches where the input data has nan values. Defaults to True.

  • mask_erosion_size (int, default: 10 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

  • test_val_split (float, default: 0.05 ) –

    The split ratio for the test and validation set. Defaults to 0.05.

  • test_regions (list[str] | str, default: None ) –

    The region to use for the test set. Defaults to None.

Source code in darts/src/darts/legacy_training/preprocess/planet.py
def preprocess_planet_train_data(
    *,
    bands: list[str],
    data_dir: Path,
    labels_dir: Path,
    train_data_dir: Path,
    arcticdem_dir: Path,
    tcvis_dir: Path,
    admin_dir: Path,
    preprocess_cache: Path | None = None,
    device: Literal["cuda", "cpu", "auto"] | int | None = None,
    dask_worker: int = min(16, mp.cpu_count() - 1),
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 10,
    test_val_split: float = 0.05,
    test_regions: list[str] | None = None,
):
    """Preprocess Planet data for training.

    The data is split into a cross-validation, a validation-test and a test set:

        - `cross-val` is meant to be used for train and validation
        - `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
        - `test` leave-out region for testing the spatial distribution shift of the data

    Each split is stored as a zarr group, containing a x and a y dataarray.
    The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size).
    The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size).
    Both dataarrays are chunked along the n_patches dimension.
    This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and
    therefore in a separate file.

    Through the parameters `test_val_split` and `test_regions`, the test and validation split can be controlled.
    To `test_regions` can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by
    https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and
    put them in the test-split.
    With the `test_val_split` parameter, the ratio between further splitting of a test-validation set can be controlled.

    Through `exclude_nopositve` and `exclude_nan`, respective patches can be excluded from the final data.

    Further, a `config.toml` file is saved in the `train_data_dir` containing the configuration used for the
    preprocessing.
    Addionally, a `labels.geojson` file is saved in the `train_data_dir` containing the joined labels geometries used
    for the creation of the binarized label-masks, containing also information about the split via the `mode` column.

    The final directory structure of `train_data_dir` will look like this:

    ```sh
    train_data_dir/
    ├── config.toml
    ├── cross-val.zarr/
    ├── test.zarr/
    ├── val-test.zarr/
    └── labels.geojson
    ```

    Args:
        bands (list[str]): The bands to be used for training. Must be present in the preprocessing.
        data_dir (Path): The directory containing the Planet scenes and orthotiles.
        labels_dir (Path): The directory containing the labels.
        train_data_dir (Path): The "output" directory where the tensors are written to.
        arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
            Will be created and downloaded if it does not exist.
        tcvis_dir (Path): The directory containing the TCVis data.
        admin_dir (Path): The directory containing the admin files.
        preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None.
        device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
            If "cuda" take the first device (0), if int take the specified device.
            If "auto" try to automatically select a free GPU (<50% memory usage).
            Defaults to "cuda" if available, else "cpu".
        dask_worker (int, optional): The number of Dask workers to use. Defaults to min(16, mp.cpu_count() - 1).
        ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
            project is defined within persistent API credentials obtained via `earthengine authenticate`.
        ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
        tpi_outer_radius (int, optional): The outer radius of the annulus kernel for the tpi calculation
            in m. Defaults to 100m.
        tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation
            in m. Defaults to 0.
        patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
        overlap (int, optional): The overlap to use for inference. Defaults to 16.
        exclude_nopositive (bool, optional): Whether to exclude patches where the labels do not contain positives.
            Defaults to False.
        exclude_nan (bool, optional): Whether to exclude patches where the input data has nan values.
            Defaults to True.
        mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
            Defaults to 10.
        test_val_split (float, optional): The split ratio for the test and validation set. Defaults to 0.05.
        test_regions (list[str] | str, optional): The region to use for the test set. Defaults to None.

    """
    # Import here to avoid long loading times when running other commands
    import geopandas as gpd
    import pandas as pd
    import toml
    import xarray as xr
    import zarr
    from darts_acquisition import load_arcticdem, load_planet_masks, load_planet_scene, load_tcvis
    from darts_preprocessing import preprocess_legacy_fast
    from darts_segmentation.training.prepare_training import create_training_patches
    from dask.distributed import Client, LocalCluster
    from lovely_tensors import monkey_patch
    from odc.stac import configure_rio
    from rich.progress import track
    from zarr.codecs import BloscCodec
    from zarr.storage import LocalStore

    from darts.utils.cuda import debug_info, decide_device
    from darts.utils.earthengine import init_ee
    from darts.utils.logging import console

    monkey_patch()
    debug_info()
    device = decide_device(device)
    init_ee(ee_project, ee_use_highvolume)

    with LocalCluster(n_workers=dask_worker) as cluster, Client(cluster) as client:
        logger.info(f"Using Dask client: {client} on cluster {cluster}")
        logger.info(f"Dashboard available at: {client.dashboard_link}")
        configure_rio(cloud_defaults=True, aws={"aws_unsigned": True}, client=client)
        logger.info("Configured Rasterio with Dask")

        labels = (gpd.read_file(labels_file) for labels_file in labels_dir.glob("*/TrainingLabel*.gpkg"))
        labels = gpd.GeoDataFrame(pd.concat(labels, ignore_index=True))

        footprints = (gpd.read_file(footprints_file) for footprints_file in labels_dir.glob("*/ImageFootprints*.gpkg"))
        footprints = gpd.GeoDataFrame(pd.concat(footprints, ignore_index=True))

        # We hardcode these because they depend on the preprocessing used
        norm_factors = {
            "red": 1 / 3000,
            "green": 1 / 3000,
            "blue": 1 / 3000,
            "nir": 1 / 3000,
            "ndvi": 1 / 20000,
            "relative_elevation": 1 / 30000,
            "slope": 1 / 90,
            "tc_brightness": 1 / 255,
            "tc_greenness": 1 / 255,
            "tc_wetness": 1 / 255,
        }
        # Filter out bands that are not in the specified bands
        norm_factors = {k: v for k, v in norm_factors.items() if k in bands}

        train_data_dir.mkdir(exist_ok=True, parents=True)

        zgroups = {
            "cross-val": zarr.group(store=LocalStore(train_data_dir / "cross-val.zarr"), overwrite=True),
            "val-test": zarr.group(store=LocalStore(train_data_dir / "val-test.zarr"), overwrite=True),
            "test": zarr.group(store=LocalStore(train_data_dir / "test.zarr"), overwrite=True),
        }
        # We need do declare the number of patches to 0, because we can't know the final number of patches
        for root in zgroups.values():
            root.create(
                name="x",
                shape=(0, len(bands), patch_size, patch_size),
                # shards=(100, len(bands), patch_size, patch_size),
                chunks=(1, len(bands), patch_size, patch_size),
                dtype="float32",
                compressor=BloscCodec(cname="lz4", clevel=9),
            )
            root.create(
                name="y",
                shape=(0, patch_size, patch_size),
                # shards=(100, patch_size, patch_size),
                chunks=(1, patch_size, patch_size),
                dtype="uint8",
                compressor=BloscCodec(cname="lz4", clevel=9),
            )

        # Find all Sentinel 2 scenes and split into train+val (cross-val), val-test (variance) and test (region)
        n_patches = 0
        n_patches_by_mode = {"cross-val": 0, "val-test": 0, "test": 0}
        joint_lables = []
        planet_paths = sorted(_legacy_path_gen(data_dir))
        logger.info(f"Found {len(planet_paths)} PLANET scenes and orthotiles in {data_dir}")
        path_gen = split_dataset_paths(
            planet_paths, footprints, train_data_dir, test_val_split, test_regions, admin_dir
        )

        for i, (fpath, mode) in track(
            enumerate(path_gen), description="Processing samples", total=len(planet_paths), console=console
        ):
            try:
                planet_id = fpath.stem
                logger.debug(
                    f"Processing sample {i + 1} of {len(planet_paths)}"
                    f" '{fpath.resolve()}' ({planet_id=}) to split '{mode}'"
                )

                # Check for a cached preprocessed file
                if preprocess_cache and (preprocess_cache / f"{planet_id}.nc").exists():
                    cache_file = preprocess_cache / f"{planet_id}.nc"
                    logger.info(f"Loading preprocessed data from {cache_file.resolve()}")
                    tile = xr.open_dataset(preprocess_cache / f"{planet_id}.nc", engine="h5netcdf").set_coords(
                        "spatial_ref"
                    )
                else:
                    optical = load_planet_scene(fpath)
                    logger.info(f"Found optical tile with size {optical.sizes}")
                    arctidem_res = 2
                    arcticdem_buffer = ceil(tpi_outer_radius / arctidem_res * sqrt(2))
                    arcticdem = load_arcticdem(
                        optical.odc.geobox, arcticdem_dir, resolution=arctidem_res, buffer=arcticdem_buffer
                    )
                    tcvis = load_tcvis(optical.odc.geobox, tcvis_dir)
                    data_masks = load_planet_masks(fpath)

                    tile: xr.Dataset = preprocess_legacy_fast(
                        optical,
                        arcticdem,
                        tcvis,
                        data_masks,
                        tpi_outer_radius,
                        tpi_inner_radius,
                        device,
                    )
                    # Only cache if we have a cache directory
                    if preprocess_cache:
                        preprocess_cache.mkdir(exist_ok=True, parents=True)
                        cache_file = preprocess_cache / f"{planet_id}.nc"
                        logger.info(f"Caching preprocessed data to {cache_file.resolve()}")
                        tile.to_netcdf(cache_file, engine="h5netcdf")

                # Save the patches
                gen = create_training_patches(
                    tile=tile,
                    labels=labels[labels.image_id == planet_id],
                    bands=bands,
                    norm_factors=norm_factors,
                    patch_size=patch_size,
                    overlap=overlap,
                    exclude_nopositive=exclude_nopositive,
                    exclude_nan=exclude_nan,
                    device=device,
                    mask_erosion_size=mask_erosion_size,
                )

                zx = zgroups[mode]["x"]
                zy = zgroups[mode]["y"]
                patch_id = None
                for patch_id, (x, y) in enumerate(gen):
                    zx.append(x.unsqueeze(0).numpy().astype("float32"))
                    zy.append(y.unsqueeze(0).numpy().astype("uint8"))
                    n_patches += 1
                    n_patches_by_mode[mode] += 1
                if n_patches > 0 and len(labels) > 0:
                    labels["mode"] = mode
                    joint_lables.append(labels.to_crs("EPSG:3413"))

                logger.info(
                    f"Processed sample {i + 1} of {len(planet_paths)} '{fpath.resolve()}'"
                    f"({planet_id=}) with {patch_id} patches."
                )

            except KeyboardInterrupt:
                logger.info("Interrupted by user.")
                break

            except Exception as e:
                logger.warning(f"Could not process folder sample {i} '{fpath.resolve()}'.\nSkipping...")
                logger.exception(e)

    # Save the used labels
    joint_lables = pd.concat(joint_lables)
    joint_lables.to_file(train_data_dir / "labels.geojson", driver="GeoJSON")

    # Save a config file as toml
    config = {
        "darts": {
            "data_dir": data_dir,
            "labels_dir": labels_dir,
            "train_data_dir": train_data_dir,
            "arcticdem_dir": arcticdem_dir,
            "tcvis_dir": tcvis_dir,
            "bands": bands,
            "norm_factors": norm_factors,
            "device": device,
            "ee_project": ee_project,
            "ee_use_highvolume": ee_use_highvolume,
            "tpi_outer_radius": tpi_outer_radius,
            "tpi_inner_radius": tpi_inner_radius,
            "patch_size": patch_size,
            "overlap": overlap,
            "exclude_nopositive": exclude_nopositive,
            "exclude_nan": exclude_nan,
            "n_patches": n_patches,
        }
    }
    with open(train_data_dir / "config.toml", "w") as f:
        toml.dump(config, f)

    logger.info(f"Saved {n_patches} ({n_patches_by_mode}) patches to {train_data_dir}")

Simple SMP train and test

To train a simple SMP (Segmentation Model Pytorch) model you can use the command:

[uv run] darts train-smp --your-args-here ...

Configurations for the architecture and encoder can be found in the SMP documentation for model configurations.

Change defaults

Even though the defaults from the CLI are somewhat useful, it is recommended to create a config file and change the behavior of the training there.

This will train a model with the cross-val data and save the model to disk. You don't need to specify the concrete path to the cross-val split, the training script expects that the --train-data-dir points to the root directory of the splits, hence, the same path used in the preprocessing should be specified. The training relies on PyTorch Lightning, which is a high-level interface for PyTorch. It is recommended to use Weights and Biases (wandb) for the logging, because the training script is heavily influenced by how the organization of wandb works.

Each training run is assigned a unique name and id pair and optionally a trial name. The name, which the user can provide, should be used as a grouping mechanism of equal hyperparameter and code. Hence, different versions of the same name should only differ by random state or run settings parameter, like logs. Each version is assigned a unique id. Artifacts (metrics & checkpoints) are then stored under {artifact_dir}/{run_name}/{run_id} in no-crossval runs. If trial_name is specified, the artifacts are stored under {artifact_dir}/{trial_name}/{run_name}-{run_id}. Wandb logs are always stored under {wandb_entity}/{wandb_project}/{run_name}, regardless of trial_name. However, they are further grouped by the trial_name (via job_type), if specified. Both run_name and run_id are also stored in the hparams of each checkpoint.

You can now test the model on the other two splits (val-test and test) with the following command:

[uv run] darts test-smp --your-args-here ...

The checkpoint stored is not usable for the pipeline yet, since it is stored in a different format. To convert the model to a format, you need to convert is first:

[uv run] darts convert-lightning-checkpoint --your-args-here ...
You can also use the underlying functions directly:

darts.legacy_training.train_smp

train_smp(
    *,
    train_data_dir: pathlib.Path,
    artifact_dir: pathlib.Path = pathlib.Path(
        "lightning_logs"
    ),
    fold: int = 0,
    continue_from_checkpoint: pathlib.Path | None = None,
    model_arch: str = "Unet",
    model_encoder: str = "dpn107",
    model_encoder_weights: str | None = None,
    augment: bool = True,
    learning_rate: float = 0.001,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    batch_size: int = 8,
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    early_stopping_patience: int = 5,
    plot_every_n_val_epochs: int = 5,
    random_seed: int = 42,
    num_workers: int = 0,
    device: int | str = "auto",
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
    wandb_group: str | None = None,
    run_name: str | None = None,
    run_id: str | None = None,
    trial_name: str | None = None,
) -> pytorch_lightning.Trainer

Run the training of the SMP model.

Please see https://smp.readthedocs.io/en/latest/index.html for model configurations.

Each training run is assigned a unique name and id pair and optionally a trial name. The name, which the user can provide, should be used as a grouping mechanism of equal hyperparameter and code. Hence, different versions of the same name should only differ by random state or run settings parameter, like logs. Each version is assigned a unique id. Artifacts (metrics & checkpoints) are then stored under {artifact_dir}/{run_name}/{run_id} in no-crossval runs. If trial_name is specified, the artifacts are stored under {artifact_dir}/{trial_name}/{run_name}-{run_id}. Wandb logs are always stored under {wandb_entity}/{wandb_project}/{run_name}, regardless of trial_name. However, they are further grouped by the trial_name (via job_type), if specified. Both run_name and run_id are also stored in the hparams of each checkpoint.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - early_stopping_patience specifies how many epochs to wait for improvement before stopping. In epochs, this would be check_val_every_n_epoch * early_stopping_patience. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── cross-val.zarr/ # this zarr group contains the dataarrays x and y for the training and validation
├── test.zarr/ # this zarr group contains the dataarrays x and y for the left-out-region test set
├── val-test.zarr/ # this zarr group contains the dataarrays x and y for the random selected validation set
└── labels.geojson

Parameters:

  • train_data_dir (pathlib.Path) –

    Path to the training data directory (top-level).

  • artifact_dir (pathlib.Path, default: pathlib.Path('lightning_logs') ) –

    Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs").

  • fold (int, default: 0 ) –

    The current fold to train on. Must be in [0, 4]. Defaults to 0.

  • continue_from_checkpoint (pathlib.Path | None, default: None ) –

    Path to a checkpoint to continue training from. Defaults to None.

  • model_arch (str, default: 'Unet' ) –

    Model architecture to use. Defaults to "Unet".

  • model_encoder (str, default: 'dpn107' ) –

    Encoder to use. Defaults to "dpn107".

  • model_encoder_weights (str | None, default: None ) –

    Path to the encoder weights. Defaults to None.

  • augment (bool, default: True ) –

    Weather to apply augments or not. Defaults to True.

  • learning_rate (float, default: 0.001 ) –

    Learning Rate. Defaults to 1e-3.

  • gamma (float, default: 0.9 ) –

    Multiplicative factor of learning rate decay. Defaults to 0.9.

  • focal_loss_alpha (float, default: None ) –

    Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.

  • focal_loss_gamma (float, default: 2.0 ) –

    Focal loss power factor. Defaults to 2.0.

  • batch_size (int, default: 8 ) –

    Batch Size. Defaults to 8.

  • max_epochs (int, default: 100 ) –

    Maximum number of epochs to train. Defaults to 100.

  • log_every_n_steps (int, default: 10 ) –

    Log every n steps. Defaults to 10.

  • check_val_every_n_epoch (int, default: 3 ) –

    Check validation every n epochs. Defaults to 3.

  • early_stopping_patience (int, default: 5 ) –

    Number of epochs to wait for improvement before stopping. Defaults to 5.

  • plot_every_n_val_epochs (int, default: 5 ) –

    Plot validation samples every n epochs. Defaults to 5.

  • random_seed (int, default: 42 ) –

    Random seed for deterministic training. Defaults to 42.

  • num_workers (int, default: 0 ) –

    Number of Dataloader workers. Defaults to 0.

  • device (int | str, default: 'auto' ) –

    The device to run the model on. Defaults to "auto".

  • wandb_entity (str | None, default: None ) –

    Weights and Biases Entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    Weights and Biases Project. Defaults to None.

  • wandb_group (str | None, default: None ) –

    Wandb group. Usefull for CV-Sweeps. Defaults to None.

  • run_name (str | None, default: None ) –

    Name of this run, as a further grouping method for logs etc. If None, will generate a random one. Defaults to None.

  • run_id (str | None, default: None ) –

    ID of the run. If None, will generate a random one. Defaults to None.

  • trial_name (str | None, default: None ) –

    Name of the cross-validation run / trial. This effects primary logging and artifact storage. If None, will do nothing. Defaults to None.

Returns:

  • Trainer ( pytorch_lightning.Trainer ) –

    The trainer object used for training.

Source code in darts/src/darts/legacy_training/train.py
def train_smp(
    *,
    # Data config
    train_data_dir: Path,
    artifact_dir: Path = Path("lightning_logs"),
    fold: int = 0,
    continue_from_checkpoint: Path | None = None,
    # Hyperparameters
    model_arch: str = "Unet",
    model_encoder: str = "dpn107",
    model_encoder_weights: str | None = None,
    augment: bool = True,
    learning_rate: float = 1e-3,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    batch_size: int = 8,
    # Epoch and Logging config
    max_epochs: int = 100,
    log_every_n_steps: int = 10,
    check_val_every_n_epoch: int = 3,
    early_stopping_patience: int = 5,
    plot_every_n_val_epochs: int = 5,
    # Device and Manager config
    random_seed: int = 42,
    num_workers: int = 0,
    device: int | str = "auto",
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
    wandb_group: str | None = None,
    run_name: str | None = None,
    run_id: str | None = None,
    trial_name: str | None = None,
) -> "pl.Trainer":
    """Run the training of the SMP model.

    Please see https://smp.readthedocs.io/en/latest/index.html for model configurations.

    Each training run is assigned a unique **name** and **id** pair and optionally a trial name.
    The name, which the user _can_ provide, should be used as a grouping mechanism of equal hyperparameter and code.
    Hence, different versions of the same name should only differ by random state or run settings parameter, like logs.
    Each version is assigned a unique id.
    Artifacts (metrics & checkpoints) are then stored under `{artifact_dir}/{run_name}/{run_id}` in no-crossval runs.
    If `trial_name` is specified, the artifacts are stored under `{artifact_dir}/{trial_name}/{run_name}-{run_id}`.
    Wandb logs are always stored under `{wandb_entity}/{wandb_project}/{run_name}`, regardless of `trial_name`.
    However, they are further grouped by the `trial_name` (via job_type), if specified.
    Both `run_name` and `run_id` are also stored in the hparams of each checkpoint.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `early_stopping_patience` specifies how many epochs to wait for improvement before stopping.
            In epochs, this would be `check_val_every_n_epoch * early_stopping_patience`.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── cross-val.zarr/ # this zarr group contains the dataarrays x and y for the training and validation
    ├── test.zarr/ # this zarr group contains the dataarrays x and y for the left-out-region test set
    ├── val-test.zarr/ # this zarr group contains the dataarrays x and y for the random selected validation set
    └── labels.geojson
    ```

    Args:
        train_data_dir (Path): Path to the training data directory (top-level).
        artifact_dir (Path, optional): Path to the training output directory.
            Will contain checkpoints and metrics. Defaults to Path("lightning_logs").
        fold (int, optional): The current fold to train on. Must be in [0, 4]. Defaults to 0.
        continue_from_checkpoint (Path | None, optional): Path to a checkpoint to continue training from.
            Defaults to None.
        model_arch (str, optional): Model architecture to use. Defaults to "Unet".
        model_encoder (str, optional): Encoder to use. Defaults to "dpn107".
        model_encoder_weights (str | None, optional): Path to the encoder weights. Defaults to None.
        augment (bool, optional): Weather to apply augments or not. Defaults to True.
        learning_rate (float, optional): Learning Rate. Defaults to 1e-3.
        gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.9.
        focal_loss_alpha (float, optional): Weight factor to balance positive and negative samples.
            Alpha must be in [0...1] range, high values will give more weight to positive class.
            None will not weight samples. Defaults to None.
        focal_loss_gamma (float, optional): Focal loss power factor. Defaults to 2.0.
        batch_size (int, optional): Batch Size. Defaults to 8.
        max_epochs (int, optional): Maximum number of epochs to train. Defaults to 100.
        log_every_n_steps (int, optional): Log every n steps. Defaults to 10.
        check_val_every_n_epoch (int, optional): Check validation every n epochs. Defaults to 3.
        early_stopping_patience (int, optional): Number of epochs to wait for improvement before stopping.
            Defaults to 5.
        plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.
        random_seed (int, optional): Random seed for deterministic training. Defaults to 42.
        num_workers (int, optional): Number of Dataloader workers. Defaults to 0.
        device (int | str, optional): The device to run the model on. Defaults to "auto".
        wandb_entity (str | None, optional): Weights and Biases Entity. Defaults to None.
        wandb_project (str | None, optional): Weights and Biases Project. Defaults to None.
        wandb_group (str | None, optional): Wandb group. Usefull for CV-Sweeps. Defaults to None.
        run_name (str | None, optional): Name of this run, as a further grouping method for logs etc.
            If None, will generate a random one. Defaults to None.
        run_id (str | None, optional): ID of the run. If None, will generate a random one. Defaults to None.
        trial_name (str | None, optional): Name of the cross-validation run / trial.
            This effects primary logging and artifact storage.
            If None, will do nothing. Defaults to None.

    Returns:
        Trainer: The trainer object used for training.

    """
    import lightning as L  # noqa: N812
    import lovely_tensors
    import torch
    from darts_segmentation.segment import SMPSegmenterConfig
    from darts_segmentation.training.callbacks import BinarySegmentationMetrics
    from darts_segmentation.training.data import DartsDataModule
    from darts_segmentation.training.module import SMPSegmenter
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar
    from lightning.pytorch.loggers import CSVLogger, WandbLogger

    from darts.legacy_training.util import generate_id, get_generated_name
    from darts.utils.logging import LoggingManager

    LoggingManager.apply_logging_handlers("lightning.pytorch")

    tick_fstart = time.perf_counter()

    # Create unique run identification (name can be specified by user, id can be interpreded as a 'version')
    run_name = run_name or get_generated_name(artifact_dir)
    run_id = run_id or generate_id()

    logger.info(f"Starting training '{run_name}' ('{run_id}') with data from {train_data_dir.resolve()}.")
    logger.debug(
        f"Using config:\n\t{model_arch=}\n\t{model_encoder=}\n\t{model_encoder_weights=}\n\t{augment=}\n\t"
        f"{learning_rate=}\n\t{gamma=}\n\t{batch_size=}\n\t{max_epochs=}\n\t{log_every_n_steps=}\n\t"
        f"{check_val_every_n_epoch=}\n\t{early_stopping_patience=}\n\t{plot_every_n_val_epochs=}\n\t{num_workers=}"
        f"\n\t{device=}\n\t{random_seed=}"
    )

    lovely_tensors.monkey_patch()

    torch.set_float32_matmul_precision("medium")
    seed_everything(random_seed, workers=True)

    preprocess_config = toml.load(train_data_dir / "config.toml")["darts"]

    config = SMPSegmenterConfig(
        input_combination=preprocess_config["bands"],
        model={
            "arch": model_arch,
            "encoder_name": model_encoder,
            "encoder_weights": model_encoder_weights,
            "in_channels": len(preprocess_config["bands"]),
            "classes": 1,
        },
        norm_factors=preprocess_config["norm_factors"],
    )

    # Data and model
    datamodule = DartsDataModule(
        data_dir=train_data_dir / "cross-val.zarr",
        batch_size=batch_size,
        fold=fold,
        augment=augment,
        num_workers=num_workers,
    )
    model = SMPSegmenter(
        config=config,
        learning_rate=learning_rate,
        gamma=gamma,
        focal_loss_alpha=focal_loss_alpha,
        focal_loss_gamma=focal_loss_gamma,
        # These are only stored in the hparams and are not used
        run_id=run_id,
        run_name=run_name,
        trial_name=trial_name,
        random_seed=random_seed,
    )

    # Loggers
    is_crossval = bool(trial_name)
    trainer_loggers = [
        CSVLogger(
            save_dir=artifact_dir,
            name=run_name if not is_crossval else trial_name,
            version=run_id if not is_crossval else f"{run_name}-{run_id}",
        ),
    ]
    logger.debug(f"Logging CSV to {Path(trainer_loggers[0].log_dir).resolve()}")
    if wandb_entity and wandb_project:
        wandb_logger = WandbLogger(
            save_dir=artifact_dir,
            name=run_name,
            version=run_id,
            project=wandb_project,
            entity=wandb_entity,
            resume="allow",
            group=wandb_group,
            job_type=trial_name,
        )
        trainer_loggers.append(wandb_logger)
        logger.debug(
            f"Logging to WandB with entity '{wandb_entity}' and project '{wandb_project}'."
            f"Artifacts are logged to {(Path(wandb_logger.save_dir) / 'wandb').resolve()}"
        )

    # Callbacks
    callbacks = [
        RichProgressBar(),
        BinarySegmentationMetrics(
            input_combination=config["input_combination"],
            val_set=f"val{fold}",
            plot_every_n_val_epochs=plot_every_n_val_epochs,
            is_crossval=is_crossval,
        ),
    ]
    if early_stopping_patience:
        logger.debug(f"Using EarlyStopping with patience {early_stopping_patience}")
        early_stopping = EarlyStopping(monitor="val/JaccardIndex", mode="max", patience=early_stopping_patience)
        callbacks.append(early_stopping)

    # Train
    trainer = L.Trainer(
        max_epochs=max_epochs,
        callbacks=callbacks,
        log_every_n_steps=log_every_n_steps,
        logger=trainer_loggers,
        check_val_every_n_epoch=check_val_every_n_epoch,
        accelerator="gpu" if isinstance(device, int) else device,
        devices=[device] if isinstance(device, int) else device,
        deterministic=False,
    )
    trainer.fit(model, datamodule, ckpt_path=continue_from_checkpoint)

    tick_fend = time.perf_counter()
    logger.info(f"Finished training '{run_name}' in {tick_fend - tick_fstart:.2f}s.")

    if wandb_entity and wandb_project:
        wandb_logger.finalize("success")
        wandb_logger.experiment.finish(exit_code=0)
        logger.debug(f"Finalized WandB logging for '{run_name}'")

    return trainer

darts.legacy_training.test_smp

test_smp(
    *,
    train_data_dir: pathlib.Path,
    run_id: str,
    run_name: str,
    model_ckp: pathlib.Path | None = None,
    batch_size: int = 8,
    artifact_dir: pathlib.Path = pathlib.Path(
        "lightning_logs"
    ),
    num_workers: int = 0,
    device: int | str = "auto",
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
) -> pytorch_lightning.Trainer

Run the testing of the SMP model.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── cross-val.zarr/ # this zarr group contains the dataarrays x and y for the training and validation
├── test.zarr/ # this zarr group contains the dataarrays x and y for the left-out-region test set
├── val-test.zarr/ # this zarr group contains the dataarrays x and y for the random selected validation set
└── labels.geojson

Parameters:

  • train_data_dir (pathlib.Path) –

    Path to the training data directory (top-level).

  • run_id (str) –

    ID of the run.

  • run_name (str) –

    Name of the run.

  • model_ckp (pathlib.Path | None, default: None ) –

    Path to the model checkpoint. If None, try to find the latest checkpoint in artifact_dir / run_name / run_id / checkpoints. Defaults to None.

  • batch_size (int, default: 8 ) –

    Batch size. Defaults to 8.

  • artifact_dir (pathlib.Path, default: pathlib.Path('lightning_logs') ) –

    Directory to save artifacts. Defaults to Path("lightning_logs").

  • num_workers (int, default: 0 ) –

    Number of workers for the DataLoader. Defaults to 0.

  • device (int | str, default: 'auto' ) –

    Device to use. Defaults to "auto".

  • wandb_entity (str | None, default: None ) –

    WandB entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    WandB project. Defaults to None.

Returns:

  • Trainer ( pytorch_lightning.Trainer ) –

    The trainer object used for training.

Source code in darts/src/darts/legacy_training/test.py
def test_smp(
    *,
    train_data_dir: Path,
    run_id: str,
    run_name: str,
    model_ckp: Path | None = None,
    batch_size: int = 8,
    artifact_dir: Path = Path("lightning_logs"),
    num_workers: int = 0,
    device: int | str = "auto",
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
) -> "pl.Trainer":
    """Run the testing of the SMP model.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── cross-val.zarr/ # this zarr group contains the dataarrays x and y for the training and validation
    ├── test.zarr/ # this zarr group contains the dataarrays x and y for the left-out-region test set
    ├── val-test.zarr/ # this zarr group contains the dataarrays x and y for the random selected validation set
    └── labels.geojson
    ```

    Args:
        train_data_dir (Path): Path to the training data directory (top-level).
        run_id (str): ID of the run.
        run_name (str): Name of the run.
        model_ckp (Path | None): Path to the model checkpoint.
            If None, try to find the latest checkpoint in `artifact_dir / run_name / run_id / checkpoints`.
            Defaults to None.
        batch_size (int, optional): Batch size. Defaults to 8.
        artifact_dir (Path, optional): Directory to save artifacts. Defaults to Path("lightning_logs").
        num_workers (int, optional): Number of workers for the DataLoader. Defaults to 0.
        device (int | str, optional): Device to use. Defaults to "auto".
        wandb_entity (str | None, optional): WandB entity. Defaults to None.
        wandb_project (str | None, optional): WandB project. Defaults to None.

    Returns:
        Trainer: The trainer object used for training.

    """
    import lightning as L  # noqa: N812
    import lovely_tensors
    import torch
    from darts_segmentation.training.callbacks import BinarySegmentationMetrics
    from darts_segmentation.training.data import DartsDataModule
    from darts_segmentation.training.module import SMPSegmenter
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import RichProgressBar
    from lightning.pytorch.loggers import CSVLogger, WandbLogger

    from darts.utils.logging import LoggingManager

    LoggingManager.apply_logging_handlers("lightning.pytorch")

    tick_fstart = time.perf_counter()
    logger.info(f"Starting testing '{run_name}' ('{run_id}') with data from {train_data_dir.resolve()}.")
    logger.debug(f"Using config:\n\t{batch_size=}\n\t{device=}")

    lovely_tensors.monkey_patch()

    torch.set_float32_matmul_precision("medium")
    seed_everything(42, workers=True)

    preprocess_config = toml.load(train_data_dir / "config.toml")["darts"]

    # Data and model
    datamodule_val_test = DartsDataModule(
        data_dir=train_data_dir / "val-test.zarr",
        batch_size=batch_size,
        num_workers=num_workers,
    )
    datamodule_test = DartsDataModule(
        data_dir=train_data_dir / "test.zarr",
        batch_size=batch_size,
        num_workers=num_workers,
    )
    # Try to infer model checkpoint if not given
    if model_ckp is None:
        checkpoint_dir = artifact_dir / run_name / run_id / "checkpoints"
        logger.debug(f"No checkpoint provided. Looking for model checkpoint in {checkpoint_dir.resolve()}")
        model_ckp = max(checkpoint_dir.glob("*.ckpt"), key=lambda x: x.stat().st_mtime)
    model = SMPSegmenter.load_from_checkpoint(model_ckp)

    # Loggers
    trainer_loggers = [
        CSVLogger(save_dir=artifact_dir, name=run_name, version=run_id),
    ]
    logger.debug(f"Logging CSV to {Path(trainer_loggers[0].log_dir).resolve()}")
    if wandb_entity and wandb_project:
        wandb_logger = WandbLogger(
            save_dir=artifact_dir,
            name=run_name,
            id=run_id,
            project=wandb_project,
            entity=wandb_entity,
        )
        trainer_loggers.append(wandb_logger)
        logger.debug(
            f"Logging to WandB with entity '{wandb_entity}' and project '{wandb_project}'."
            f"Artifacts are logged to {(Path(wandb_logger.save_dir) / 'wandb').resolve()}"
        )

    # Callbacks
    metrics_cb = BinarySegmentationMetrics(
        input_combination=preprocess_config["bands"],
    )
    callbacks = [
        RichProgressBar(),
        metrics_cb,
    ]

    # Test
    trainer = L.Trainer(
        callbacks=callbacks,
        logger=trainer_loggers,
        accelerator="gpu" if isinstance(device, int) else device,
        devices=[device] if isinstance(device, int) else device,
        deterministic=True,
    )
    # Overwrite the names of the test sets to test agains two separate sets
    metrics_cb.test_set = "val-test"
    model.test_set = "val-test"
    trainer.test(model, datamodule_val_test, ckpt_path=model_ckp)
    metrics_cb.test_set = "test"
    model.test_set = "test"
    trainer.test(model, datamodule_test)

    tick_fend = time.perf_counter()
    logger.info(f"Finished testing '{run_name}' in {tick_fend - tick_fstart:.2f}s.")

    if wandb_entity and wandb_project:
        wandb_logger.finalize("success")
        wandb_logger.experiment.finish(exit_code=0)
        logger.debug(f"Finalized WandB logging for '{run_name}'")

    return trainer

darts.legacy_training.convert_lightning_checkpoint

convert_lightning_checkpoint(
    *,
    lightning_checkpoint: pathlib.Path,
    out_directory: pathlib.Path,
    checkpoint_name: str,
    framework: str = "smp",
)

Convert a lightning checkpoint to our own format.

The final checkpoint will contain the model configuration and the state dict. It will be saved to:

    out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"

Parameters:

  • lightning_checkpoint (pathlib.Path) –

    Path to the lightning checkpoint.

  • out_directory (pathlib.Path) –

    Output directory for the converted checkpoint.

  • checkpoint_name (str) –

    A unique name of the new checkpoint.

  • framework (str, default: 'smp' ) –

    The framework used for the model. Defaults to "smp".

Source code in darts/src/darts/legacy_training/util.py
def convert_lightning_checkpoint(
    *,
    lightning_checkpoint: Path,
    out_directory: Path,
    checkpoint_name: str,
    framework: str = "smp",
):
    """Convert a lightning checkpoint to our own format.

    The final checkpoint will contain the model configuration and the state dict.
    It will be saved to:

    ```python
        out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"
    ```

    Args:
        lightning_checkpoint (Path): Path to the lightning checkpoint.
        out_directory (Path): Output directory for the converted checkpoint.
        checkpoint_name (str): A unique name of the new checkpoint.
        framework (str, optional): The framework used for the model. Defaults to "smp".

    """
    import torch

    logger.debug(f"Loading checkpoint from {lightning_checkpoint.resolve()}")
    lckpt = torch.load(lightning_checkpoint, weights_only=False, map_location=torch.device("cpu"))

    now = datetime.now()
    formatted_date = now.strftime("%Y-%m-%d")
    config = lckpt["hyper_parameters"]["config"]
    del config["model"]["encoder_weights"]
    config["time"] = formatted_date
    config["name"] = checkpoint_name
    config["model_framework"] = framework

    statedict = lckpt["state_dict"]
    # Statedict has model. prefix before every weight. We need to remove them. This is an in-place function
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(statedict, "model.")

    own_ckpt = {
        "config": config,
        "statedict": lckpt["state_dict"],
    }

    out_directory.mkdir(exist_ok=True, parents=True)

    out_checkpoint = out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"

    torch.save(own_ckpt, out_checkpoint)

    logger.info(f"Saved converted checkpoint to {out_checkpoint.resolve()}")

Run a cross-validation hyperparameter sweep

Terminal Multiplexers

It is recommended to use a terminal multiplexer like tmux, screen or zellij to run multiple training runs in parallel. This way there is no need to have multiple terminal open over the span of multiple days.

To sweep over a certrain set of hyperparameters, some preparations are necessary:

  1. Create a sweep configuration file in YAML format. This file should contain the hyperparameters to sweep over and the search space for each hyperparameter.
  2. Setup a PostgreSQL database to store the results of the sweep, so we can run multiple runs in parallel with Optuna.

The sweep configuration file should look like a wandb sweep configuration. All values will be parsed and transformed to fit to an optuna sweep.

To setup the PostgreSQL database, search for an appropriate guide on how to setup a PostgreSQL database. There are many ways to do this, depending on your environment. The only important thing is that the database is reachable from the machine you are running the sweep on.

Now you can setup the sweep with the following command:

uv run darts optuna-sweep-smp --your-args-here ... --device 0

This will output some information about the sweep, especially the sweep id. In addition, it will start running trials on the CUDA:0 device.

Starting and continuing sweeps

Starting and continuing sweeps is done via the same optuna-sweep-smp command. Depending on the two arguments -sweep-id and device, the command will decide what to do. If the sweep-id is not specified, a new sweep will be started. If the sweep-id is specified, the sweep will continue from the last run. If the device is specified, n-trials will be started on the specified device (sequentially). If the device is not specified, but sweep-id is, then an error will be raised. If neither device nor sweep-id is specified, then a new sweep will be created without starting trials.

To start a second runner, you must open a new terminal (or panel/window in a terminal multiplexer) and run the following command:

uv run darts optuna-sweep-smp --your-args-here ... --device 1 --sweep-id <sweep-id>

Multiple runners

You can run as many runners as you have devices available. Each runner will start n trials sequentially, specified by n-trials, which each request a new hyperparameter-combination from optuna. Each trial further creates multiple runs, depending on the n_folds and n_randoms parameters. This is the cross-validation part: Each trial, hence same hyperparameter-combination, is run n_folds times with n_randoms different random seeds. Therefore, the total number of runs done by a runner is n-trials * n_folds * n_randoms. This should ensure that a single random good (or bad) run does not influence the overall result of a hyperparameter-combination.

Example config and sweep-config files

For better readability, the example config file uses different sub-headings which are not necessary and could be named differently or even removed. The only important heading is the [darts] heading, which is the root of the configuration file. Every value which is not under a darts top-level heading is ignored, as descriped in the Configuration Guide.

The following config.toml expects that the labels are cloned from the ML_training_labels repository and that PLANET scenes and tiles are downloaded into the /large-storage/planet_data directory. The resulting file structure would look like this:

File structure under cd .
./
├── ../ML_training_labels/retrogressive_thaw_slumps/
├── darts/
├── logs/
└── configs/
    ├── planet-sweep-config.toml
    └── planet-tcvis-sweep.yaml
File structure under /large-storage/
/large-storage/
├── planet_data/
└── darts-nextgen/
    ├── artifacts/
    └── data/
        ├── training/
           └── planet_native_tcvis_896_partial/
        ├── cache/
        ├── datacubes/
           ├── arcticdem/
           └── tcvis/
        └── aux/admin/
File structure under /fast-storage/
/fast-storage/
└── darts-nextgen/
    └── data/
        └── training/
            └── planet_native_tcvis_896_partial/
configs/planet-sweep-config.toml
[darts.wandb]
wandb-project = "darts"
wandb-entity = "your-wandb-username"

[darts.sweep]
n-trials = 100
sweep-db = "postgresql://pguser@localhost:5432/sweeps"
n_folds = 3
n_randoms = 3
sweep-id = "sweep-cv-large-planet"

[darts.training]
num-workers = 16
max-epochs = 60
log-every-n-steps = 100
check-val-every-n-epoch = 5
plot-every-n-val-epochs = 4 # == 20 epochs
early-stopping-patience = 0

# These are the default one, if not specified in the sweep-config
[darts.hyperparameters]
batch-size = 6
augment = true

[darts.training_preprocess]
ee-project = "your-ee-project"
tpi-outer-radius = 100
tpi-inner-radius = 0
bands = [
    'blue',
    'green',
    'red',
    'nir',
    'ndvi',
    'tc_brightness',
    'tc_greenness',
    'tc_wetness',
    'relative_elevation',
    'slope',
]
patch-size = 896
overlap = 0 # increase to 64 if exclude-nan = True
exclude-nopositive = false
exclude-nan = false
test-val-split = 0.05
test-regions = ['Taymyrsky Dolgano-Nenetsky District']

[darts.paths]
data-dir = "/large-storage/planet_data"
labels-dir = "../ML_training_labels/retrogressive_thaw_slumps" # (1)
arcticdem-dir = "/large-storage/darts-nextgen/data/datacubes/arcticdem"
tcvis-dir = "/large-storage/darts-nextgen/data/datacubes/tcvis"
admin-dir = "/large-storage/darts-nextgen/data/aux/admin"
train-data-dir = "/fast-storage/darts-nextgen/data/training/planet_native_tcvis_896_partial" # (2)
preprocess-cache = "/large-storage/darts-nextgen/data/cache"
sweep-config = "configs/planet-tcvis-sweep.yaml"
artifact-dir = "/large-storage/darts-nextgen/artifacts"
  1. Clone this repository to obtain the labels for the training data.
  2. The train-data-dir should point to a fast read-access storage, like a local mounted SSD to speed up the training process.
configs/planet-tcvis-sweep.yaml
name: planet-tcvis-large
method: random
metric:
  goal: maximize
  name: val0/JaccardIndex
parameters:
  learning_rate:
    max: !!float 1e-3
    min: !!float 1e-5
    distribution: log_uniform_values
  gamma: # How fast the learning rate will decrease
    value: 0.997
  focal_loss_alpha: # How much the positive class is weighted
    min: 0.8
    max: 0.99
  focal_loss_gamma: # How much focus should be given to "bad" predictions
    min: 0.0
    max: 2.0
  model_arch:
    values:
      - Unet
      - MAnet
      - UPerNet
      - Segformer
  model_encoder:
    values:
      - resnet50
      - resnext50_32x4d
      - mit_b2
      - tu-convnextv2_tiny
      - tu-maxvit_tiny_rw_224