Skip to content

cli

darts.cli

Entrypoint for the darts-pipeline CLI.

LoggingManager module-attribute

LoggingManager = (
    darts.utils.logging.LoggingManagerSingleton()
)

__version__ module-attribute

__version__ = importlib.metadata.version('darts-nextgen')

app module-attribute

app = cyclopts.App(
    version=darts.__version__,
    console=rich.get_console(),
    config=darts.cli.config_parser,
    help_format="plaintext",
    version_format="plaintext",
)

config_parser module-attribute

config_parser = darts.utils.config.ConfigParser()

data_group module-attribute

data_group = cyclopts.Group.create_ordered('Data Commands')

logger module-attribute

logger = logging.getLogger(__name__)

pipeline_group module-attribute

pipeline_group = cyclopts.Group.create_ordered(
    "Pipeline Commands"
)

root_file module-attribute

root_file = pathlib.Path(__file__).resolve()

train_group module-attribute

train_group = cyclopts.Group.create_ordered(
    "Training Commands"
)

AOISentinel2Pipeline dataclass

AOISentinel2Pipeline(
    model_files: list[pathlib.Path] = None,
    output_data_dir: pathlib.Path = pathlib.Path(
        "data/output"
    ),
    arcticdem_dir: pathlib.Path = pathlib.Path(
        "data/download/arcticdem"
    ),
    tcvis_dir: pathlib.Path = pathlib.Path(
        "data/download/tcvis"
    ),
    device: typing.Literal["cuda", "cpu", "auto"]
    | int
    | None = None,
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 256,
    batch_size: int = 8,
    reflection: int = 0,
    binarization_threshold: float = 0.5,
    mask_erosion_size: int = 10,
    min_object_size: int = 32,
    quality_level: int
    | typing.Literal[
        "high_quality", "low_quality", "none"
    ] = 1,
    export_bands: list[str] = lambda: [
        "probabilities",
        "binarized",
        "polygonized",
        "extent",
        "thumbnail",
    ](),
    write_model_outputs: bool = False,
    overwrite: bool = False,
    aoi_shapefile: pathlib.Path = None,
    start_date: str = None,
    end_date: str = None,
    max_cloud_cover: int = 10,
    input_cache: pathlib.Path = pathlib.Path(
        "data/cache/input"
    ),
)

Bases: darts.pipelines.sequential_v2._BasePipeline

Pipeline for Sentinel 2 data based on an area of interest.

Parameters:

  • aoi_shapefile (pathlib.Path, default: None ) –

    The shapefile containing the area of interest.

  • start_date (str, default: None ) –

    The start date of the time series in YYYY-MM-DD format.

  • end_date (str, default: None ) –

    The end date of the time series in YYYY-MM-DD format.

  • max_cloud_cover (int, default: 10 ) –

    The maximum cloud cover percentage to use for filtering the Sentinel 2 scenes. Defaults to 10.

  • input_cache (pathlib.Path, default: pathlib.Path('data/cache/input') ) –

    The directory to use for caching the input data. Defaults to Path("data/cache/input").

  • model_files (pathlib.Path | list[pathlib.Path], default: None ) –

    The path to the models to use for segmentation. Can also be a single Path to only use one model. This implies write_model_outputs=False If a list is provided, will use an ensemble of the models.

  • output_data_dir (pathlib.Path, default: pathlib.Path('data/output') ) –

    The "output" directory. Defaults to Path("data/output").

  • arcticdem_dir (pathlib.Path, default: pathlib.Path('data/download/arcticdem') ) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist. Defaults to Path("data/download/arcticdem").

  • tcvis_dir (pathlib.Path, default: pathlib.Path('data/download/tcvis') ) –

    The directory containing the TCVis data. Defaults to Path("data/download/tcvis").

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 256 ) –

    The overlap to use for inference. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size to use for inference. Defaults to 8.

  • reflection (int, default: 0 ) –

    The reflection padding to use for inference. Defaults to 0.

  • binarization_threshold (float, default: 0.5 ) –

    The threshold to binarize the probabilities. Defaults to 0.5.

  • mask_erosion_size (int, default: 10 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

  • min_object_size (int, default: 32 ) –

    The minimum object size to keep in pixel. Defaults to 32.

  • quality_level (int | typing.Literal['high_quality', 'low_quality', 'none'], default: 1 ) –

    The quality level to use for the segmentation. Can also be an int. In this case 0="none" 1="low_quality" 2="high_quality". Defaults to 1.

  • export_bands (list[str], default: lambda: ['probabilities', 'binarized', 'polygonized', 'extent', 'thumbnail']() ) –

    The bands to export. Can be a list of "probabilities", "binarized", "polygonized", "extent", "thumbnail", "optical", "dem", "tcvis" or concrete band-names. Defaults to ["probabilities", "binarized", "polygonized", "extent", "thumbnail"].

  • write_model_outputs (bool, default: False ) –

    Also save the model outputs, not only the ensemble result. Defaults to False.

  • overwrite (bool, default: False ) –

    Whether to overwrite existing files. Defaults to False.

aoi_shapefile class-attribute instance-attribute

aoi_shapefile: pathlib.Path = None

arcticdem_dir class-attribute instance-attribute

arcticdem_dir: pathlib.Path = pathlib.Path(
    "data/download/arcticdem"
)

batch_size class-attribute instance-attribute

batch_size: int = 8

binarization_threshold class-attribute instance-attribute

binarization_threshold: float = 0.5

device class-attribute instance-attribute

device: (
    typing.Literal["cuda", "cpu", "auto"] | int | None
) = None

ee_project class-attribute instance-attribute

ee_project: str | None = None

ee_use_highvolume class-attribute instance-attribute

ee_use_highvolume: bool = True

end_date class-attribute instance-attribute

end_date: str = None

export_bands class-attribute instance-attribute

export_bands: list[str] = dataclasses.field(
    default_factory=lambda: [
        "probabilities",
        "binarized",
        "polygonized",
        "extent",
        "thumbnail",
    ]
)

input_cache class-attribute instance-attribute

input_cache: pathlib.Path = pathlib.Path("data/cache/input")

mask_erosion_size class-attribute instance-attribute

mask_erosion_size: int = 10

max_cloud_cover class-attribute instance-attribute

max_cloud_cover: int = 10

min_object_size class-attribute instance-attribute

min_object_size: int = 32

model_files class-attribute instance-attribute

model_files: list[pathlib.Path] = None

output_data_dir class-attribute instance-attribute

output_data_dir: pathlib.Path = pathlib.Path('data/output')

overlap class-attribute instance-attribute

overlap: int = 256

overwrite class-attribute instance-attribute

overwrite: bool = False

patch_size class-attribute instance-attribute

patch_size: int = 1024

quality_level class-attribute instance-attribute

quality_level: (
    int
    | typing.Literal["high_quality", "low_quality", "none"]
) = 1

reflection class-attribute instance-attribute

reflection: int = 0

start_date class-attribute instance-attribute

start_date: str = None

tcvis_dir class-attribute instance-attribute

tcvis_dir: pathlib.Path = pathlib.Path(
    "data/download/tcvis"
)

tpi_inner_radius class-attribute instance-attribute

tpi_inner_radius: int = 0

tpi_outer_radius class-attribute instance-attribute

tpi_outer_radius: int = 100

write_model_outputs class-attribute instance-attribute

write_model_outputs: bool = False

cli staticmethod

Run the sequential pipeline for AOI Sentinel 2 data.

Source code in darts/src/darts/pipelines/sequential_v2.py
@staticmethod
def cli(*, pipeline: "AOISentinel2Pipeline"):
    """Run the sequential pipeline for AOI Sentinel 2 data."""
    pipeline.run()

run

run()
Source code in darts/src/darts/pipelines/sequential_v2.py
def run(self):  # noqa: C901
    if self.model_files is None or len(self.model_files) == 0:
        raise ValueError("No model files provided. Please provide a list of model files.")
    if len(self.export_bands) == 0:
        raise ValueError("No export bands provided. Please provide a list of export bands.")

    current_time = time.strftime("%Y-%m-%d_%H-%M-%S")
    logger.info(f"Starting pipeline at {current_time}.")

    # Storing the configuration as JSON file
    self.output_data_dir.mkdir(parents=True, exist_ok=True)
    with open(self.output_data_dir / f"{current_time}.config.json", "w") as f:
        config = asdict(self)
        # Convert everything to json serializable
        for key, value in config.items():
            if isinstance(value, Path):
                config[key] = str(value.resolve())
            elif isinstance(value, list):
                config[key] = [str(v.resolve()) if isinstance(v, Path) else v for v in value]
        json.dump(config, f)

    from stopuhr import Chronometer

    timer = Chronometer(printer=logger.debug)

    from darts.utils.cuda import debug_info

    debug_info()

    from darts.utils.earthengine import init_ee

    init_ee(self.ee_project, self.ee_use_highvolume)

    import pandas as pd
    import smart_geocubes
    import torch
    from darts_acquisition import load_arcticdem, load_tcvis
    from darts_ensemble import EnsembleV1
    from darts_export import export_tile, missing_outputs
    from darts_postprocessing import prepare_export
    from darts_preprocessing import preprocess_legacy_fast

    from darts.utils.cuda import decide_device
    from darts.utils.logging import LoggingManager

    self.device = decide_device(self.device)

    # determine models to use
    if isinstance(self.model_files, Path):
        self.model_files = [self.model_files]
        self.write_model_outputs = False
    models = {model_file.stem: model_file for model_file in self.model_files}
    ensemble = EnsembleV1(models, device=torch.device(self.device))

    # Create the datacubes if they do not exist
    LoggingManager.apply_logging_handlers("smart_geocubes")
    arcticdem_resolution = self._arcticdem_resolution()
    if arcticdem_resolution == 2:
        accessor = smart_geocubes.ArcticDEM2m(self.arcticdem_dir)
    elif arcticdem_resolution == 10:
        accessor = smart_geocubes.ArcticDEM10m(self.arcticdem_dir)
    if not accessor.created:
        accessor.create(overwrite=False)
    accessor = smart_geocubes.TCTrend(self.tcvis_dir)
    if not accessor.created:
        accessor.create(overwrite=False)

    # Iterate over all the data
    tileinfo = self._tileinfos()
    n_tiles = 0
    logger.info(f"Found {len(tileinfo)} tiles to process.")
    results = []
    for i, (tilekey, outpath) in enumerate(tileinfo):
        tile_id = self._get_tile_id(tilekey)
        try:
            if not self.overwrite:
                mo = missing_outputs(outpath, bands=self.export_bands, ensemble_subsets=models.keys())
                if mo == "none":
                    logger.info(f"Tile {tile_id} already processed. Skipping...")
                    continue
                if mo == "some":
                    logger.warning(
                        f"Tile {tile_id} already processed. Some outputs are missing."
                        " Skipping because overwrite=False..."
                    )
                    continue

            with timer("Loading optical data", log=False):
                tile = self._load_tile(tilekey)
            with timer("Loading ArcticDEM", log=False):
                arcticdem = load_arcticdem(
                    tile.odc.geobox,
                    self.arcticdem_dir,
                    resolution=arcticdem_resolution,
                    buffer=ceil(self.tpi_outer_radius / arcticdem_resolution * sqrt(2)),
                )
            with timer("Loading TCVis", log=False):
                tcvis = load_tcvis(tile.odc.geobox, self.tcvis_dir)
            with timer("Preprocessing tile", log=False):
                tile = preprocess_legacy_fast(
                    tile,
                    arcticdem,
                    tcvis,
                    self.tpi_outer_radius,
                    self.tpi_inner_radius,
                    self.device,
                )
            with timer("Segmenting", log=False):
                tile = ensemble.segment_tile(
                    tile,
                    patch_size=self.patch_size,
                    overlap=self.overlap,
                    batch_size=self.batch_size,
                    reflection=self.reflection,
                    keep_inputs=self.write_model_outputs,
                )
            with timer("Postprosessing", log=False):
                tile = prepare_export(
                    tile,
                    bin_threshold=self.binarization_threshold,
                    mask_erosion_size=self.mask_erosion_size,
                    min_object_size=self.min_object_size,
                    quality_level=self.quality_level,
                    ensemble_subsets=models.keys() if self.write_model_outputs else [],
                    device=self.device,
                )

            with timer("Exporting", log=False):
                export_tile(
                    tile,
                    outpath,
                    bands=self.export_bands,
                    ensemble_subsets=models.keys() if self.write_model_outputs else [],
                )

            n_tiles += 1
            results.append(
                {
                    "tile_id": tile_id,
                    "output_path": str(outpath.resolve()),
                    "status": "success",
                    "error": None,
                }
            )
            logger.info(f"Processed sample {i + 1} of {len(tileinfo)} '{tilekey}' ({tile_id=}).")
        except KeyboardInterrupt:
            logger.warning("Keyboard interrupt detected.\nExiting...")
            raise KeyboardInterrupt
        except Exception as e:
            logger.warning(f"Could not process '{tilekey}' ({tile_id=}).\nSkipping...")
            logger.exception(e)
            results.append(
                {
                    "tile_id": tile_id,
                    "output_path": str(outpath.resolve()),
                    "status": "failed",
                    "error": str(e),
                }
            )
        finally:
            if len(results) > 0:
                pd.DataFrame(results).to_parquet(self.output_data_dir / f"{current_time}.results.parquet")
            if len(timer.durations) > 0:
                timer.export().to_parquet(self.output_data_dir / f"{current_time}.stopuhr.parquet")
    else:
        logger.info(f"Processed {n_tiles} tiles to {self.output_data_dir.resolve()}.")
        timer.summary(printer=logger.info)

ConfigParser

ConfigParser()

Parser for cyclopts config.

An own implementation is needed to select our own toml structure and source. Implemented as a class to be able to provide the config-file as a parameter of the CLI.

Initialize the ConfigParser (no-op).

Source code in darts/src/darts/utils/config.py
def __init__(self) -> None:
    """Initialize the ConfigParser (no-op)."""
    self._config = None

__call__

__call__(
    apps: list[cyclopts.App],
    commands: tuple[str, ...],
    arguments: cyclopts.ArgumentCollection,
)

Parser for cyclopts config. An own implementation is needed to select our own toml structure.

First, the configuration file at "config.toml" is loaded. Then, this config is flattened and then mapped to the input arguments of the called function. Hence parent keys are not considered.

Parameters:

  • apps (list[cyclopts.App]) –

    The cyclopts apps. Unused, but must be provided for the cyclopts hook.

  • commands (tuple[str, ...]) –

    The commands. Unused, but must be provided for the cyclopts hook.

  • arguments (cyclopts.ArgumentCollection) –

    The arguments to apply the config to.

Examples:

Setup the cyclopts App
import cyclopts
from darts.utils.config import ConfigParser

config_parser = ConfigParser()
app = cyclopts.App(config=config_parser)

# Intercept the logging behavior to add a file handler
@app.meta.default
def launcher(
    *tokens: Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)],
    log_dir: Path = Path("logs"),
    config_file: Path = Path("config.toml"),
):
    command, bound, _ = app.parse_args(tokens)
    add_logging_handlers(command.__name__, console, log_dir)
    return command(*bound.args, **bound.kwargs)

if __name__ == "__main__":
    app.meta()
Usage

Config file ./config.toml:

[darts.hello] # The parent key is completely ignored
name = "Tobias"

Function signature which is called:

# ... setup code for cyclopts
@app.command()
def hello(name: str):
    print(f"Hello {name}")

Calling the function from CLI:

$ darts hello
Hello Tobias

$ darts hello --name=Max
Hello Max
Source code in darts/src/darts/utils/config.py
def __call__(self, apps: list[cyclopts.App], commands: tuple[str, ...], arguments: cyclopts.ArgumentCollection):
    """Parser for cyclopts config. An own implementation is needed to select our own toml structure.

    First, the configuration file at "config.toml" is loaded.
    Then, this config is flattened and then mapped to the input arguments of the called function.
    Hence parent keys are not considered.

    Args:
        apps (list[cyclopts.App]): The cyclopts apps. Unused, but must be provided for the cyclopts hook.
        commands (tuple[str, ...]): The commands. Unused, but must be provided for the cyclopts hook.
        arguments (cyclopts.ArgumentCollection): The arguments to apply the config to.

    Examples:
        ### Setup the cyclopts App

        ```python
        import cyclopts
        from darts.utils.config import ConfigParser

        config_parser = ConfigParser()
        app = cyclopts.App(config=config_parser)

        # Intercept the logging behavior to add a file handler
        @app.meta.default
        def launcher(
            *tokens: Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)],
            log_dir: Path = Path("logs"),
            config_file: Path = Path("config.toml"),
        ):
            command, bound, _ = app.parse_args(tokens)
            add_logging_handlers(command.__name__, console, log_dir)
            return command(*bound.args, **bound.kwargs)

        if __name__ == "__main__":
            app.meta()
        ```


        ### Usage

        Config file `./config.toml`:

        ```toml
        [darts.hello] # The parent key is completely ignored
        name = "Tobias"
        ```

        Function signature which is called:

        ```python
        # ... setup code for cyclopts
        @app.command()
        def hello(name: str):
            print(f"Hello {name}")
        ```

        Calling the function from CLI:

        ```sh
        $ darts hello
        Hello Tobias

        $ darts hello --name=Max
        Hello Max
        ```

    """
    if self._config is None:
        config_arg, _, _ = arguments.match("--config-file")
        config_file = config_arg.convert_and_validate()
        # Use default config file if not specified
        if not config_file:
            config_file = config_arg.field_info.default
        # else never happens
        self.open_config(config_file)

    self.apply_config(arguments)

apply_config

apply_config(arguments: cyclopts.ArgumentCollection)

Apply the loaded config to the cyclopts mapping.

Parameters:

  • arguments (cyclopts.ArgumentCollection) –

    The arguments to apply the config to.

Source code in darts/src/darts/utils/config.py
def apply_config(self, arguments: cyclopts.ArgumentCollection):
    """Apply the loaded config to the cyclopts mapping.

    Args:
        arguments (cyclopts.ArgumentCollection): The arguments to apply the config to.

    """
    to_add = []
    for k in self._config.keys():
        value = self._config[k]["value"]

        try:
            argument, remaining_keys, _ = arguments.match(f"--{k}")
        except ValueError:
            # Config key not found in arguments - ignore
            continue

        # Skip if the argument is not bound to a parameter
        if argument.tokens or argument.field_info.kind is argument.field_info.VAR_KEYWORD:
            continue

        # Skip if the argument is from the config file
        if any(x.source != "config-file" for x in argument.tokens):
            continue

        # Parse value to tuple of strings
        if not isinstance(value, list):
            value = (value,)
        value = tuple(str(x) for x in value)
        # Add the new tokens to the list
        for i, v in enumerate(value):
            to_add.append(
                (
                    argument,
                    cyclopts.Token(keyword=k, value=v, source="config-file", index=i, keys=remaining_keys),
                )
            )
    # Add here after all "arguments.match" calls, to avoid changing the list while iterating
    for argument, token in to_add:
        argument.append(token)

open_config

open_config(file_path: str | pathlib.Path) -> None

Open the config file, takes the 'darts' key, flattens the resulting dict and saves as config.

Parameters:

Source code in darts/src/darts/utils/config.py
def open_config(self, file_path: str | Path) -> None:
    """Open the config file, takes the 'darts' key, flattens the resulting dict and saves as config.

    Args:
        file_path (str | Path): The path to the config file.

    """
    file_path = file_path if isinstance(file_path, Path) else Path(file_path)

    if not file_path.exists():
        logger.warning(f"No config file found at {file_path.resolve()}")
        self._config = {}
        return

    with file_path.open("rb") as f:
        config = tomllib.load(f)["darts"]

    # Flatten the config data ()
    self._config = flatten_dict(config)
    logger.info(f"loaded config from '{file_path.resolve()}'")

PlanetPipeline dataclass

PlanetPipeline(
    model_files: list[pathlib.Path] = None,
    output_data_dir: pathlib.Path = pathlib.Path(
        "data/output"
    ),
    arcticdem_dir: pathlib.Path = pathlib.Path(
        "data/download/arcticdem"
    ),
    tcvis_dir: pathlib.Path = pathlib.Path(
        "data/download/tcvis"
    ),
    device: typing.Literal["cuda", "cpu", "auto"]
    | int
    | None = None,
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 256,
    batch_size: int = 8,
    reflection: int = 0,
    binarization_threshold: float = 0.5,
    mask_erosion_size: int = 10,
    min_object_size: int = 32,
    quality_level: int
    | typing.Literal[
        "high_quality", "low_quality", "none"
    ] = 1,
    export_bands: list[str] = lambda: [
        "probabilities",
        "binarized",
        "polygonized",
        "extent",
        "thumbnail",
    ](),
    write_model_outputs: bool = False,
    overwrite: bool = False,
    orthotiles_dir: pathlib.Path = pathlib.Path(
        "data/input/planet/PSOrthoTile"
    ),
    scenes_dir: pathlib.Path = pathlib.Path(
        "data/input/planet/PSScene"
    ),
    image_ids: list = None,
)

Bases: darts.pipelines.sequential_v2._BasePipeline

Pipeline for PlanetScope data.

Parameters:

  • orthotiles_dir (pathlib.Path, default: pathlib.Path('data/input/planet/PSOrthoTile') ) –

    The directory containing the PlanetScope orthotiles.

  • scenes_dir (pathlib.Path, default: pathlib.Path('data/input/planet/PSScene') ) –

    The directory containing the PlanetScope scenes.

  • image_ids (list, default: None ) –

    The list of image ids to process. If None, all images in the directory will be processed.

  • model_files (pathlib.Path | list[pathlib.Path], default: None ) –

    The path to the models to use for segmentation. Can also be a single Path to only use one model. This implies write_model_outputs=False If a list is provided, will use an ensemble of the models.

  • output_data_dir (pathlib.Path, default: pathlib.Path('data/output') ) –

    The "output" directory. Defaults to Path("data/output").

  • arcticdem_dir (pathlib.Path, default: pathlib.Path('data/download/arcticdem') ) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist. Defaults to Path("data/download/arcticdem").

  • tcvis_dir (pathlib.Path, default: pathlib.Path('data/download/tcvis') ) –

    The directory containing the TCVis data. Defaults to Path("data/download/tcvis").

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 256 ) –

    The overlap to use for inference. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size to use for inference. Defaults to 8.

  • reflection (int, default: 0 ) –

    The reflection padding to use for inference. Defaults to 0.

  • binarization_threshold (float, default: 0.5 ) –

    The threshold to binarize the probabilities. Defaults to 0.5.

  • mask_erosion_size (int, default: 10 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

  • min_object_size (int, default: 32 ) –

    The minimum object size to keep in pixel. Defaults to 32.

  • quality_level (int | typing.Literal['high_quality', 'low_quality', 'none'], default: 1 ) –

    The quality level to use for the segmentation. Can also be an int. In this case 0="none" 1="low_quality" 2="high_quality". Defaults to 1.

  • export_bands (list[str], default: lambda: ['probabilities', 'binarized', 'polygonized', 'extent', 'thumbnail']() ) –

    The bands to export. Can be a list of "probabilities", "binarized", "polygonized", "extent", "thumbnail", "optical", "dem", "tcvis" or concrete band-names. Defaults to ["probabilities", "binarized", "polygonized", "extent", "thumbnail"].

  • write_model_outputs (bool, default: False ) –

    Also save the model outputs, not only the ensemble result. Defaults to False.

  • overwrite (bool, default: False ) –

    Whether to overwrite existing files. Defaults to False.

arcticdem_dir class-attribute instance-attribute

arcticdem_dir: pathlib.Path = pathlib.Path(
    "data/download/arcticdem"
)

batch_size class-attribute instance-attribute

batch_size: int = 8

binarization_threshold class-attribute instance-attribute

binarization_threshold: float = 0.5

device class-attribute instance-attribute

device: (
    typing.Literal["cuda", "cpu", "auto"] | int | None
) = None

ee_project class-attribute instance-attribute

ee_project: str | None = None

ee_use_highvolume class-attribute instance-attribute

ee_use_highvolume: bool = True

export_bands class-attribute instance-attribute

export_bands: list[str] = dataclasses.field(
    default_factory=lambda: [
        "probabilities",
        "binarized",
        "polygonized",
        "extent",
        "thumbnail",
    ]
)

image_ids class-attribute instance-attribute

image_ids: list = None

mask_erosion_size class-attribute instance-attribute

mask_erosion_size: int = 10

min_object_size class-attribute instance-attribute

min_object_size: int = 32

model_files class-attribute instance-attribute

model_files: list[pathlib.Path] = None

orthotiles_dir class-attribute instance-attribute

orthotiles_dir: pathlib.Path = pathlib.Path(
    "data/input/planet/PSOrthoTile"
)

output_data_dir class-attribute instance-attribute

output_data_dir: pathlib.Path = pathlib.Path('data/output')

overlap class-attribute instance-attribute

overlap: int = 256

overwrite class-attribute instance-attribute

overwrite: bool = False

patch_size class-attribute instance-attribute

patch_size: int = 1024

quality_level class-attribute instance-attribute

quality_level: (
    int
    | typing.Literal["high_quality", "low_quality", "none"]
) = 1

reflection class-attribute instance-attribute

reflection: int = 0

scenes_dir class-attribute instance-attribute

scenes_dir: pathlib.Path = pathlib.Path(
    "data/input/planet/PSScene"
)

tcvis_dir class-attribute instance-attribute

tcvis_dir: pathlib.Path = pathlib.Path(
    "data/download/tcvis"
)

tpi_inner_radius class-attribute instance-attribute

tpi_inner_radius: int = 0

tpi_outer_radius class-attribute instance-attribute

tpi_outer_radius: int = 100

write_model_outputs class-attribute instance-attribute

write_model_outputs: bool = False

cli staticmethod

Run the sequential pipeline for Planet data.

Source code in darts/src/darts/pipelines/sequential_v2.py
@staticmethod
def cli(*, pipeline: "PlanetPipeline"):
    """Run the sequential pipeline for Planet data."""
    pipeline.run()

run

run()
Source code in darts/src/darts/pipelines/sequential_v2.py
def run(self):  # noqa: C901
    if self.model_files is None or len(self.model_files) == 0:
        raise ValueError("No model files provided. Please provide a list of model files.")
    if len(self.export_bands) == 0:
        raise ValueError("No export bands provided. Please provide a list of export bands.")

    current_time = time.strftime("%Y-%m-%d_%H-%M-%S")
    logger.info(f"Starting pipeline at {current_time}.")

    # Storing the configuration as JSON file
    self.output_data_dir.mkdir(parents=True, exist_ok=True)
    with open(self.output_data_dir / f"{current_time}.config.json", "w") as f:
        config = asdict(self)
        # Convert everything to json serializable
        for key, value in config.items():
            if isinstance(value, Path):
                config[key] = str(value.resolve())
            elif isinstance(value, list):
                config[key] = [str(v.resolve()) if isinstance(v, Path) else v for v in value]
        json.dump(config, f)

    from stopuhr import Chronometer

    timer = Chronometer(printer=logger.debug)

    from darts.utils.cuda import debug_info

    debug_info()

    from darts.utils.earthengine import init_ee

    init_ee(self.ee_project, self.ee_use_highvolume)

    import pandas as pd
    import smart_geocubes
    import torch
    from darts_acquisition import load_arcticdem, load_tcvis
    from darts_ensemble import EnsembleV1
    from darts_export import export_tile, missing_outputs
    from darts_postprocessing import prepare_export
    from darts_preprocessing import preprocess_legacy_fast

    from darts.utils.cuda import decide_device
    from darts.utils.logging import LoggingManager

    self.device = decide_device(self.device)

    # determine models to use
    if isinstance(self.model_files, Path):
        self.model_files = [self.model_files]
        self.write_model_outputs = False
    models = {model_file.stem: model_file for model_file in self.model_files}
    ensemble = EnsembleV1(models, device=torch.device(self.device))

    # Create the datacubes if they do not exist
    LoggingManager.apply_logging_handlers("smart_geocubes")
    arcticdem_resolution = self._arcticdem_resolution()
    if arcticdem_resolution == 2:
        accessor = smart_geocubes.ArcticDEM2m(self.arcticdem_dir)
    elif arcticdem_resolution == 10:
        accessor = smart_geocubes.ArcticDEM10m(self.arcticdem_dir)
    if not accessor.created:
        accessor.create(overwrite=False)
    accessor = smart_geocubes.TCTrend(self.tcvis_dir)
    if not accessor.created:
        accessor.create(overwrite=False)

    # Iterate over all the data
    tileinfo = self._tileinfos()
    n_tiles = 0
    logger.info(f"Found {len(tileinfo)} tiles to process.")
    results = []
    for i, (tilekey, outpath) in enumerate(tileinfo):
        tile_id = self._get_tile_id(tilekey)
        try:
            if not self.overwrite:
                mo = missing_outputs(outpath, bands=self.export_bands, ensemble_subsets=models.keys())
                if mo == "none":
                    logger.info(f"Tile {tile_id} already processed. Skipping...")
                    continue
                if mo == "some":
                    logger.warning(
                        f"Tile {tile_id} already processed. Some outputs are missing."
                        " Skipping because overwrite=False..."
                    )
                    continue

            with timer("Loading optical data", log=False):
                tile = self._load_tile(tilekey)
            with timer("Loading ArcticDEM", log=False):
                arcticdem = load_arcticdem(
                    tile.odc.geobox,
                    self.arcticdem_dir,
                    resolution=arcticdem_resolution,
                    buffer=ceil(self.tpi_outer_radius / arcticdem_resolution * sqrt(2)),
                )
            with timer("Loading TCVis", log=False):
                tcvis = load_tcvis(tile.odc.geobox, self.tcvis_dir)
            with timer("Preprocessing tile", log=False):
                tile = preprocess_legacy_fast(
                    tile,
                    arcticdem,
                    tcvis,
                    self.tpi_outer_radius,
                    self.tpi_inner_radius,
                    self.device,
                )
            with timer("Segmenting", log=False):
                tile = ensemble.segment_tile(
                    tile,
                    patch_size=self.patch_size,
                    overlap=self.overlap,
                    batch_size=self.batch_size,
                    reflection=self.reflection,
                    keep_inputs=self.write_model_outputs,
                )
            with timer("Postprosessing", log=False):
                tile = prepare_export(
                    tile,
                    bin_threshold=self.binarization_threshold,
                    mask_erosion_size=self.mask_erosion_size,
                    min_object_size=self.min_object_size,
                    quality_level=self.quality_level,
                    ensemble_subsets=models.keys() if self.write_model_outputs else [],
                    device=self.device,
                )

            with timer("Exporting", log=False):
                export_tile(
                    tile,
                    outpath,
                    bands=self.export_bands,
                    ensemble_subsets=models.keys() if self.write_model_outputs else [],
                )

            n_tiles += 1
            results.append(
                {
                    "tile_id": tile_id,
                    "output_path": str(outpath.resolve()),
                    "status": "success",
                    "error": None,
                }
            )
            logger.info(f"Processed sample {i + 1} of {len(tileinfo)} '{tilekey}' ({tile_id=}).")
        except KeyboardInterrupt:
            logger.warning("Keyboard interrupt detected.\nExiting...")
            raise KeyboardInterrupt
        except Exception as e:
            logger.warning(f"Could not process '{tilekey}' ({tile_id=}).\nSkipping...")
            logger.exception(e)
            results.append(
                {
                    "tile_id": tile_id,
                    "output_path": str(outpath.resolve()),
                    "status": "failed",
                    "error": str(e),
                }
            )
        finally:
            if len(results) > 0:
                pd.DataFrame(results).to_parquet(self.output_data_dir / f"{current_time}.results.parquet")
            if len(timer.durations) > 0:
                timer.export().to_parquet(self.output_data_dir / f"{current_time}.stopuhr.parquet")
    else:
        logger.info(f"Processed {n_tiles} tiles to {self.output_data_dir.resolve()}.")
        timer.summary(printer=logger.info)

Sentinel2Pipeline dataclass

Sentinel2Pipeline(
    model_files: list[pathlib.Path] = None,
    output_data_dir: pathlib.Path = pathlib.Path(
        "data/output"
    ),
    arcticdem_dir: pathlib.Path = pathlib.Path(
        "data/download/arcticdem"
    ),
    tcvis_dir: pathlib.Path = pathlib.Path(
        "data/download/tcvis"
    ),
    device: typing.Literal["cuda", "cpu", "auto"]
    | int
    | None = None,
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 256,
    batch_size: int = 8,
    reflection: int = 0,
    binarization_threshold: float = 0.5,
    mask_erosion_size: int = 10,
    min_object_size: int = 32,
    quality_level: int
    | typing.Literal[
        "high_quality", "low_quality", "none"
    ] = 1,
    export_bands: list[str] = lambda: [
        "probabilities",
        "binarized",
        "polygonized",
        "extent",
        "thumbnail",
    ](),
    write_model_outputs: bool = False,
    overwrite: bool = False,
    sentinel2_dir: pathlib.Path = pathlib.Path(
        "data/input/sentinel2"
    ),
    image_ids: list = None,
)

Bases: darts.pipelines.sequential_v2._BasePipeline

Pipeline for Sentinel 2 data.

Parameters:

  • sentinel2_dir (pathlib.Path, default: pathlib.Path('data/input/sentinel2') ) –

    The directory containing the Sentinel 2 scenes. Defaults to Path("data/input/sentinel2").

  • image_ids (list, default: None ) –

    The list of image ids to process. If None, all images in the directory will be processed. Defaults to None.

  • model_files (pathlib.Path | list[pathlib.Path], default: None ) –

    The path to the models to use for segmentation. Can also be a single Path to only use one model. This implies write_model_outputs=False If a list is provided, will use an ensemble of the models.

  • output_data_dir (pathlib.Path, default: pathlib.Path('data/output') ) –

    The "output" directory. Defaults to Path("data/output").

  • arcticdem_dir (pathlib.Path, default: pathlib.Path('data/download/arcticdem') ) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist. Defaults to Path("data/download/arcticdem").

  • tcvis_dir (pathlib.Path, default: pathlib.Path('data/download/tcvis') ) –

    The directory containing the TCVis data. Defaults to Path("data/download/tcvis").

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 256 ) –

    The overlap to use for inference. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size to use for inference. Defaults to 8.

  • reflection (int, default: 0 ) –

    The reflection padding to use for inference. Defaults to 0.

  • binarization_threshold (float, default: 0.5 ) –

    The threshold to binarize the probabilities. Defaults to 0.5.

  • mask_erosion_size (int, default: 10 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

  • min_object_size (int, default: 32 ) –

    The minimum object size to keep in pixel. Defaults to 32.

  • quality_level (int | typing.Literal['high_quality', 'low_quality', 'none'], default: 1 ) –

    The quality level to use for the segmentation. Can also be an int. In this case 0="none" 1="low_quality" 2="high_quality". Defaults to 1.

  • export_bands (list[str], default: lambda: ['probabilities', 'binarized', 'polygonized', 'extent', 'thumbnail']() ) –

    The bands to export. Can be a list of "probabilities", "binarized", "polygonized", "extent", "thumbnail", "optical", "dem", "tcvis" or concrete band-names. Defaults to ["probabilities", "binarized", "polygonized", "extent", "thumbnail"].

  • write_model_outputs (bool, default: False ) –

    Also save the model outputs, not only the ensemble result. Defaults to False.

  • overwrite (bool, default: False ) –

    Whether to overwrite existing files. Defaults to False.

arcticdem_dir class-attribute instance-attribute

arcticdem_dir: pathlib.Path = pathlib.Path(
    "data/download/arcticdem"
)

batch_size class-attribute instance-attribute

batch_size: int = 8

binarization_threshold class-attribute instance-attribute

binarization_threshold: float = 0.5

device class-attribute instance-attribute

device: (
    typing.Literal["cuda", "cpu", "auto"] | int | None
) = None

ee_project class-attribute instance-attribute

ee_project: str | None = None

ee_use_highvolume class-attribute instance-attribute

ee_use_highvolume: bool = True

export_bands class-attribute instance-attribute

export_bands: list[str] = dataclasses.field(
    default_factory=lambda: [
        "probabilities",
        "binarized",
        "polygonized",
        "extent",
        "thumbnail",
    ]
)

image_ids class-attribute instance-attribute

image_ids: list = None

mask_erosion_size class-attribute instance-attribute

mask_erosion_size: int = 10

min_object_size class-attribute instance-attribute

min_object_size: int = 32

model_files class-attribute instance-attribute

model_files: list[pathlib.Path] = None

output_data_dir class-attribute instance-attribute

output_data_dir: pathlib.Path = pathlib.Path('data/output')

overlap class-attribute instance-attribute

overlap: int = 256

overwrite class-attribute instance-attribute

overwrite: bool = False

patch_size class-attribute instance-attribute

patch_size: int = 1024

quality_level class-attribute instance-attribute

quality_level: (
    int
    | typing.Literal["high_quality", "low_quality", "none"]
) = 1

reflection class-attribute instance-attribute

reflection: int = 0

sentinel2_dir class-attribute instance-attribute

sentinel2_dir: pathlib.Path = pathlib.Path(
    "data/input/sentinel2"
)

tcvis_dir class-attribute instance-attribute

tcvis_dir: pathlib.Path = pathlib.Path(
    "data/download/tcvis"
)

tpi_inner_radius class-attribute instance-attribute

tpi_inner_radius: int = 0

tpi_outer_radius class-attribute instance-attribute

tpi_outer_radius: int = 100

write_model_outputs class-attribute instance-attribute

write_model_outputs: bool = False

cli staticmethod

Run the sequential pipeline for Sentinel 2 data.

Source code in darts/src/darts/pipelines/sequential_v2.py
@staticmethod
def cli(*, pipeline: "Sentinel2Pipeline"):
    """Run the sequential pipeline for Sentinel 2 data."""
    pipeline.run()

run

run()
Source code in darts/src/darts/pipelines/sequential_v2.py
def run(self):  # noqa: C901
    if self.model_files is None or len(self.model_files) == 0:
        raise ValueError("No model files provided. Please provide a list of model files.")
    if len(self.export_bands) == 0:
        raise ValueError("No export bands provided. Please provide a list of export bands.")

    current_time = time.strftime("%Y-%m-%d_%H-%M-%S")
    logger.info(f"Starting pipeline at {current_time}.")

    # Storing the configuration as JSON file
    self.output_data_dir.mkdir(parents=True, exist_ok=True)
    with open(self.output_data_dir / f"{current_time}.config.json", "w") as f:
        config = asdict(self)
        # Convert everything to json serializable
        for key, value in config.items():
            if isinstance(value, Path):
                config[key] = str(value.resolve())
            elif isinstance(value, list):
                config[key] = [str(v.resolve()) if isinstance(v, Path) else v for v in value]
        json.dump(config, f)

    from stopuhr import Chronometer

    timer = Chronometer(printer=logger.debug)

    from darts.utils.cuda import debug_info

    debug_info()

    from darts.utils.earthengine import init_ee

    init_ee(self.ee_project, self.ee_use_highvolume)

    import pandas as pd
    import smart_geocubes
    import torch
    from darts_acquisition import load_arcticdem, load_tcvis
    from darts_ensemble import EnsembleV1
    from darts_export import export_tile, missing_outputs
    from darts_postprocessing import prepare_export
    from darts_preprocessing import preprocess_legacy_fast

    from darts.utils.cuda import decide_device
    from darts.utils.logging import LoggingManager

    self.device = decide_device(self.device)

    # determine models to use
    if isinstance(self.model_files, Path):
        self.model_files = [self.model_files]
        self.write_model_outputs = False
    models = {model_file.stem: model_file for model_file in self.model_files}
    ensemble = EnsembleV1(models, device=torch.device(self.device))

    # Create the datacubes if they do not exist
    LoggingManager.apply_logging_handlers("smart_geocubes")
    arcticdem_resolution = self._arcticdem_resolution()
    if arcticdem_resolution == 2:
        accessor = smart_geocubes.ArcticDEM2m(self.arcticdem_dir)
    elif arcticdem_resolution == 10:
        accessor = smart_geocubes.ArcticDEM10m(self.arcticdem_dir)
    if not accessor.created:
        accessor.create(overwrite=False)
    accessor = smart_geocubes.TCTrend(self.tcvis_dir)
    if not accessor.created:
        accessor.create(overwrite=False)

    # Iterate over all the data
    tileinfo = self._tileinfos()
    n_tiles = 0
    logger.info(f"Found {len(tileinfo)} tiles to process.")
    results = []
    for i, (tilekey, outpath) in enumerate(tileinfo):
        tile_id = self._get_tile_id(tilekey)
        try:
            if not self.overwrite:
                mo = missing_outputs(outpath, bands=self.export_bands, ensemble_subsets=models.keys())
                if mo == "none":
                    logger.info(f"Tile {tile_id} already processed. Skipping...")
                    continue
                if mo == "some":
                    logger.warning(
                        f"Tile {tile_id} already processed. Some outputs are missing."
                        " Skipping because overwrite=False..."
                    )
                    continue

            with timer("Loading optical data", log=False):
                tile = self._load_tile(tilekey)
            with timer("Loading ArcticDEM", log=False):
                arcticdem = load_arcticdem(
                    tile.odc.geobox,
                    self.arcticdem_dir,
                    resolution=arcticdem_resolution,
                    buffer=ceil(self.tpi_outer_radius / arcticdem_resolution * sqrt(2)),
                )
            with timer("Loading TCVis", log=False):
                tcvis = load_tcvis(tile.odc.geobox, self.tcvis_dir)
            with timer("Preprocessing tile", log=False):
                tile = preprocess_legacy_fast(
                    tile,
                    arcticdem,
                    tcvis,
                    self.tpi_outer_radius,
                    self.tpi_inner_radius,
                    self.device,
                )
            with timer("Segmenting", log=False):
                tile = ensemble.segment_tile(
                    tile,
                    patch_size=self.patch_size,
                    overlap=self.overlap,
                    batch_size=self.batch_size,
                    reflection=self.reflection,
                    keep_inputs=self.write_model_outputs,
                )
            with timer("Postprosessing", log=False):
                tile = prepare_export(
                    tile,
                    bin_threshold=self.binarization_threshold,
                    mask_erosion_size=self.mask_erosion_size,
                    min_object_size=self.min_object_size,
                    quality_level=self.quality_level,
                    ensemble_subsets=models.keys() if self.write_model_outputs else [],
                    device=self.device,
                )

            with timer("Exporting", log=False):
                export_tile(
                    tile,
                    outpath,
                    bands=self.export_bands,
                    ensemble_subsets=models.keys() if self.write_model_outputs else [],
                )

            n_tiles += 1
            results.append(
                {
                    "tile_id": tile_id,
                    "output_path": str(outpath.resolve()),
                    "status": "success",
                    "error": None,
                }
            )
            logger.info(f"Processed sample {i + 1} of {len(tileinfo)} '{tilekey}' ({tile_id=}).")
        except KeyboardInterrupt:
            logger.warning("Keyboard interrupt detected.\nExiting...")
            raise KeyboardInterrupt
        except Exception as e:
            logger.warning(f"Could not process '{tilekey}' ({tile_id=}).\nSkipping...")
            logger.exception(e)
            results.append(
                {
                    "tile_id": tile_id,
                    "output_path": str(outpath.resolve()),
                    "status": "failed",
                    "error": str(e),
                }
            )
        finally:
            if len(results) > 0:
                pd.DataFrame(results).to_parquet(self.output_data_dir / f"{current_time}.results.parquet")
            if len(timer.durations) > 0:
                timer.export().to_parquet(self.output_data_dir / f"{current_time}.stopuhr.parquet")
    else:
        logger.info(f"Processed {n_tiles} tiles to {self.output_data_dir.resolve()}.")
        timer.summary(printer=logger.info)

benchviz

benchviz(
    stopuhr_data: pathlib.Path,
    *,
    viz_dir: pathlib.Path | None = None,
)

Visulize benchmark based on a Stopuhr data file produced by a pipeline run.

Note

This function changes the seaborn theme to "whitegrid" for better visualization.

Parameters:

  • stopuhr_data (pathlib.Path) –

    Path to the Stopuhr data file.

  • viz_dir (pathlib.Path | None, default: None ) –

    Path to the directory where the visualization will be saved. If None, the defaults to the parent directory of the Stopuhr data file. Defaults to None.

Returns:

  • plt.Figure: A matplotlib figure containing the benchmark visualization.

Source code in darts/src/darts/utils/bench.py
def benchviz(
    stopuhr_data: Path,
    *,
    viz_dir: Path | None = None,
):
    """Visulize benchmark based on a Stopuhr data file produced by a pipeline run.

    !!! note
        This function changes the seaborn theme to "whitegrid" for better visualization.

    Args:
        stopuhr_data (Path): Path to the Stopuhr data file.
        viz_dir (Path | None): Path to the directory where the visualization will be saved.
            If None, the defaults to the parent directory of the Stopuhr data file.
            Defaults to None.

    Returns:
        plt.Figure: A matplotlib figure containing the benchmark visualization.

    """
    import matplotlib.pyplot as plt
    import numpy as np
    import pandas as pd
    import seaborn as sns

    # Visualize the results
    sns.set_theme(style="whitegrid")

    assert stopuhr_data.suffix == ".parquet", "Stopuhr data file must be a parquet file."

    times = pd.read_parquet(stopuhr_data)
    times_long = times.melt(ignore_index=False, value_name="time", var_name="step").reset_index(drop=False)
    times_desc = times.describe()
    times_sum = times.sum()

    # Pretty print the results
    for col in times_desc.columns:
        mean = times_desc[col]["mean"]
        std = times_desc[col]["std"]
        total = times_sum[col]
        n = int(times_desc[col]["count"].item())
        logger.info(f"{col} took {mean:.2f} ± {std:.2f}s ({n=} -> {total=:.2f}s)")

    # axs: hist, histlog, bar, heat
    fig, axs = plt.subplot_mosaic(
        [
            ["histlog"] * 4,
            ["histlog"] * 4,
            ["hist", "hist", "heat", "heat"],
            ["hist", "hist", "heat", "heat"],
            ["bar", "bar", "bar", "bar"],
        ],
        layout="constrained",
        figsize=(20, 15),
    )

    sns.histplot(
        data=times_long,
        x="time",
        hue="step",
        bins=100,
        # log_scale=True,
        ax=axs["hist"],
    )
    axs["hist"].set_xlabel("Time in seconds")
    axs["hist"].set_title("Histogram of time taken for each step", fontdict={"fontweight": "bold"})

    sns.histplot(
        data=times_long,
        x="time",
        hue="step",
        bins=100,
        log_scale=True,
        kde=True,
        ax=axs["histlog"],
    )
    axs["histlog"].set_xlabel("Time in seconds")
    axs["histlog"].set_title("Histogram of time taken for each step (log scale)", fontdict={"fontweight": "bold"})

    sns.heatmap(
        times.T,
        robust=True,
        cbar_kws={"label": "Time in seconds"},
        ax=axs["heat"],
    )
    axs["heat"].set_xlabel("Sample")
    axs["heat"].set_title("Heatmap of time taken for each step and sample", fontdict={"fontweight": "bold"})

    bottom = np.array([0.0])
    for i, (step, time_taken) in enumerate(times.mean().items()):
        axs["bar"].barh(["Time taken"], [time_taken], label=step, color=sns.color_palette()[i], left=bottom)
        # Add a text label to the bar
        axs["bar"].text(
            bottom[-1] + time_taken / 2,
            0,
            f"{step}:\n{time_taken:.1f} s",
            va="center",
            ha="center",
            fontsize=10,
            color="white",
        )
        bottom += time_taken
    axs["bar"].legend(loc="upper center", bbox_to_anchor=(0.5, 1.05), ncol=3)
    # Make the y-axis labels vertical
    axs["bar"].set_yticks([0.15], labels=["Time taken"], rotation=90)
    axs["bar"].set_xlabel("Time in seconds")
    axs["bar"].set_title("Avg. time taken for each step", fontdict={"fontweight": "bold"})

    # Save the figure
    viz_dir = viz_dir or stopuhr_data.parent
    viz_dir.mkdir(parents=True, exist_ok=True)
    fpath = viz_dir / stopuhr_data.name.replace(".parquet", ".png")
    fig.savefig(fpath, dpi=300, bbox_inches="tight")
    logger.info(f"Benchmark visualization saved to {fpath.resolve()}")

    return fig

convert_lightning_checkpoint

convert_lightning_checkpoint(
    *,
    lightning_checkpoint: pathlib.Path,
    out_directory: pathlib.Path,
    checkpoint_name: str,
    framework: str = "smp",
)

Convert a lightning checkpoint to our own format.

The final checkpoint will contain the model configuration and the state dict. It will be saved to:

    out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"

Parameters:

  • lightning_checkpoint (pathlib.Path) –

    Path to the lightning checkpoint.

  • out_directory (pathlib.Path) –

    Output directory for the converted checkpoint.

  • checkpoint_name (str) –

    A unique name of the new checkpoint.

  • framework (str, default: 'smp' ) –

    The framework used for the model. Defaults to "smp".

Source code in darts-segmentation/src/darts_segmentation/training/train.py
def convert_lightning_checkpoint(
    *,
    lightning_checkpoint: Path,
    out_directory: Path,
    checkpoint_name: str,
    framework: str = "smp",
):
    """Convert a lightning checkpoint to our own format.

    The final checkpoint will contain the model configuration and the state dict.
    It will be saved to:

    ```python
        out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"
    ```

    Args:
        lightning_checkpoint (Path): Path to the lightning checkpoint.
        out_directory (Path): Output directory for the converted checkpoint.
        checkpoint_name (str): A unique name of the new checkpoint.
        framework (str, optional): The framework used for the model. Defaults to "smp".

    """
    import torch

    logger.debug(f"Loading checkpoint from {lightning_checkpoint.resolve()}")
    lckpt = torch.load(lightning_checkpoint, weights_only=False, map_location=torch.device("cpu"))

    now = datetime.now()
    formatted_date = now.strftime("%Y-%m-%d")
    config = lckpt["hyper_parameters"]["config"]
    del config["model"]["encoder_weights"]
    config["time"] = formatted_date
    config["name"] = checkpoint_name
    config["model_framework"] = framework

    statedict = lckpt["state_dict"]
    # Statedict has model. prefix before every weight. We need to remove them. This is an in-place function
    torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(statedict, "model.")

    own_ckpt = {
        "config": config,
        "statedict": lckpt["state_dict"],
    }

    out_directory.mkdir(exist_ok=True, parents=True)

    out_checkpoint = out_directory / f"{checkpoint_name}_{formatted_date}.ckpt"

    torch.save(own_ckpt, out_checkpoint)

    logger.info(f"Saved converted checkpoint to {out_checkpoint.resolve()}")

cross_validation_smp

Perform cross-validation for a model with given hyperparameters.

Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

Please also consider reading our training guide (docs/guides/training.md).

This cross-validation function is designed to evaluate the performance of a single model configuration. It can be used by a tuning script to tune hyperparameters. It calls the training function, hence most functionality is the same as the training function. In general, it does perform this:

for seed in seeds:
    for fold in folds:
        train_model(seed=seed, fold=fold, ...)

and calculates a score from the results.

To specify on which metric(s) the score is calculated, the scoring_metric parameter can be specified. Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics. This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower". If no direction is provided, it is assumed to be ":higher". Has no real effect on the single score calculation, since only the mean is calculated there.

In a multi-score setting, the score is calculated by combine-then-reduce the metrics. Meaning that first for each fold the metrics are combined using the specified strategy, and then the results are reduced via mean. Please refer to the documentation to understand the different multi-score strategies.

If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable".

Artifacts are stored under {artifact_dir}/{tune_name} for tunes (meaning if tune_name is not None) else {artifact_dir}/_cross_validation.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - early_stopping_patience specifies how many epochs to wait for improvement before stopping. In epochs, this would be check_val_every_n_epoch * early_stopping_patience. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs. Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch. If log_every_n_steps is set to 50 then the training logs and metrics will be logged 4 times per epoch. If check_val_every_n_epoch is set to 5 then validation will be performed every 5 epochs. If plot_every_n_val_epochs is set to 2 then validation samples will be plotted every 10 epochs. If early_stopping_patience is set to 3 then early stopping will be performed after 15 epochs without improvement.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson

Parameters:

Returns:

  • tuple[float, bool, pd.DataFrame]: A single score, a boolean indicating if the score is unstable, and a DataFrame containing run info (seed, fold, metrics, duration, checkpoint)

Raises:

  • ValueError

    If no runs were performed, meaning the configuration is invalid or no data was found.

Source code in darts-segmentation/src/darts_segmentation/training/cv.py
def cross_validation_smp(
    *,
    name: str | None = None,
    tune_name: str | None = None,
    cv: CrossValidationConfig = CrossValidationConfig(),
    training_config: TrainingConfig = TrainingConfig(),
    data_config: DataConfig = DataConfig(),
    device_config: DeviceConfig = DeviceConfig(),
    hparams: Hyperparameters = Hyperparameters(),
    logging_config: LoggingConfig = LoggingConfig(),
):
    """Perform cross-validation for a model with given hyperparameters.

    Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

    Please also consider reading our training guide (docs/guides/training.md).

    This cross-validation function is designed to evaluate the performance of a single model configuration.
    It can be used by a tuning script to tune hyperparameters.
    It calls the training function, hence most functionality is the same as the training function.
    In general, it does perform this:

    ```py
    for seed in seeds:
        for fold in folds:
            train_model(seed=seed, fold=fold, ...)
    ```

    and calculates a score from the results.

    To specify on which metric(s) the score is calculated, the `scoring_metric` parameter can be specified.
    Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics.
    This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower".
    If no direction is provided, it is assumed to be ":higher".
    Has no real effect on the single score calculation, since only the mean is calculated there.

    In a multi-score setting, the score is calculated by combine-then-reduce the metrics.
    Meaning that first for each fold the metrics are combined using the specified strategy,
    and then the results are reduced via mean.
    Please refer to the documentation to understand the different multi-score strategies.

    If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable".

    Artifacts are stored under `{artifact_dir}/{tune_name}` for tunes (meaning if `tune_name` is not None)
    else `{artifact_dir}/_cross_validation`.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `early_stopping_patience` specifies how many epochs to wait for improvement before stopping.
            In epochs, this would be `check_val_every_n_epoch * early_stopping_patience`.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.
    Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch.
    If `log_every_n_steps` is set to 50 then the training logs and metrics will be logged 4 times per epoch.
    If `check_val_every_n_epoch` is set to 5 then validation will be performed every 5 epochs.
    If `plot_every_n_val_epochs` is set to 2 then validation samples will be plotted every 10 epochs.
    If `early_stopping_patience` is set to 3 then early stopping will be performed after 15 epochs without improvement.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── data.zarr/ # this zarr group contains the dataarrays x and y
    ├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
    └── labels.geojson
    ```

    Args:
        name (str | None, optional): Name of the cross-validation. If None, a name is generated automatically.
            Defaults to None.
        tune_name (str | None, optional): Name of the tuning. Should only be specified by a tuning script.
            Defaults to None.
        cv (CrossValidationConfig): Configuration for cross-validation.
        training_config (TrainingConfig): Configuration for the training.
        data_config (DataConfig): Configuration for the data.
        device_config (DeviceConfig): Configuration for the devices to use.
        hparams (Hyperparameters): Hyperparameters for the training.
        logging_config (LoggingConfig): Logging configuration.

    Returns:
        tuple[float, bool, pd.DataFrame]: A single score, a boolean indicating if the score is unstable,
            and a DataFrame containing run info (seed, fold, metrics, duration, checkpoint)

    Raises:
        ValueError: If no runs were performed, meaning the configuration is invalid or no data was found.

    """
    import pandas as pd
    from darts_utils.namegen import generate_counted_name

    from darts_segmentation.training.adp import _adp
    from darts_segmentation.training.scoring import score_from_runs

    tick_fstart = time.perf_counter()

    artifact_dir = logging_config.artifact_dir_at_cv(tune_name)
    cv_name = name or generate_counted_name(artifact_dir)
    artifact_dir = artifact_dir / cv_name
    artifact_dir.mkdir(parents=True, exist_ok=True)

    n_folds = cv.n_folds or data_config.total_folds

    logger.info(
        f"Starting cross-validation '{cv_name}' with data from {data_config.train_data_dir.resolve()}."
        f" Artifacts will be saved to {artifact_dir.resolve()}."
        f" Will run n_randoms*n_folds = {cv.n_randoms}*{n_folds} = {cv.n_randoms * n_folds} experiments."
    )

    seeds = cv.rng_seeds
    logger.debug(f"Using seeds: {seeds}")

    # Plan which runs to perform. These are later consumed based on the parallelization strategy.
    process_inputs: list[_ProcessInputs] = []
    for i, seed in enumerate(seeds):
        for fold in range(n_folds):
            current = i * len(seeds) + fold
            total = n_folds * len(seeds)
            run = TrainRunConfig(
                name=f"{cv_name}-run-f{fold}s{seed}",
                cv_name=cv_name,
                tune_name=tune_name,
                fold=fold,
                random_seed=seed,
            )
            process_inputs.append(
                _ProcessInputs(
                    current=current,
                    total=total,
                    seed=seed,
                    fold=fold,
                    cv=cv,
                    run=run,
                    training_config=training_config,
                    logging_config=logging_config,
                    data_config=data_config,
                    device_config=device_config,
                    hparams=hparams,
                )
            )

    run_infos = []
    # This function abstracts away common logic for running multiprocessing
    for inp, output in _adp(
        process_inputs=process_inputs,
        is_parallel=device_config.strategy == "cv-parallel",
        devices=device_config.devices,
        available_devices=available_devices,
        _run=_run_training,
    ):
        run_infos.append(output.run_info)

    if len(run_infos) == 0:
        raise ValueError(
            "No runs were performed. Please check your configuration and data."
            " If you are using a tuning script, make sure to specify the correct parameters."
        )

    logger.debug(f"{run_infos=}")
    score = score_from_runs(run_infos, cv.scoring_metric, cv.multi_score_strategy)

    run_infos = pd.DataFrame(run_infos)
    run_infos["score"] = score
    is_unstable = run_infos["is_unstable"].any()
    run_infos["score_is_unstable"] = is_unstable
    if is_unstable:
        logger.warning("Score is unstable, meaning at least one of the metrics is NaN, Inf, -Inf or 0.")
    run_infos.to_parquet(artifact_dir / "run_infos.parquet")
    logger.debug(f"Saved run infos to {artifact_dir / 'run_infos.parquet'}")

    tick_fend = time.perf_counter()
    logger.info(
        f"Finished cross-validation '{cv_name}' in {tick_fend - tick_fstart:.2f}s"
        f" with {score=:.4f} ({'stable' if not is_unstable else 'unstable'})."
    )

    return score, is_unstable, run_infos

env_info

env_info()

Print debug information about the environment.

Source code in darts/src/darts/cli.py
@app.command
def env_info():
    """Print debug information about the environment."""
    from darts.utils.cuda import debug_info

    logger.debug(f"PATH: {os.environ.get('PATH', 'UNSET')}")
    debug_info()

hello

hello(name: str, *, n: int = 1)

Say hello to someone.

Parameters:

  • name (str) –

    The name of the person to say hello to

  • n (int, default: 1 ) –

    The number of times to say hello. Defaults to 1.

Raises:

Source code in darts/src/darts/cli.py
@app.command
def hello(name: str, *, n: int = 1):
    """Say hello to someone.

    Args:
        name (str): The name of the person to say hello to
        n (int, optional): The number of times to say hello. Defaults to 1.

    Raises:
        ValueError: If n is 3.

    """
    for i in range(n):
        logger.debug(f"Currently at {i=}")
        if n == 3:
            raise ValueError("I don't like 3")
        logger.info(f"Hello {name}")

help

help()

Display the help screen.

Source code in darts/src/darts/cli.py
@app.command
def help():
    """Display the help screen."""
    app.help_print()

launcher

launcher(
    *tokens: str,
    log_dir: pathlib.Path = pathlib.Path("logs"),
    config_file: pathlib.Path = pathlib.Path("config.toml"),
    verbose: bool = False,
    tracebacks_show_locals: bool = False,
)
Source code in darts/src/darts/cli.py
@app.meta.default
def launcher(  # noqa: D103
    *tokens: Annotated[str, cyclopts.Parameter(show=False, allow_leading_hyphen=True)],
    log_dir: Path = Path("logs"),
    config_file: Path = Path("config.toml"),
    verbose: bool = False,
    tracebacks_show_locals: bool = False,
):
    command, bound, ignored = app.parse_args(tokens, verbose=verbose)
    # Set verbose to true for debug stuff like env_info
    if command.__name__ == "env_info":
        verbose = True
    LoggingManager.add_logging_handlers(command.__name__, log_dir, verbose, tracebacks_show_locals)
    logger.debug(f"Running on Python version {sys.version} from {__name__} ({root_file})")
    additional_args = {}
    if "config_file" in ignored:
        additional_args["config_file"] = config_file
    if "log_dir" in ignored:
        additional_args["log_dir"] = log_dir
    if "verbose" in ignored:
        additional_args["verbose"] = verbose
    return command(*bound.args, **bound.kwargs, **additional_args)

preprocess_planet_train_data

preprocess_planet_train_data(
    *,
    data_dir: pathlib.Path,
    labels_dir: pathlib.Path,
    train_data_dir: pathlib.Path,
    arcticdem_dir: pathlib.Path,
    tcvis_dir: pathlib.Path,
    admin_dir: pathlib.Path,
    preprocess_cache: pathlib.Path | None = None,
    force_preprocess: bool = False,
    append: bool = True,
    device: typing.Literal["cuda", "cpu", "auto"]
    | int
    | None = None,
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 3,
)

Preprocess Planet data for training.

The data is split into a cross-validation, a validation-test and a test set:

- `cross-val` is meant to be used for train and validation
- `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
- `test` leave-out region for testing the spatial distribution shift of the data

Each split is stored as a zarr group, containing a x and a y dataarray. The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size). The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size). Both dataarrays are chunked along the n_patches dimension. This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and therefore in a separate file.

Through the parameters test_val_split and test_regions, the test and validation split can be controlled. To test_regions can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and put them in the test-split. With the test_val_split parameter, the ratio between further splitting of a test-validation set can be controlled.

Through exclude_nopositve and exclude_nan, respective patches can be excluded from the final data.

Further, a config.toml file is saved in the train_data_dir containing the configuration used for the preprocessing. Addionally, a labels.geojson file is saved in the train_data_dir containing the joined labels geometries used for the creation of the binarized label-masks, containing also information about the split via the mode column.

The final directory structure of train_data_dir will look like this:

train_data_dir/
├── config.toml
├── cross-val.zarr/
├── test.zarr/
├── val-test.zarr/
└── labels.geojson

Parameters:

  • data_dir (pathlib.Path) –

    The directory containing the Planet scenes and orthotiles.

  • labels_dir (pathlib.Path) –

    The directory containing the labels and footprints / extents.

  • train_data_dir (pathlib.Path) –

    The "output" directory where the tensors are written to.

  • arcticdem_dir (pathlib.Path) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist.

  • tcvis_dir (pathlib.Path) –

    The directory containing the TCVis data.

  • admin_dir (pathlib.Path) –

    The directory containing the admin files.

  • preprocess_cache (pathlib.Path, default: None ) –

    The directory to store the preprocessed data. Defaults to None.

  • force_preprocess (bool, default: False ) –

    Whether to force the preprocessing of the data. Defaults to False.

  • append (bool, default: True ) –

    Whether to append the data to the existing data. Defaults to True.

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The overlap to use for inference. Defaults to 16.

  • exclude_nopositive (bool, default: False ) –

    Whether to exclude patches where the labels do not contain positives. Defaults to False.

  • exclude_nan (bool, default: True ) –

    Whether to exclude patches where the input data has nan values. Defaults to True.

  • mask_erosion_size (int, default: 3 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

Source code in darts/src/darts/training/preprocess_planet_v2.py
def preprocess_planet_train_data(
    *,
    data_dir: Path,
    labels_dir: Path,
    train_data_dir: Path,
    arcticdem_dir: Path,
    tcvis_dir: Path,
    admin_dir: Path,
    preprocess_cache: Path | None = None,
    force_preprocess: bool = False,
    append: bool = True,
    device: Literal["cuda", "cpu", "auto"] | int | None = None,
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 3,
):
    """Preprocess Planet data for training.

    The data is split into a cross-validation, a validation-test and a test set:

        - `cross-val` is meant to be used for train and validation
        - `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
        - `test` leave-out region for testing the spatial distribution shift of the data

    Each split is stored as a zarr group, containing a x and a y dataarray.
    The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size).
    The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size).
    Both dataarrays are chunked along the n_patches dimension.
    This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and
    therefore in a separate file.

    Through the parameters `test_val_split` and `test_regions`, the test and validation split can be controlled.
    To `test_regions` can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by
    https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and
    put them in the test-split.
    With the `test_val_split` parameter, the ratio between further splitting of a test-validation set can be controlled.

    Through `exclude_nopositve` and `exclude_nan`, respective patches can be excluded from the final data.

    Further, a `config.toml` file is saved in the `train_data_dir` containing the configuration used for the
    preprocessing.
    Addionally, a `labels.geojson` file is saved in the `train_data_dir` containing the joined labels geometries used
    for the creation of the binarized label-masks, containing also information about the split via the `mode` column.

    The final directory structure of `train_data_dir` will look like this:

    ```sh
    train_data_dir/
    ├── config.toml
    ├── cross-val.zarr/
    ├── test.zarr/
    ├── val-test.zarr/
    └── labels.geojson
    ```

    Args:
        data_dir (Path): The directory containing the Planet scenes and orthotiles.
        labels_dir (Path): The directory containing the labels and footprints / extents.
        train_data_dir (Path): The "output" directory where the tensors are written to.
        arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
            Will be created and downloaded if it does not exist.
        tcvis_dir (Path): The directory containing the TCVis data.
        admin_dir (Path): The directory containing the admin files.
        preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None.
        force_preprocess (bool, optional): Whether to force the preprocessing of the data. Defaults to False.
        append (bool, optional): Whether to append the data to the existing data. Defaults to True.
        device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
            If "cuda" take the first device (0), if int take the specified device.
            If "auto" try to automatically select a free GPU (<50% memory usage).
            Defaults to "cuda" if available, else "cpu".
        ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
            project is defined within persistent API credentials obtained via `earthengine authenticate`.
        ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
        tpi_outer_radius (int, optional): The outer radius of the annulus kernel for the tpi calculation
            in m. Defaults to 100m.
        tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation
            in m. Defaults to 0.
        patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
        overlap (int, optional): The overlap to use for inference. Defaults to 16.
        exclude_nopositive (bool, optional): Whether to exclude patches where the labels do not contain positives.
            Defaults to False.
        exclude_nan (bool, optional): Whether to exclude patches where the input data has nan values.
            Defaults to True.
        mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
            Defaults to 10.

    """
    current_time = time.strftime("%Y-%m-%d_%H-%M-%S")
    logger.info(f"Starting preprocessing at {current_time}.")

    # Storing the configuration as JSON file
    train_data_dir.mkdir(parents=True, exist_ok=True)
    from darts_utils.functools import write_function_args_to_config_file

    write_function_args_to_config_file(
        fpath=train_data_dir / f"{current_time}.cli.json",
        function=preprocess_planet_train_data,
        locals_=locals(),
    )

    from stopuhr import Chronometer

    timer = Chronometer(printer=logger.debug)

    from darts.utils.cuda import debug_info

    debug_info()

    # Import here to avoid long loading times when running other commands
    import geopandas as gpd
    import pandas as pd
    import rich
    import xarray as xr
    from darts_acquisition import load_arcticdem, load_planet_masks, load_planet_scene, load_tcvis
    from darts_acquisition.admin import download_admin_files
    from darts_preprocessing import preprocess_v2
    from darts_segmentation.training.prepare_training import TrainDatasetBuilder
    from darts_segmentation.utils import Bands
    from darts_utils.tilecache import XarrayCacheManager
    from odc.stac import configure_rio
    from rich.progress import track

    from darts.utils.cuda import decide_device
    from darts.utils.earthengine import init_ee

    device = decide_device(device)
    init_ee(ee_project, ee_use_highvolume)
    configure_rio(cloud_defaults=True, aws={"aws_unsigned": True})
    logger.info("Configured Rasterio")

    labels = (gpd.read_file(labels_file) for labels_file in labels_dir.glob("*/TrainingLabel*.gpkg"))
    labels = gpd.GeoDataFrame(pd.concat(labels, ignore_index=True))

    footprints = (gpd.read_file(footprints_file) for footprints_file in labels_dir.glob("*/ImageFootprints*.gpkg"))
    footprints = gpd.GeoDataFrame(pd.concat(footprints, ignore_index=True))
    fpaths = {fpath.stem: fpath for fpath in _legacy_path_gen(data_dir)}
    footprints["fpath"] = footprints.image_id.map(fpaths)

    # Download admin files if they do not exist
    admin2_fpath = admin_dir / "geoBoundariesCGAZ_ADM2.shp"
    if not admin2_fpath.exists():
        download_admin_files(admin_dir)
    admin2 = gpd.read_file(admin2_fpath)

    # We hardcode these because they depend on the preprocessing used
    bands = Bands.from_dict(
        {
            "red": (1 / 3000, 0.0),
            "green": (1 / 3000, 0.0),
            "blue": (1 / 3000, 0.0),
            "nir": (1 / 3000, 0.0),
            "ndvi": (1 / 20000, 0.0),
            "relative_elevation": (1 / 30000, 0.0),
            "slope": (1 / 90, 0.0),
            "aspect": (1 / 360, 0.0),
            "hillshade": (1.0, 0.0),
            "curvature": (1 / 10, 0.5),  # TODO: Do we even want shift?
            "tc_brightness": (1 / 255, 0.0),
            "tc_greenness": (1 / 255, 0.0),
            "tc_wetness": (1 / 255, 0.0),
        }
    )

    builder = TrainDatasetBuilder(
        train_data_dir=train_data_dir,
        patch_size=patch_size,
        overlap=overlap,
        bands=bands,
        exclude_nopositive=exclude_nopositive,
        exclude_nan=exclude_nan,
        mask_erosion_size=mask_erosion_size,
        device=device,
        append=append,
    )
    cache_manager = XarrayCacheManager(preprocess_cache / "planet_v2")

    if append and (train_data_dir / "metadata.parquet").exists():
        metadata = gpd.read_parquet(train_data_dir / "metadata.parquet")
        already_processed_planet_ids = set(metadata["planet_id"].unique())
        logger.info(f"Already processed {len(already_processed_planet_ids)} samples.")
        footprints = footprints[~footprints.image_id.isin(already_processed_planet_ids)]

    for i, footprint in track(
        footprints.iterrows(), description="Processing samples", total=len(footprints), console=rich.get_console()
    ):
        planet_id = footprint.image_id
        try:
            logger.debug(f"Processing sample {planet_id} ({i + 1} of {len(footprints)})")

            if not footprint.fpath or (not footprint.fpath.exists() and not cache_manager.exists(planet_id)):
                logger.warning(f"Footprint image {planet_id} at {footprint.fpath} does not exist. Skipping...")
                continue

            def _get_tile():
                tile = load_planet_scene(footprint.fpath)
                arctidem_res = 2
                arcticdem_buffer = ceil(tpi_outer_radius / arctidem_res * sqrt(2))
                arcticdem = load_arcticdem(
                    tile.odc.geobox, arcticdem_dir, resolution=arctidem_res, buffer=arcticdem_buffer
                )
                tcvis = load_tcvis(tile.odc.geobox, tcvis_dir)
                data_masks = load_planet_masks(footprint.fpath)
                tile = xr.merge([tile, data_masks])

                tile: xr.Dataset = preprocess_v2(
                    tile,
                    arcticdem,
                    tcvis,
                    tpi_outer_radius,
                    tpi_inner_radius,
                    device,
                )
                return tile

            with timer("Loading tile"):
                tile = cache_manager.get_or_create(
                    identifier=planet_id,
                    creation_func=_get_tile,
                    force=force_preprocess,
                )

            logger.debug(f"Found tile with size {tile.sizes}")

            footprint_labels = labels[labels.image_id == planet_id]
            region = _get_region_name(footprint, admin2)

            with timer("Save as patches"):
                builder.add_tile_batched(
                    tile=tile,
                    labels=footprint_labels,
                    region=region,
                    sample_id=planet_id,
                    metadata={
                        "planet_id": planet_id,
                        "fpath": footprint.fpath,
                    },
                )

            logger.info(f"Processed sample {planet_id} ({i + 1} of {len(footprints)})")

        except (KeyboardInterrupt, SystemExit, SystemError):
            logger.info("Interrupted by user.")
            break

        except Exception as e:
            logger.warning(f"Could not process sample {planet_id} ({i + 1} of {len(footprints)}). \nSkipping...")
            logger.exception(e)

    builder.finalize(
        {
            "data_dir": data_dir,
            "labels_dir": labels_dir,
            "arcticdem_dir": arcticdem_dir,
            "tcvis_dir": tcvis_dir,
            "ee_project": ee_project,
            "ee_use_highvolume": ee_use_highvolume,
            "tpi_outer_radius": tpi_outer_radius,
            "tpi_inner_radius": tpi_inner_radius,
        }
    )
    timer.summary()

preprocess_planet_train_data_pingo

preprocess_planet_train_data_pingo(
    *,
    data_dir: pathlib.Path,
    labels_dir: pathlib.Path,
    train_data_dir: pathlib.Path,
    arcticdem_dir: pathlib.Path,
    tcvis_dir: pathlib.Path,
    admin_dir: pathlib.Path,
    preprocess_cache: pathlib.Path | None = None,
    force_preprocess: bool = False,
    device: typing.Literal["cuda", "cpu", "auto"]
    | int
    | None = None,
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 3,
)

Preprocess Planet data for training.

The data is split into a cross-validation, a validation-test and a test set:

- `cross-val` is meant to be used for train and validation
- `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
- `test` leave-out region for testing the spatial distribution shift of the data

Each split is stored as a zarr group, containing a x and a y dataarray. The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size). The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size). Both dataarrays are chunked along the n_patches dimension. This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and therefore in a separate file.

Through the parameters test_val_split and test_regions, the test and validation split can be controlled. To test_regions can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and put them in the test-split. With the test_val_split parameter, the ratio between further splitting of a test-validation set can be controlled.

Through exclude_nopositve and exclude_nan, respective patches can be excluded from the final data.

Further, a config.toml file is saved in the train_data_dir containing the configuration used for the preprocessing. Addionally, a labels.geojson file is saved in the train_data_dir containing the joined labels geometries used for the creation of the binarized label-masks, containing also information about the split via the mode column.

The final directory structure of train_data_dir will look like this:

train_data_dir/
├── config.toml
├── cross-val.zarr/
├── test.zarr/
├── val-test.zarr/
└── labels.geojson

Parameters:

  • data_dir (pathlib.Path) –

    The directory containing the Planet scenes and orthotiles.

  • labels_dir (pathlib.Path) –

    The directory containing the labels and footprints / extents.

  • train_data_dir (pathlib.Path) –

    The "output" directory where the tensors are written to.

  • arcticdem_dir (pathlib.Path) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist.

  • tcvis_dir (pathlib.Path) –

    The directory containing the TCVis data.

  • admin_dir (pathlib.Path) –

    The directory containing the admin files.

  • preprocess_cache (pathlib.Path, default: None ) –

    The directory to store the preprocessed data. Defaults to None.

  • force_preprocess (bool, default: False ) –

    Whether to force the preprocessing of the data. Defaults to False.

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The overlap to use for inference. Defaults to 16.

  • exclude_nopositive (bool, default: False ) –

    Whether to exclude patches where the labels do not contain positives. Defaults to False.

  • exclude_nan (bool, default: True ) –

    Whether to exclude patches where the input data has nan values. Defaults to True.

  • mask_erosion_size (int, default: 3 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

Source code in darts/src/darts/training/preprocess_planet_v2_pingo.py
def preprocess_planet_train_data_pingo(
    *,
    data_dir: Path,
    labels_dir: Path,
    train_data_dir: Path,
    arcticdem_dir: Path,
    tcvis_dir: Path,
    admin_dir: Path,
    preprocess_cache: Path | None = None,
    force_preprocess: bool = False,
    device: Literal["cuda", "cpu", "auto"] | int | None = None,
    ee_project: str | None = None,
    ee_use_highvolume: bool = True,
    tpi_outer_radius: int = 100,
    tpi_inner_radius: int = 0,
    patch_size: int = 1024,
    overlap: int = 16,
    exclude_nopositive: bool = False,
    exclude_nan: bool = True,
    mask_erosion_size: int = 3,
):
    """Preprocess Planet data for training.

    The data is split into a cross-validation, a validation-test and a test set:

        - `cross-val` is meant to be used for train and validation
        - `val-test` (5%) random leave-out for testing the randomness distribution shift of the data
        - `test` leave-out region for testing the spatial distribution shift of the data

    Each split is stored as a zarr group, containing a x and a y dataarray.
    The x dataarray contains the input data with the shape (n_patches, n_bands, patch_size, patch_size).
    The y dataarray contains the labels with the shape (n_patches, patch_size, patch_size).
    Both dataarrays are chunked along the n_patches dimension.
    This results in super fast random access to the data, because each sample / patch is stored in a separate chunk and
    therefore in a separate file.

    Through the parameters `test_val_split` and `test_regions`, the test and validation split can be controlled.
    To `test_regions` can a list of admin 1 or admin 2 region names, based on the region shapefile maintained by
    https://github.com/wmgeolab/geoBoundaries, be supplied to remove intersecting scenes from the dataset and
    put them in the test-split.
    With the `test_val_split` parameter, the ratio between further splitting of a test-validation set can be controlled.

    Through `exclude_nopositve` and `exclude_nan`, respective patches can be excluded from the final data.

    Further, a `config.toml` file is saved in the `train_data_dir` containing the configuration used for the
    preprocessing.
    Addionally, a `labels.geojson` file is saved in the `train_data_dir` containing the joined labels geometries used
    for the creation of the binarized label-masks, containing also information about the split via the `mode` column.

    The final directory structure of `train_data_dir` will look like this:

    ```sh
    train_data_dir/
    ├── config.toml
    ├── cross-val.zarr/
    ├── test.zarr/
    ├── val-test.zarr/
    └── labels.geojson
    ```

    Args:
        data_dir (Path): The directory containing the Planet scenes and orthotiles.
        labels_dir (Path): The directory containing the labels and footprints / extents.
        train_data_dir (Path): The "output" directory where the tensors are written to.
        arcticdem_dir (Path): The directory containing the ArcticDEM data (the datacube and the extent files).
            Will be created and downloaded if it does not exist.
        tcvis_dir (Path): The directory containing the TCVis data.
        admin_dir (Path): The directory containing the admin files.
        preprocess_cache (Path, optional): The directory to store the preprocessed data. Defaults to None.
        force_preprocess (bool, optional): Whether to force the preprocessing of the data. Defaults to False.
        device (Literal["cuda", "cpu"] | int, optional): The device to run the model on.
            If "cuda" take the first device (0), if int take the specified device.
            If "auto" try to automatically select a free GPU (<50% memory usage).
            Defaults to "cuda" if available, else "cpu".
        ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
            project is defined within persistent API credentials obtained via `earthengine authenticate`.
        ee_use_highvolume (bool, optional): Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).
        tpi_outer_radius (int, optional): The outer radius of the annulus kernel for the tpi calculation
            in m. Defaults to 100m.
        tpi_inner_radius (int, optional): The inner radius of the annulus kernel for the tpi calculation
            in m. Defaults to 0.
        patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
        overlap (int, optional): The overlap to use for inference. Defaults to 16.
        exclude_nopositive (bool, optional): Whether to exclude patches where the labels do not contain positives.
            Defaults to False.
        exclude_nan (bool, optional): Whether to exclude patches where the input data has nan values.
            Defaults to True.
        mask_erosion_size (int, optional): The size of the disk to use for mask erosion and the edge-cropping.
            Defaults to 10.

    """
    current_time = time.strftime("%Y-%m-%d_%H-%M-%S")
    logger.info(f"Starting preprocessing at {current_time}.")

    # Storing the configuration as JSON file
    train_data_dir.mkdir(parents=True, exist_ok=True)
    from darts_utils.functools import write_function_args_to_config_file

    write_function_args_to_config_file(
        fpath=train_data_dir / f"{current_time}.cli.json",
        function=preprocess_planet_train_data_pingo,
        locals_=locals(),
    )

    from stopuhr import Chronometer

    timer = Chronometer(printer=logger.debug)

    from darts.utils.cuda import debug_info

    debug_info()

    # Import here to avoid long loading times when running other commands
    import geopandas as gpd
    import pandas as pd
    import rich
    import xarray as xr
    from darts_acquisition import load_arcticdem, load_planet_masks, load_planet_scene, load_tcvis
    from darts_acquisition.admin import download_admin_files
    from darts_preprocessing import preprocess_v2
    from darts_segmentation.training.prepare_training import TrainDatasetBuilder
    from darts_segmentation.utils import Bands
    from darts_utils.tilecache import XarrayCacheManager
    from odc.stac import configure_rio
    from rich.progress import track

    from darts.utils.cuda import decide_device
    from darts.utils.earthengine import init_ee

    device = decide_device(device)
    init_ee(ee_project, ee_use_highvolume)
    configure_rio(cloud_defaults=True, aws={"aws_unsigned": True})
    logger.info("Configured Rasterio")

    labels = (gpd.read_file(labels_file) for labels_file in labels_dir.glob("*/TrainingLabel*.gpkg"))
    labels = gpd.GeoDataFrame(pd.concat(labels, ignore_index=True))

    footprints = (gpd.read_file(footprints_file) for footprints_file in labels_dir.glob("*/ImageFootprints*.gpkg"))
    footprints = gpd.GeoDataFrame(pd.concat(footprints, ignore_index=True))
    footprints["fpath"] = footprints.image_id.map(_path_gen(data_dir))

    # Download admin files if they do not exist
    admin2_fpath = admin_dir / "geoBoundariesCGAZ_ADM2.shp"
    if not admin2_fpath.exists():
        download_admin_files(admin_dir)
    admin2 = gpd.read_file(admin2_fpath)

    # We hardcode these because they depend on the preprocessing used
    bands = Bands.from_dict(
        {
            "red": (1 / 3000, 0.0),
            "green": (1 / 3000, 0.0),
            "blue": (1 / 3000, 0.0),
            "nir": (1 / 3000, 0.0),
            "ndvi": (1 / 20000, 0.0),
            "relative_elevation": (1 / 30000, 0.0),
            "slope": (1 / 90, 0.0),
            "aspect": (1 / 360, 0.0),
            "hillshade": (1.0, 0.0),
            "curvature": (1 / 10, 0.5),  # TODO: Do we even want shift?
            "tc_brightness": (1 / 255, 0.0),
            "tc_greenness": (1 / 255, 0.0),
            "tc_wetness": (1 / 255, 0.0),
        }
    )

    builder = TrainDatasetBuilder(
        train_data_dir=train_data_dir,
        patch_size=patch_size,
        overlap=overlap,
        bands=bands,
        exclude_nopositive=exclude_nopositive,
        exclude_nan=exclude_nan,
        mask_erosion_size=mask_erosion_size,
        device=device,
    )
    cache_manager = XarrayCacheManager(preprocess_cache / "planet_v2")

    for i, footprint in track(
        footprints.iterrows(), description="Processing samples", total=len(footprints), console=rich.get_console()
    ):
        planet_id = footprint.image_id
        try:
            logger.debug(f"Processing sample {planet_id} ({i + 1} of {len(footprints)})")

            if not footprint.fpath or (not footprint.fpath.exists() and not cache_manager.exists(planet_id)):
                logger.warning(f"Footprint image {planet_id} at {footprint.fpath} does not exist. Skipping...")
                continue

            def _get_tile():
                tile = load_planet_scene(footprint.fpath)
                arctidem_res = 2
                arcticdem_buffer = ceil(tpi_outer_radius / arctidem_res * sqrt(2))
                arcticdem = load_arcticdem(
                    tile.odc.geobox, arcticdem_dir, resolution=arctidem_res, buffer=arcticdem_buffer
                )
                tcvis = load_tcvis(tile.odc.geobox, tcvis_dir)
                data_masks = load_planet_masks(footprint.fpath)
                tile = xr.merge([tile, data_masks])

                tile: xr.Dataset = preprocess_v2(
                    tile,
                    arcticdem,
                    tcvis,
                    tpi_outer_radius,
                    tpi_inner_radius,
                    device,
                )
                return tile

            with timer("Loading tile"):
                tile = cache_manager.get_or_create(
                    identifier=planet_id,
                    creation_func=_get_tile,
                    force=force_preprocess,
                )

            logger.debug(f"Found tile with size {tile.sizes}")

            footprint_labels = labels[labels.image_id == planet_id]
            region = _get_region_name(footprint, admin2)

            with timer("Save as patches"):
                builder.add_tile(
                    tile=tile,
                    labels=footprint_labels,
                    region=region,
                    sample_id=planet_id,
                    metadata={
                        "planet_id": planet_id,
                        "fpath": footprint.fpath,
                    },
                )

            logger.info(f"Processed sample {planet_id} ({i + 1} of {len(footprints)})")

        except (KeyboardInterrupt, SystemExit, SystemError):
            logger.info("Interrupted by user.")
            break

        except Exception as e:
            logger.warning(f"Could not process sample {planet_id} ({i + 1} of {len(footprints)}). \nSkipping...")
            logger.exception(e)

    builder.finalize(
        {
            "data_dir": data_dir,
            "labels_dir": labels_dir,
            "arcticdem_dir": arcticdem_dir,
            "tcvis_dir": tcvis_dir,
            "ee_project": ee_project,
            "ee_use_highvolume": ee_use_highvolume,
            "tpi_outer_radius": tpi_outer_radius,
            "tpi_inner_radius": tpi_inner_radius,
        }
    )
    timer.summary()

shell

shell()

Open an interactive shell.

Source code in darts/src/darts/cli.py
@app.command
def shell():
    """Open an interactive shell."""
    app.interactive_shell()

start_app

start_app()

Wrapp to start the app.

Source code in darts/src/darts/cli.py
def start_app():
    """Wrapp to start the app."""
    try:
        # First time initialization of the logging manager
        LoggingManager.setup_logging()
        app.meta()
    except KeyboardInterrupt:
        logger.info("Interrupted by user. Closing...")
    except SystemExit:
        logger.info("Closing...")
    except Exception as e:
        logger.exception(e)

test_smp

test_smp(
    *,
    train_data_dir: pathlib.Path,
    run_id: str,
    run_name: str,
    model_ckp: pathlib.Path | None = None,
    batch_size: int = 8,
    data_split_method: typing.Literal[
        "random", "region", "sample"
    ]
    | None = None,
    data_split_by: list[str] | str | float | None = None,
    bands: list[str] | None = None,
    artifact_dir: pathlib.Path = pathlib.Path("artifacts"),
    num_workers: int = 0,
    device_config: darts_segmentation.training.train.DeviceConfig = darts_segmentation.training.train.DeviceConfig(),
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
) -> pytorch_lightning.Trainer

Run the testing of the SMP model.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson

Parameters:

  • train_data_dir (pathlib.Path) –

    The path (top-level) to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script.

  • run_id (str) –

    ID of the run.

  • run_name (str) –

    Name of the run.

  • model_ckp (pathlib.Path | None, default: None ) –

    Path to the model checkpoint. If None, try to find the latest checkpoint in artifact_dir / run_name / run_id / checkpoints. Defaults to None.

  • batch_size (int, default: 8 ) –

    Batch size for training and validation.

  • data_split_method (typing.Literal['random', 'region', 'sample'] | None, default: None ) –

    The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the size of the test set can be specified by providing a float between 0 and 1 to data_split_by. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.

  • data_split_by (list[str] | str | float | None, default: None ) –

    Select by which seed/regions/samples split. Defaults to None.

  • bands (list[str] | None, default: None ) –

    List of bands to use. Defaults to None.

  • artifact_dir (pathlib.Path, default: pathlib.Path('artifacts') ) –

    Directory to save artifacts. Defaults to Path("lightning_logs").

  • num_workers (int, default: 0 ) –

    Number of workers for the DataLoader. Defaults to 0.

  • device_config (darts_segmentation.training.train.DeviceConfig, default: darts_segmentation.training.train.DeviceConfig() ) –

    Device and distributed strategy related parameters.

  • wandb_entity (str | None, default: None ) –

    WandB entity. Defaults to None.

  • wandb_project (str | None, default: None ) –

    WandB project. Defaults to None.

Returns:

  • Trainer ( pytorch_lightning.Trainer ) –

    The trainer object used for training.

Source code in darts-segmentation/src/darts_segmentation/training/train.py
def test_smp(
    *,
    train_data_dir: Path,
    run_id: str,
    run_name: str,
    model_ckp: Path | None = None,
    batch_size: int = 8,
    data_split_method: Literal["random", "region", "sample"] | None = None,
    data_split_by: list[str] | str | float | None = None,
    bands: list[str] | None = None,
    artifact_dir: Path = Path("artifacts"),
    num_workers: int = 0,
    device_config: DeviceConfig = DeviceConfig(),
    wandb_entity: str | None = None,
    wandb_project: str | None = None,
) -> "pl.Trainer":
    """Run the testing of the SMP model.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── data.zarr/ # this zarr group contains the dataarrays x and y
    ├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
    └── labels.geojson
    ```

    Args:
        train_data_dir (Path): The path (top-level) to the data to be used for training.
            Expects a directory containing:
            1. a zarr group called "data.zarr" containing a "x" and "y" array
            2. a geoparquet file called "metadata.parquet" containing the metadata for the data.
                This metadata should contain at least the following columns:
                - "sample_id": The id of the sample
                - "region": The region the sample belongs to
                - "empty": Whether the image is empty
                The index should refer to the index of the sample in the zarr data.
            This directory should be created by a preprocessing script.
        run_id (str): ID of the run.
        run_name (str): Name of the run.
        model_ckp (Path | None): Path to the model checkpoint.
            If None, try to find the latest checkpoint in `artifact_dir / run_name / run_id / checkpoints`.
            Defaults to None.
        batch_size (int): Batch size for training and validation.
        data_split_method (Literal["random", "region", "sample"] | None, optional):
            The method to use for splitting the data into a train and a test set.
            "random" will split the data randomly, the seed is always 42 and the size of the test set can be
            specified by providing a float between 0 and 1 to data_split_by.
            "region" will split the data by one or multiple regions,
            which can be specified by providing a str or list of str to data_split_by.
            "sample" will split the data by sample ids, which can also be specified similar to "region".
            If None, no split is done and the complete dataset is used for both training and testing.
            The train split will further be split in the cross validation process.
            Defaults to None.
        data_split_by (list[str] | str | float | None, optional): Select by which seed/regions/samples split.
            Defaults to None.
        bands (list[str] | None, optional): List of bands to use. Defaults to None.
        artifact_dir (Path, optional): Directory to save artifacts. Defaults to Path("lightning_logs").
        num_workers (int, optional): Number of workers for the DataLoader. Defaults to 0.
        device_config (DeviceConfig, optional): Device and distributed strategy related parameters.
        wandb_entity (str | None, optional): WandB entity. Defaults to None.
        wandb_project (str | None, optional): WandB project. Defaults to None.

    Returns:
        Trainer: The trainer object used for training.

    """
    import lightning as L  # noqa: N812
    import lovely_tensors
    import torch
    from darts.utils.logging import LoggingManager
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import RichProgressBar, ThroughputMonitor
    from lightning.pytorch.loggers import CSVLogger, WandbLogger

    from darts_segmentation.training.callbacks import BinarySegmentationMetrics
    from darts_segmentation.training.data import DartsDataModule
    from darts_segmentation.training.module import LitSMP
    from darts_segmentation.utils import Bands

    LoggingManager.apply_logging_handlers("lightning.pytorch")

    tick_fstart = time.perf_counter()

    # Further nest the artifact directory to avoid cluttering the root directory
    artifact_dir = artifact_dir / "_runs"

    logger.info(
        f"Starting testing '{run_name}' ('{run_id}') with data from {train_data_dir.resolve()}."
        f" Artifacts will be saved to {(artifact_dir / f'{run_name}-{run_id}').resolve()}."
    )
    logger.debug(f"Using config:\n\t{batch_size=}\n\t{device_config}")

    lovely_tensors.set_config(color=False)
    lovely_tensors.monkey_patch()
    torch.set_float32_matmul_precision("medium")
    seed_everything(42, workers=True)

    data_config = toml.load(train_data_dir / "config.toml")["darts"]

    all_bands = Bands.from_config(data_config)
    bands = all_bands.filter(bands) if bands else all_bands

    # Data and model
    datamodule = DartsDataModule(
        data_dir=train_data_dir,
        batch_size=batch_size,
        data_split_method=data_split_method,
        data_split_by=data_split_by,
        bands=bands,
        num_workers=num_workers,
    )
    # Try to infer model checkpoint if not given
    if model_ckp is None:
        checkpoint_dir = artifact_dir / f"{run_name}-{run_id}" / "checkpoints"
        logger.debug(f"No checkpoint provided. Looking for model checkpoint in {checkpoint_dir.resolve()}")
        model_ckp = max(checkpoint_dir.glob("*.ckpt"), key=lambda x: x.stat().st_mtime)
    logger.debug(f"Using model checkpoint at {model_ckp.resolve()}")
    model = LitSMP.load_from_checkpoint(model_ckp)

    # Loggers
    trainer_loggers = [
        CSVLogger(save_dir=artifact_dir, version=f"{run_name}-{run_id}"),
    ]
    logger.debug(f"Logging CSV to {Path(trainer_loggers[0].log_dir).resolve()}")
    if wandb_entity and wandb_project:
        wandb_logger = WandbLogger(
            save_dir=artifact_dir.parent,
            name=run_name,
            version=run_id,
            project=wandb_project,
            entity=wandb_entity,
            resume="allow",
            # Using the group and job_type is a workaround for wandb's lack of support for manually sweeps
            group="none",
            job_type="none",
        )
        trainer_loggers.append(wandb_logger)
        logger.debug(
            f"Logging to WandB with entity '{wandb_entity}' and project '{wandb_project}'."
            f"Artifacts are logged to {(Path(wandb_logger.save_dir) / 'wandb').resolve()}"
        )

    # Callbacks
    callbacks = [
        RichProgressBar(),
        BinarySegmentationMetrics(
            bands=bands,
            batch_size=batch_size,
            patch_size=data_config["patch_size"],
        ),
        ThroughputMonitor(batch_size_fn=lambda batch: batch[0].size(0)),
    ]

    # Test
    trainer = L.Trainer(
        callbacks=callbacks,
        logger=trainer_loggers,
        accelerator=device_config.accelerator,
        strategy=device_config.lightning_strategy,
        num_nodes=device_config.num_nodes,
        devices=device_config.devices,
        deterministic=True,
    )

    trainer.test(model, datamodule, ckpt_path=model_ckp)

    tick_fend = time.perf_counter()
    logger.info(f"Finished testing '{run_name}' in {tick_fend - tick_fstart:.2f}s.")

    if wandb_entity and wandb_project:
        wandb_logger.finalize("success")
        wandb_logger.experiment.finish(exit_code=0)
        logger.debug(f"Finalized WandB logging for '{run_name}'")

    return trainer

train_smp

Run the training of the SMP model, specifically binary segmentation.

Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

Please also consider reading our training guide (docs/guides/training.md).

This training function is meant for single training runs but is also used for cross-validation and hyperparameter tuning by cv.py and tune.py. This strongly affects where artifacts are stored:

  • Run was created by a tune: {artifact_dir}/{tune_name}/{cv_name}/{run_name}-{run_id}
  • Run was created by a cross-validation: {artifact_dir}/_cross_validations/{cv_name}/{run_name}-{run_id}
  • Single runs: {artifact_dir}/_runs/{run_name}-{run_id}

run_name can be specified by the user, else it is generated automatically. In case of cross-validation, the run name is generated automatically by the cross-validation. run_id is generated automatically by the training function. Both are saved to the final checkpoint.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - early_stopping_patience specifies how many epochs to wait for improvement before stopping. In epochs, this would be check_val_every_n_epoch * early_stopping_patience. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs. Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch. If log_every_n_steps is set to 50 then the training logs and metrics will be logged 4 times per epoch. If check_val_every_n_epoch is set to 5 then validation will be performed every 5 epochs. If plot_every_n_val_epochs is set to 2 then validation samples will be plotted every 10 epochs. If early_stopping_patience is set to 3 then early stopping will be performed after 15 epochs without improvement.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson

Parameters:

Returns:

  • pl.Trainer: The trainer object used for training. Contains also metrics.

Source code in darts-segmentation/src/darts_segmentation/training/train.py
def train_smp(
    *,
    run: TrainRunConfig = TrainRunConfig(),
    training_config: TrainingConfig = TrainingConfig(),
    data_config: DataConfig = DataConfig(),
    logging_config: LoggingConfig = LoggingConfig(),
    device_config: DeviceConfig = DeviceConfig(),
    hparams: Hyperparameters = Hyperparameters(),
):
    """Run the training of the SMP model, specifically binary segmentation.

    Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

    Please also consider reading our training guide (docs/guides/training.md).

    This training function is meant for single training runs but is also used for cross-validation and hyperparameter
    tuning by cv.py and tune.py.
    This strongly affects where artifacts are stored:

    - Run was created by a tune: `{artifact_dir}/{tune_name}/{cv_name}/{run_name}-{run_id}`
    - Run was created by a cross-validation: `{artifact_dir}/_cross_validations/{cv_name}/{run_name}-{run_id}`
    - Single runs: `{artifact_dir}/_runs/{run_name}-{run_id}`

    `run_name` can be specified by the user, else it is generated automatically.
    In case of cross-validation, the run name is generated automatically by the cross-validation.
    `run_id` is generated automatically by the training function.
    Both are saved to the final checkpoint.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `early_stopping_patience` specifies how many epochs to wait for improvement before stopping.
            In epochs, this would be `check_val_every_n_epoch * early_stopping_patience`.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.
    Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch.
    If `log_every_n_steps` is set to 50 then the training logs and metrics will be logged 4 times per epoch.
    If `check_val_every_n_epoch` is set to 5 then validation will be performed every 5 epochs.
    If `plot_every_n_val_epochs` is set to 2 then validation samples will be plotted every 10 epochs.
    If `early_stopping_patience` is set to 3 then early stopping will be performed after 15 epochs without improvement.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── data.zarr/ # this zarr group contains the dataarrays x and y
    ├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
    └── labels.geojson
    ```

    Args:
        data_config (DataConfig): Data related parameters for training.
        run (TrainRunConfig): Run related parameters for training.
        logging_config (LoggingConfig): Logging related parameters for training.
        device_config (DeviceConfig): Device and distributed strategy related parameters.
        training_config (TrainingConfig): Training related parameters for training.
        hparams (Hyperparameters): Hyperparameters for the model.

    Returns:
        pl.Trainer: The trainer object used for training. Contains also metrics.

    """
    import lightning as L  # noqa: N812
    import lovely_tensors
    import torch
    from darts.utils.logging import LoggingManager
    from darts_utils.namegen import generate_counted_name, generate_id
    from lightning.pytorch import seed_everything
    from lightning.pytorch.callbacks import EarlyStopping, RichProgressBar
    from lightning.pytorch.loggers import CSVLogger, WandbLogger

    from darts_segmentation.segment import SMPSegmenterConfig
    from darts_segmentation.training.callbacks import BinarySegmentationMetrics, BinarySegmentationPreview
    from darts_segmentation.training.data import DartsDataModule
    from darts_segmentation.training.module import LitSMP
    from darts_segmentation.utils import Bands

    LoggingManager.apply_logging_handlers("lightning.pytorch", level=logging.INFO)

    tick_fstart = time.perf_counter()

    # Get the right nesting of the artifact directory
    artifact_dir = logging_config.artifact_dir_at_run(run.cv_name, run.tune_name)

    # Create unique run identification (name can be specified by user, id can be interpreded as a 'version')
    run_name = run.name or generate_counted_name(artifact_dir)
    run_id = generate_id()  # Needed for wandb

    logger.info(
        f"Starting training '{run_name}' ('{run_id}') with data from {data_config.train_data_dir.resolve()}."
        f" Artifacts will be saved to {(artifact_dir / f'{run_name}-{run_id}').resolve()}."
    )
    logger.debug(
        f"Using config:\n\t{run}\n\t{training_config}\n\t{data_config}\n\t{logging_config}\n\t"
        f"{device_config}\n\t{hparams}"
    )
    if training_config.continue_from_checkpoint:
        logger.debug(f"Continuing from checkpoint '{training_config.continue_from_checkpoint.resolve()}'")

    lovely_tensors.monkey_patch()
    lovely_tensors.set_config(color=False)
    torch.set_float32_matmul_precision("medium")
    seed_everything(run.random_seed, workers=True, verbose=False)

    dataset_config = toml.load(data_config.train_data_dir / "config.toml")["darts"]
    all_bands = Bands.from_config(dataset_config)
    bands = all_bands.filter(hparams.bands) if hparams.bands else all_bands
    config = SMPSegmenterConfig(
        bands=bands,
        model={
            "arch": hparams.model_arch,
            "encoder_name": hparams.model_encoder,
            "encoder_weights": hparams.model_encoder_weights,
            "in_channels": len(all_bands) if bands is None else len(bands),
            "classes": 1,
        },
    )

    # Data and model
    datamodule = DartsDataModule(
        data_dir=data_config.train_data_dir,
        batch_size=hparams.batch_size,
        data_split_method=data_config.data_split_method,
        data_split_by=data_config.data_split_by,
        fold_method=data_config.fold_method,
        total_folds=data_config.total_folds,
        fold=run.fold,
        subsample=data_config.subsample,
        bands=hparams.bands,
        augment=hparams.augment,
        num_workers=training_config.num_workers,
    )
    model = LitSMP(
        config=config,
        learning_rate=hparams.learning_rate,
        gamma=hparams.gamma,
        focal_loss_alpha=hparams.focal_loss_alpha,
        focal_loss_gamma=hparams.focal_loss_gamma,
        # These are only stored in the hparams and are not used
        run_id=run_id,
        run_name=run_name,
        cv_name=run.cv_name or "none",
        tune_name=run.tune_name or "none",
        random_seed=run.random_seed,
    )

    # Loggers
    trainer_loggers = [
        CSVLogger(save_dir=artifact_dir, name=None, version=f"{run_name}-{run_id}"),
    ]
    logger.debug(f"Logging CSV to {Path(trainer_loggers[0].log_dir).resolve()}")
    if logging_config.wandb_entity and logging_config.wandb_project:
        tags = [data_config.train_data_dir.stem]
        if run.cv_name:
            tags.append(run.cv_name)
        if run.tune_name:
            tags.append(run.tune_name)
        wandb_logger = WandbLogger(
            save_dir=artifact_dir.parent.parent if run.tune_name or run.cv_name else artifact_dir.parent,
            name=run_name,
            version=run_id,
            project=logging_config.wandb_project,
            entity=logging_config.wandb_entity,
            resume="allow",
            # Using the group and job_type is a workaround for wandb's lack of support for manually sweeps
            group=run.tune_name or "none",
            job_type=run.cv_name or "none",
            # Using tags to quickly identify the run
            tags=tags,
        )
        trainer_loggers.append(wandb_logger)
        logger.debug(
            f"Logging to WandB with entity '{logging_config.wandb_entity}' and project '{logging_config.wandb_project}'"
            f"Artifacts are logged to {(Path(wandb_logger.save_dir) / 'wandb').resolve()}"
        )

    # Callbacks and profiler
    callbacks = [
        RichProgressBar(),
        BinarySegmentationMetrics(
            bands=bands,
            val_set=f"val{run.fold}",
            plot_every_n_val_epochs=logging_config.plot_every_n_val_epochs,
            is_crossval=bool(run.cv_name),
            batch_size=hparams.batch_size,
            patch_size=dataset_config["patch_size"],
        ),
        BinarySegmentationPreview(
            bands=bands,
            val_set=f"val{run.fold}",
            plot_every_n_val_epochs=logging_config.plot_every_n_val_epochs,
        ),
        # Something does not work well here...
        # ThroughputMonitor(batch_size_fn=lambda batch: batch[0].size(0), window_size=log_every_n_steps),
    ]
    if training_config.early_stopping_patience:
        logger.debug(f"Using EarlyStopping with patience {training_config.early_stopping_patience}")
        early_stopping = EarlyStopping(
            monitor="val/JaccardIndex", mode="max", patience=training_config.early_stopping_patience
        )
        callbacks.append(early_stopping)

    # Unsupported: https://github.com/Lightning-AI/pytorch-lightning/issues/19983
    # profiler_dir = artifact_dir / f"{run_name}-{run_id}" / "profiler"
    # profiler_dir.mkdir(parents=True, exist_ok=True)
    # profiler = AdvancedProfiler(dirpath=profiler_dir, filename="perf_logs", dump_stats=True)
    # logger.debug(f"Using profiler with output to {profiler.dirpath.resolve()}")

    logger.debug(
        f"Creating lightning-trainer on {device_config.accelerator} with devices {device_config.devices}"
        f" and strategy '{device_config.lightning_strategy}'"
    )
    # Train
    trainer = L.Trainer(
        max_epochs=training_config.max_epochs,
        callbacks=callbacks,
        log_every_n_steps=logging_config.log_every_n_steps,
        logger=trainer_loggers,
        check_val_every_n_epoch=logging_config.check_val_every_n_epoch,
        accelerator=device_config.accelerator,
        devices=device_config.devices if device_config.devices[0] != "auto" else "auto",
        strategy=device_config.lightning_strategy,
        num_nodes=device_config.num_nodes,
        deterministic=False,  # True does not work for some reason
        # profiler=profiler,
    )
    trainer.fit(model, datamodule, ckpt_path=training_config.continue_from_checkpoint)

    tick_fend = time.perf_counter()
    logger.info(f"Finished training '{run_name}' in {tick_fend - tick_fstart:.2f}s.")

    if logging_config.wandb_entity and logging_config.wandb_project:
        wandb_logger.finalize("success")
        wandb_logger.experiment.finish(exit_code=0)
        logger.debug(f"Finalized WandB logging for '{run_name}'")

    return trainer

tune_smp

Tune the hyper-parameters of the model using cross-validation and random states.

Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

Please also consider reading our training guide (docs/guides/training.md).

This tuning script is designed to sweep over hyperparameters with a cross-validation used to evaluate each hyperparameter configuration. Optionally, by setting retrain_and_test to True, the best hyperparameters are then selected based on the cross-validation scores and a new model is trained on the entire train-split and tested on the test-split.

Hyperparameters can be configured using a hpconfig file (YAML or Toml). Please consult the training guide or the documentation of darts_segmentation.training.hparams.parse_hyperparameters to learn how such a file should be structured. Per default, a random search is performed, where the number of samples can be specified by n_trials. If n_trials is set to "grid", a grid search is performed instead. However, this expects to be every hyperparameter to be configured as either constant value or a choice / list.

To specify on which metric(s) the cv score is calculated, the scoring_metric parameter can be specified. Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics. This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower". If no direction is provided, it is assumed to be ":higher". Has no real effect on the single score calculation, since only the mean is calculated there.

In a multi-score setting, the score is calculated by combine-then-reduce the metrics. Meaning that first for each fold the metrics are combined using the specified strategy, and then the results are reduced via mean. Please refer to the documentation to understand the different multi-score strategies.

If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable". In such cases, the configuration is not considered for further evaluation.

Artifacts are stored under {artifact_dir}/{tune_name}.

You can specify the frequency on how often logs will be written and validation will be performed. - log_every_n_steps specifies how often train-logs will be written. This does not affect validation. - check_val_every_n_epoch specifies how often validation will be performed. This will also affect early stopping. - early_stopping_patience specifies how many epochs to wait for improvement before stopping. In epochs, this would be check_val_every_n_epoch * early_stopping_patience. - plot_every_n_val_epochs specifies how often validation samples will be plotted. Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping. In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs. Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch. If log_every_n_steps is set to 50 then the training logs and metrics will be logged 4 times per epoch. If check_val_every_n_epoch is set to 5 then validation will be performed every 5 epochs. If plot_every_n_val_epochs is set to 2 then validation samples will be plotted every 10 epochs. If early_stopping_patience is set to 3 then early stopping will be performed after 15 epochs without improvement.

The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:

preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson

Parameters:

Returns:

  • tuple[float, pd.DataFrame]: The best score (if retrained and tested) and the run infos of all runs.

Raises:

  • ValueError

    If no hyperparameter configuration file is provided.

Source code in darts-segmentation/src/darts_segmentation/training/tune.py
def tune_smp(
    *,
    name: str | None = None,
    n_trials: int | Literal["grid"] = 100,
    retrain_and_test: bool = False,
    cv_config: CrossValidationConfig = CrossValidationConfig(),
    training_config: TrainingConfig = TrainingConfig(),
    data_config: DataConfig = DataConfig(),
    device_config: DeviceConfig = DeviceConfig(),
    logging_config: LoggingConfig = LoggingConfig(),
    hpconfig: Path | None = None,
    config_file: Annotated[Path | None, cyclopts.Parameter(parse=False)] = None,
):
    """Tune the hyper-parameters of the model using cross-validation and random states.

    Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.

    Please also consider reading our training guide (docs/guides/training.md).

    This tuning script is designed to sweep over hyperparameters with a cross-validation
    used to evaluate each hyperparameter configuration.
    Optionally, by setting `retrain_and_test` to True, the best hyperparameters are then selected based on the
    cross-validation scores and a new model is trained on the entire train-split and tested on the test-split.

    Hyperparameters can be configured using a `hpconfig` file (YAML or Toml).
    Please consult the training guide or the documentation of
    `darts_segmentation.training.hparams.parse_hyperparameters` to learn how such a file should be structured.
    Per default, a random search is performed, where the number of samples can be specified by `n_trials`.
    If `n_trials` is set to "grid", a grid search is performed instead.
    However, this expects to be every hyperparameter to be configured as either constant value or a choice / list.

    To specify on which metric(s) the cv score is calculated, the `scoring_metric` parameter can be specified.
    Each score can be provided by either ":higher" or ":lower" to indicate the direction of the metrics.
    This allows to correctly combine multiple metrics by doing 1/metric before calculation if a metric is ":lower".
    If no direction is provided, it is assumed to be ":higher".
    Has no real effect on the single score calculation, since only the mean is calculated there.

    In a multi-score setting, the score is calculated by combine-then-reduce the metrics.
    Meaning that first for each fold the metrics are combined using the specified strategy,
    and then the results are reduced via mean.
    Please refer to the documentation to understand the different multi-score strategies.

    If one of the metrics of any of the runs contains NaN, Inf, -Inf or is 0 the score is reported to be "unstable".
    In such cases, the configuration is not considered for further evaluation.

    Artifacts are stored under `{artifact_dir}/{tune_name}`.

    You can specify the frequency on how often logs will be written and validation will be performed.
        - `log_every_n_steps` specifies how often train-logs will be written. This does not affect validation.
        - `check_val_every_n_epoch` specifies how often validation will be performed.
            This will also affect early stopping.
        - `early_stopping_patience` specifies how many epochs to wait for improvement before stopping.
            In epochs, this would be `check_val_every_n_epoch * early_stopping_patience`.
        - `plot_every_n_val_epochs` specifies how often validation samples will be plotted.
            Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
            In epochs, this would be `check_val_every_n_epoch * plot_every_n_val_epochs`.
    Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch.
    If `log_every_n_steps` is set to 50 then the training logs and metrics will be logged 4 times per epoch.
    If `check_val_every_n_epoch` is set to 5 then validation will be performed every 5 epochs.
    If `plot_every_n_val_epochs` is set to 2 then validation samples will be plotted every 10 epochs.
    If `early_stopping_patience` is set to 3 then early stopping will be performed after 15 epochs without improvement.

    The data structure of the training data expects the "preprocessing" step to be done beforehand,
    which results in the following data structure:

    ```sh
    preprocessed-data/ # the top-level directory
    ├── config.toml
    ├── data.zarr/ # this zarr group contains the dataarrays x and y
    ├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
    └── labels.geojson
    ```

    Args:
        name (str | None, optional): Name of the tuning run.
            Will be generated based on the number of existing directories in the artifact directory if None.
            Defaults to None.
        n_trials (int | Literal["grid"], optional): Number of trials to perform in hyperparameter tuning.
            If "grid", span a grid search over all configured hyperparameters.
            In a grid search, only constant or choice hyperparameters are allowed.
            Defaults to 100.
        retrain_and_test (bool, optional): Whether to retrain the model with the best hyperparameters and test it.
            Defaults to False.
        cv_config (CrossValidationConfig, optional): Configuration for cross-validation.
            Defaults to CrossValidationConfig().
        training_config (TrainingConfig, optional): Configuration for training.
            Defaults to TrainingConfig().
        data_config (DataConfig, optional): Configuration for data.
            Defaults to DataConfig().
        device_config (DeviceConfig, optional): Configuration for device.
            Defaults to DeviceConfig().
        logging_config (LoggingConfig, optional): Configuration for logging.
            Defaults to LoggingConfig().
        hpconfig (Path | None, optional): Path to the hyperparameter configuration file.
            Please see the documentation of `hyperparameters` for more information.
            Defaults to None.
        config_file (Path | None, optional): Path to the configuration file. If provided,
            it will be used instead of `hpconfig` if `hpconfig` is None. Defaults to None.

    Returns:
        tuple[float, pd.DataFrame]: The best score (if retrained and tested) and the run infos of all runs.

    Raises:
        ValueError: If no hyperparameter configuration file is provided.

    """
    import pandas as pd
    from darts_utils.namegen import generate_counted_name

    from darts_segmentation.training.adp import _adp
    from darts_segmentation.training.hparams import parse_hyperparameters, sample_hyperparameters
    from darts_segmentation.training.scoring import score_from_single_run
    from darts_segmentation.training.train import test_smp, train_smp

    tick_fstart = time.perf_counter()

    tune_name = name or generate_counted_name(logging_config.artifact_dir)
    artifact_dir = logging_config.artifact_dir / tune_name
    run_infos_file = artifact_dir / f"{tune_name}.parquet"

    # Check if the artifact directory is empty
    assert not artifact_dir.exists(), f"{artifact_dir} already exists."
    artifact_dir.mkdir(parents=True, exist_ok=True)

    hpconfig = hpconfig or config_file
    if hpconfig is None:
        raise ValueError(
            "No hyperparameter configuration file provided. Please provide a valid file via the `--hpconfig` flag."
        )
    param_grid = parse_hyperparameters(hpconfig)
    logger.debug(f"Parsed hyperparameter grid: {param_grid}")
    param_list = sample_hyperparameters(param_grid, n_trials)

    logger.info(
        f"Starting tune '{tune_name}' with data from {data_config.train_data_dir.resolve()}."
        f" Artifacts will be saved to {artifact_dir.resolve()}."
        f" Will run n_trials*n_randoms*n_folds ="
        f" {len(param_list)}*{cv_config.n_randoms}*{cv_config.n_folds} ="
        f" {len(param_list) * cv_config.n_randoms * cv_config.n_folds} experiments."
    )

    # Plan which runs to perform. These are later consumed based on the parallelization strategy.
    process_inputs = [
        _ProcessInputs(
            current=i,
            total=len(param_list),
            tune_name=tune_name,
            cv=cv_config,
            training_config=training_config,
            logging_config=logging_config,
            data_config=data_config,
            device_config=device_config,
            hparams=hparams,
        )
        for i, hparams in enumerate(param_list)
    ]

    run_infos: list[pd.DataFrame] = []
    best_score = 0
    best_hp = None

    # This function abstracts away common logic for running multiprocessing
    for inp, output in _adp(
        process_inputs=process_inputs,
        is_parallel=device_config.strategy == "tune-parallel",
        devices=device_config.devices,
        available_devices=available_devices,
        _run=_run_cv,
    ):
        run_infos.append(output.run_infos)
        if not output.is_unstable and output.score > best_score:
            best_score = output.score
            best_hp = inp.hparams

        # Save already here to prevent data loss if something goes wrong
        pd.concat(run_infos).reset_index(drop=True).to_parquet(run_infos_file)
        logger.debug(f"Saved run infos to {run_infos_file}")

    if len(run_infos) == 0:
        logger.error("No hyperparameters resulted in a valid score. Please check the logs for more information.")
        return 0, run_infos

    run_infos = pd.concat(run_infos).reset_index(drop=True)

    tick_fend = time.perf_counter()

    if best_hp is None:
        logger.warning(
            f"Tuning completed in {tick_fend - tick_fstart:.2f}s."
            " No hyperparameters resulted in a valid score. Please check the logs for more information."
        )
        return 0, run_infos
    logger.info(
        f"Tuning completed in {tick_fend - tick_fstart:.2f}s. The best score was {best_score:.4f} with {best_hp}."
    )

    # =====================
    # === End of tuning ===
    # =====================

    if not retrain_and_test:
        return 0, run_infos

    logger.info("Starting retraining with the best hyperparameters.")

    tick_fstart = time.perf_counter()
    trainer = train_smp(
        run=TrainRunConfig(name=f"{tune_name}-retrain"),
        training_config=training_config,  # TODO: device and strategy
        data_config=DataConfig(
            train_data_dir=data_config.train_data_dir,
            data_split_method=data_config.data_split_method,
            data_split_by=data_config.data_split_by,
            fold_method=None,  # No fold method for retraining
            total_folds=None,  # No folds for retraining
        ),
        logging_config=LoggingConfig(
            artifact_dir=artifact_dir,
            log_every_n_steps=logging_config.log_every_n_steps,
            check_val_every_n_epoch=logging_config.check_val_every_n_epoch,
            plot_every_n_val_epochs=logging_config.plot_every_n_val_epochs,
            wandb_entity=logging_config.wandb_entity,
            wandb_project=logging_config.wandb_project,
        ),
        hparams=best_hp,
    )
    run_id = trainer.lightning_module.hparams["run_id"]
    trainer = test_smp(
        train_data_dir=data_config.train_data_dir,
        run_id=run_id,
        run_name=f"{tune_name}-retrain",
        model_ckp=trainer.checkpoint_callback.best_model_path,
        batch_size=best_hp.batch_size,
        data_split_method=data_config.data_split_method,
        data_split_by=data_config.data_split_by,
        artifact_dir=artifact_dir,
        num_workers=training_config.num_workers,
        device_config=device_config,
        wandb_entity=logging_config.wandb_entity,
        wandb_project=logging_config.wandb_project,
    )

    run_info = {k: v.item() for k, v in trainer.callback_metrics.items()}
    test_scoring_metric = (
        cv_config.scoring_metric.replace("val/", "test/")
        if isinstance(cv_config.scoring_metric, str)
        else [sm.replace("val/", "test/") for sm in cv_config.scoring_metric]
    )
    score = score_from_single_run(run_info, test_scoring_metric, cv_config.multi_score_strategy)
    is_unstable = check_score_is_unstable(run_info, cv_config.scoring_metric)
    tick_fend = time.perf_counter()
    logger.info(
        f"Retraining and testing completed successfully in {tick_fend - tick_fstart:.2f}s"
        f" with {score=:.4f} ({'stable' if not is_unstable else 'unstable'})."
    )

    return score, run_infos