Skip to content

darts.pipelines.AOISentinel2Pipeline

Bases: darts.pipelines.sequential_v2._BasePipeline

Pipeline for Sentinel 2 data based on an area of interest.

Parameters:

  • aoi_shapefile (pathlib.Path, default: None ) –

    The shapefile containing the area of interest.

  • start_date (str, default: None ) –

    The start date of the time series in YYYY-MM-DD format.

  • end_date (str, default: None ) –

    The end date of the time series in YYYY-MM-DD format.

  • max_cloud_cover (int, default: 10 ) –

    The maximum cloud cover percentage to use for filtering the Sentinel 2 scenes. Defaults to 10.

  • input_cache (pathlib.Path, default: pathlib.Path('data/cache/input') ) –

    The directory to use for caching the input data. Defaults to Path("data/cache/input").

  • model_files (pathlib.Path | list[pathlib.Path], default: None ) –

    The path to the models to use for segmentation. Can also be a single Path to only use one model. This implies write_model_outputs=False If a list is provided, will use an ensemble of the models.

  • output_data_dir (pathlib.Path, default: pathlib.Path('data/output') ) –

    The "output" directory. Defaults to Path("data/output").

  • arcticdem_dir (pathlib.Path, default: pathlib.Path('data/download/arcticdem') ) –

    The directory containing the ArcticDEM data (the datacube and the extent files). Will be created and downloaded if it does not exist. Defaults to Path("data/download/arcticdem").

  • tcvis_dir (pathlib.Path, default: pathlib.Path('data/download/tcvis') ) –

    The directory containing the TCVis data. Defaults to Path("data/download/tcvis").

  • device (typing.Literal['cuda', 'cpu'] | int, default: None ) –

    The device to run the model on. If "cuda" take the first device (0), if int take the specified device. If "auto" try to automatically select a free GPU (<50% memory usage). Defaults to "cuda" if available, else "cpu".

  • ee_project (str, default: None ) –

    The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

  • ee_use_highvolume (bool, default: True ) –

    Whether to use the high volume server (https://earthengine-highvolume.googleapis.com).

  • tpi_outer_radius (int, default: 100 ) –

    The outer radius of the annulus kernel for the tpi calculation in m. Defaults to 100m.

  • tpi_inner_radius (int, default: 0 ) –

    The inner radius of the annulus kernel for the tpi calculation in m. Defaults to 0.

  • patch_size (int, default: 1024 ) –

    The patch size to use for inference. Defaults to 1024.

  • overlap (int, default: 256 ) –

    The overlap to use for inference. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size to use for inference. Defaults to 8.

  • reflection (int, default: 0 ) –

    The reflection padding to use for inference. Defaults to 0.

  • binarization_threshold (float, default: 0.5 ) –

    The threshold to binarize the probabilities. Defaults to 0.5.

  • mask_erosion_size (int, default: 10 ) –

    The size of the disk to use for mask erosion and the edge-cropping. Defaults to 10.

  • min_object_size (int, default: 32 ) –

    The minimum object size to keep in pixel. Defaults to 32.

  • quality_level (int | typing.Literal['high_quality', 'low_quality', 'none'], default: 1 ) –

    The quality level to use for the segmentation. Can also be an int. In this case 0="none" 1="low_quality" 2="high_quality". Defaults to 1.

  • export_bands (list[str], default: lambda: ['probabilities', 'binarized', 'polygonized', 'extent', 'thumbnail']() ) –

    The bands to export. Can be a list of "probabilities", "binarized", "polygonized", "extent", "thumbnail", "optical", "dem", "tcvis" or concrete band-names. Defaults to ["probabilities", "binarized", "polygonized", "extent", "thumbnail"].

  • write_model_outputs (bool, default: False ) –

    Also save the model outputs, not only the ensemble result. Defaults to False.

  • overwrite (bool, default: False ) –

    Whether to overwrite existing files. Defaults to False.

_s2ids cached property

_s2ids: list[str]

aoi_shapefile class-attribute instance-attribute

aoi_shapefile: pathlib.Path = None

arcticdem_dir class-attribute instance-attribute

arcticdem_dir: pathlib.Path = pathlib.Path(
    "data/download/arcticdem"
)

batch_size class-attribute instance-attribute

batch_size: int = 8

binarization_threshold class-attribute instance-attribute

binarization_threshold: float = 0.5

device class-attribute instance-attribute

device: (
    typing.Literal["cuda", "cpu", "auto"] | int | None
) = None

ee_project class-attribute instance-attribute

ee_project: str | None = None

ee_use_highvolume class-attribute instance-attribute

ee_use_highvolume: bool = True

end_date class-attribute instance-attribute

end_date: str = None

export_bands class-attribute instance-attribute

export_bands: list[str] = dataclasses.field(
    default_factory=lambda: [
        "probabilities",
        "binarized",
        "polygonized",
        "extent",
        "thumbnail",
    ]
)

input_cache class-attribute instance-attribute

input_cache: pathlib.Path = pathlib.Path("data/cache/input")

mask_erosion_size class-attribute instance-attribute

mask_erosion_size: int = 10

max_cloud_cover class-attribute instance-attribute

max_cloud_cover: int = 10

min_object_size class-attribute instance-attribute

min_object_size: int = 32

model_files class-attribute instance-attribute

model_files: list[pathlib.Path] = None

output_data_dir class-attribute instance-attribute

output_data_dir: pathlib.Path = pathlib.Path('data/output')

overlap class-attribute instance-attribute

overlap: int = 256

overwrite class-attribute instance-attribute

overwrite: bool = False

patch_size class-attribute instance-attribute

patch_size: int = 1024

quality_level class-attribute instance-attribute

quality_level: (
    int
    | typing.Literal["high_quality", "low_quality", "none"]
) = 1

reflection class-attribute instance-attribute

reflection: int = 0

start_date class-attribute instance-attribute

start_date: str = None

tcvis_dir class-attribute instance-attribute

tcvis_dir: pathlib.Path = pathlib.Path(
    "data/download/tcvis"
)

tpi_inner_radius class-attribute instance-attribute

tpi_inner_radius: int = 0

tpi_outer_radius class-attribute instance-attribute

tpi_outer_radius: int = 100

write_model_outputs class-attribute instance-attribute

write_model_outputs: bool = False

_arcticdem_resolution

_arcticdem_resolution() -> typing.Literal[10]
Source code in darts/src/darts/pipelines/sequential_v2.py
def _arcticdem_resolution(self) -> Literal[10]:
    return 10

_get_tile_id

_get_tile_id(tilekey)
Source code in darts/src/darts/pipelines/sequential_v2.py
def _get_tile_id(self, tilekey):
    # In case of the GEE tilekey is also the s2id
    return tilekey

_load_tile

_load_tile(s2id: str) -> xarray.Dataset
Source code in darts/src/darts/pipelines/sequential_v2.py
def _load_tile(self, s2id: str) -> "xr.Dataset":
    from darts_acquisition.s2 import load_s2_from_gee

    tile = load_s2_from_gee(s2id, cache=self.input_cache)
    return tile

_tileinfos

_tileinfos() -> list[tuple[str, pathlib.Path]]
Source code in darts/src/darts/pipelines/sequential_v2.py
def _tileinfos(self) -> list[tuple[str, Path]]:
    out = []
    for s2id in self._s2ids:
        outpath = self.output_data_dir / s2id
        out.append((s2id, outpath))
    out.sort()
    return out

cli staticmethod

cli(
    *,
    pipeline: darts.pipelines.sequential_v2.AOISentinel2Pipeline,
)

Run the sequential pipeline for AOI Sentinel 2 data.

Source code in darts/src/darts/pipelines/sequential_v2.py
@staticmethod
def cli(*, pipeline: "AOISentinel2Pipeline"):
    """Run the sequential pipeline for AOI Sentinel 2 data."""
    pipeline.run()

run

run()
Source code in darts/src/darts/pipelines/sequential_v2.py
def run(self):  # noqa: C901
    if self.model_files is None or len(self.model_files) == 0:
        raise ValueError("No model files provided. Please provide a list of model files.")
    if len(self.export_bands) == 0:
        raise ValueError("No export bands provided. Please provide a list of export bands.")

    current_time = time.strftime("%Y-%m-%d_%H-%M-%S")
    logger.info(f"Starting pipeline at {current_time}.")

    # Storing the configuration as JSON file
    self.output_data_dir.mkdir(parents=True, exist_ok=True)
    with open(self.output_data_dir / f"{current_time}.config.json", "w") as f:
        config = asdict(self)
        # Convert everything to json serializable
        for key, value in config.items():
            if isinstance(value, Path):
                config[key] = str(value.resolve())
            elif isinstance(value, list):
                config[key] = [str(v.resolve()) if isinstance(v, Path) else v for v in value]
        json.dump(config, f)

    from stopuhr import StopUhr

    stopuhr = StopUhr(printer=logger.debug)

    from darts.utils.cuda import debug_info

    debug_info()

    from darts.utils.earthengine import init_ee

    init_ee(self.ee_project, self.ee_use_highvolume)

    import pandas as pd
    import smart_geocubes
    import torch
    from darts_acquisition import load_arcticdem, load_tcvis
    from darts_ensemble import EnsembleV1
    from darts_export import export_tile, missing_outputs
    from darts_postprocessing import prepare_export
    from darts_preprocessing import preprocess_legacy_fast

    from darts.utils.cuda import decide_device
    from darts.utils.logging import LoggingManager

    self.device = decide_device(self.device)

    # determine models to use
    if isinstance(self.model_files, Path):
        self.model_files = [self.model_files]
        self.write_model_outputs = False
    models = {model_file.stem: model_file for model_file in self.model_files}
    ensemble = EnsembleV1(models, device=torch.device(self.device))

    # Create the datacubes if they do not exist
    LoggingManager.apply_logging_handlers("smart_geocubes")
    arcticdem_resolution = self._arcticdem_resolution()
    if arcticdem_resolution == 2:
        accessor = smart_geocubes.ArcticDEM2m(self.arcticdem_dir)
    elif arcticdem_resolution == 10:
        accessor = smart_geocubes.ArcticDEM10m(self.arcticdem_dir)
    if not accessor.created:
        accessor.create(overwrite=False)
    accessor = smart_geocubes.TCTrend(self.tcvis_dir)
    if not accessor.created:
        accessor.create(overwrite=False)

    # Iterate over all the data
    tileinfo = self._tileinfos()
    n_tiles = 0
    logger.info(f"Found {len(tileinfo)} tiles to process.")
    results = []
    for i, (tilekey, outpath) in enumerate(tileinfo):
        tile_id = self._get_tile_id(tilekey)
        try:
            if not self.overwrite:
                mo = missing_outputs(outpath, bands=self.export_bands, ensemble_subsets=models.keys())
                if mo == "none":
                    logger.info(f"Tile {tile_id} already processed. Skipping...")
                    continue
                if mo == "some":
                    logger.warning(
                        f"Tile {tile_id} already processed. Some outputs are missing."
                        " Skipping because overwrite=False..."
                    )
                    continue

            with stopuhr("Loading optical data", log=False):
                tile = self._load_tile(tilekey)
            with stopuhr("Loading ArcticDEM", log=False):
                arcticdem = load_arcticdem(
                    tile.odc.geobox,
                    self.arcticdem_dir,
                    resolution=arcticdem_resolution,
                    buffer=ceil(self.tpi_outer_radius / 2 * sqrt(2)),
                )
            with stopuhr("Loading TCVis", log=False):
                tcvis = load_tcvis(tile.odc.geobox, self.tcvis_dir)
            with stopuhr("Preprocessing tile", log=False):
                tile = preprocess_legacy_fast(
                    tile,
                    arcticdem,
                    tcvis,
                    self.tpi_outer_radius,
                    self.tpi_inner_radius,
                    self.device,
                )
            with stopuhr("Segmenting", log=False):
                tile = ensemble.segment_tile(
                    tile,
                    patch_size=self.patch_size,
                    overlap=self.overlap,
                    batch_size=self.batch_size,
                    reflection=self.reflection,
                    keep_inputs=self.write_model_outputs,
                )
            with stopuhr("Postprosessing", log=False):
                tile = prepare_export(
                    tile,
                    bin_threshold=self.binarization_threshold,
                    mask_erosion_size=self.mask_erosion_size,
                    min_object_size=self.min_object_size,
                    quality_level=self.quality_level,
                    ensemble_subsets=models.keys() if self.write_model_outputs else [],
                    device=self.device,
                )

            with stopuhr("Exporting", log=False):
                export_tile(
                    tile,
                    outpath,
                    bands=self.export_bands,
                    ensemble_subsets=models.keys() if self.write_model_outputs else [],
                )

            n_tiles += 1
            results.append(
                {
                    "tile_id": tile_id,
                    "output_path": str(outpath.resolve()),
                    "status": "success",
                    "error": None,
                }
            )
            logger.info(f"Processed sample {i + 1} of {len(tileinfo)} '{tilekey}' ({tile_id=}).")
        except KeyboardInterrupt:
            logger.warning("Keyboard interrupt detected.\nExiting...")
            raise KeyboardInterrupt
        except Exception as e:
            logger.warning(f"Could not process '{tilekey}' ({tile_id=}).\nSkipping...")
            logger.exception(e)
            results.append(
                {
                    "tile_id": tile_id,
                    "output_path": str(outpath.resolve()),
                    "status": "failed",
                    "error": str(e),
                }
            )
        finally:
            if len(results) > 0:
                pd.DataFrame(results).to_parquet(self.output_data_dir / f"{current_time}.results.parquet")
            stopuhr.export().to_parquet(self.output_data_dir / f"{current_time}.stopuhr.parquet")
    else:
        logger.info(f"Processed {n_tiles} tiles to {self.output_data_dir.resolve()}.")
        stopuhr.summary()