Skip to content

Combined Reference

All references on one page

darts

DARTS processing pipeline.

__version__ = version('darts-nextgen') module-attribute

run_native_planet_pipeline(orthotiles_dir, scenes_dir, output_data_dir, arcticdem_slope_vrt, arcticdem_elevation_vrt, model_dir, tcvis_model_name='RTS_v6_tcvis.pt', notcvis_model_name='RTS_v6_notcvis.pt', cache_dir=None, ee_project=None, patch_size=1024, overlap=16, batch_size=8, reflection=0)

Search for all PlanetScope scenes in the given directory and runs the segmentation pipeline on them.

Parameters:

Name Type Description Default
orthotiles_dir Path

The directory containing the PlanetScope orthotiles.

required
scenes_dir Path

The directory containing the PlanetScope scenes.

required
output_data_dir Path

The "output" directory.

required
arcticdem_slope_vrt Path

The path to the ArcticDEM slope VRT file.

required
arcticdem_elevation_vrt Path

The path to the ArcticDEM elevation VRT file.

required
model_dir Path

The path to the models to use for segmentation.

required
tcvis_model_name str

The name of the model to use for TCVis. Defaults to "RTS_v6_tcvis.pt".

'RTS_v6_tcvis.pt'
notcvis_model_name str

The name of the model to use for not TCVis. Defaults to "RTS_v6_notcvis.pt".

'RTS_v6_notcvis.pt'
cache_dir Path | None

The cache directory. If None, no caching will be used. Defaults to None.

None
ee_project str

The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

None
patch_size int

The patch size to use for inference. Defaults to 1024.

1024
overlap int

The overlap to use for inference. Defaults to 16.

16
batch_size int

The batch size to use for inference. Defaults to 8.

8
reflection int

The reflection padding to use for inference. Defaults to 0.

0
Todo

Document the structure of the input data dir.

Source code in darts/src/darts/native.py
def run_native_planet_pipeline(
    orthotiles_dir: Path,
    scenes_dir: Path,
    output_data_dir: Path,
    arcticdem_slope_vrt: Path,
    arcticdem_elevation_vrt: Path,
    model_dir: Path,
    tcvis_model_name: str = "RTS_v6_tcvis.pt",
    notcvis_model_name: str = "RTS_v6_notcvis.pt",
    cache_dir: Path | None = None,
    ee_project: str | None = None,
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
):
    """Search for all PlanetScope scenes in the given directory and runs the segmentation pipeline on them.

    Args:
        orthotiles_dir (Path): The directory containing the PlanetScope orthotiles.
        scenes_dir (Path): The directory containing the PlanetScope scenes.
        output_data_dir (Path): The "output" directory.
        arcticdem_slope_vrt (Path): The path to the ArcticDEM slope VRT file.
        arcticdem_elevation_vrt (Path): The path to the ArcticDEM elevation VRT file.
        model_dir (Path): The path to the models to use for segmentation.
        tcvis_model_name (str, optional): The name of the model to use for TCVis. Defaults to "RTS_v6_tcvis.pt".
        notcvis_model_name (str, optional): The name of the model to use for not TCVis. Defaults to "RTS_v6_notcvis.pt".
        cache_dir (Path | None, optional): The cache directory. If None, no caching will be used. Defaults to None.
        ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
            project is defined within persistent API credentials obtained via `earthengine authenticate`.
        patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
        overlap (int, optional): The overlap to use for inference. Defaults to 16.
        batch_size (int, optional): The batch size to use for inference. Defaults to 8.
        reflection (int, optional): The reflection padding to use for inference. Defaults to 0.

    Todo:
        Document the structure of the input data dir.

    """
    # Import here to avoid long loading times when running other commands
    from darts_ensemble.ensemble_v1 import EnsembleV1
    from darts_export.inference import InferenceResultWriter
    from darts_postprocessing import prepare_export
    from darts_preprocessing import load_and_preprocess_planet_scene

    from darts.utils.earthengine import init_ee

    init_ee(ee_project)

    # Find all PlanetScope orthotiles
    for fpath, outpath in planet_file_generator(orthotiles_dir, scenes_dir, output_data_dir):
        tile = load_and_preprocess_planet_scene(fpath, arcticdem_slope_vrt, arcticdem_elevation_vrt, cache_dir)

        ensemble = EnsembleV1(model_dir / tcvis_model_name, model_dir / notcvis_model_name)
        tile = ensemble.segment_tile(
            tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )
        tile = prepare_export(tile)

        outpath.mkdir(parents=True, exist_ok=True)
        writer = InferenceResultWriter(tile)
        writer.export_probabilities(outpath)
        writer.export_binarized(outpath)
        writer.export_polygonized(outpath)

run_native_sentinel2_pipeline(sentinel2_dir, output_data_dir, arcticdem_slope_vrt, arcticdem_elevation_vrt, model_dir, tcvis_model_name='RTS_v6_tcvis.pt', notcvis_model_name='RTS_v6_notcvis.pt', cache_dir=None, ee_project=None, patch_size=1024, overlap=16, batch_size=8, reflection=0)

Search for all PlanetScope scenes in the given directory and runs the segmentation pipeline on them.

Parameters:

Name Type Description Default
sentinel2_dir Path

The directory containing the Sentinel 2 scenes.

required
output_data_dir Path

The "output" directory.

required
arcticdem_slope_vrt Path

The path to the ArcticDEM slope VRT file.

required
arcticdem_elevation_vrt Path

The path to the ArcticDEM elevation VRT file.

required
model_dir Path

The path to the models to use for segmentation.

required
tcvis_model_name str

The name of the model to use for TCVis. Defaults to "RTS_v6_tcvis.pt".

'RTS_v6_tcvis.pt'
notcvis_model_name str

The name of the model to use for not TCVis. Defaults to "RTS_v6_notcvis.pt".

'RTS_v6_notcvis.pt'
cache_dir Path | None

The cache directory. If None, no caching will be used. Defaults to None.

None
ee_project str

The Earth Engine project ID or number to use. May be omitted if project is defined within persistent API credentials obtained via earthengine authenticate.

None
patch_size int

The patch size to use for inference. Defaults to 1024.

1024
overlap int

The overlap to use for inference. Defaults to 16.

16
batch_size int

The batch size to use for inference. Defaults to 8.

8
reflection int

The reflection padding to use for inference. Defaults to 0.

0
Todo

Document the structure of the input data dir.

Source code in darts/src/darts/native.py
def run_native_sentinel2_pipeline(
    sentinel2_dir: Path,
    output_data_dir: Path,
    arcticdem_slope_vrt: Path,
    arcticdem_elevation_vrt: Path,
    model_dir: Path,
    tcvis_model_name: str = "RTS_v6_tcvis.pt",
    notcvis_model_name: str = "RTS_v6_notcvis.pt",
    cache_dir: Path | None = None,
    ee_project: str | None = None,
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
):
    """Search for all PlanetScope scenes in the given directory and runs the segmentation pipeline on them.

    Args:
        sentinel2_dir (Path): The directory containing the Sentinel 2 scenes.
        output_data_dir (Path): The "output" directory.
        arcticdem_slope_vrt (Path): The path to the ArcticDEM slope VRT file.
        arcticdem_elevation_vrt (Path): The path to the ArcticDEM elevation VRT file.
        model_dir (Path): The path to the models to use for segmentation.
        tcvis_model_name (str, optional): The name of the model to use for TCVis. Defaults to "RTS_v6_tcvis.pt".
        notcvis_model_name (str, optional): The name of the model to use for not TCVis. Defaults to "RTS_v6_notcvis.pt".
        cache_dir (Path | None, optional): The cache directory. If None, no caching will be used. Defaults to None.
        ee_project (str, optional): The Earth Engine project ID or number to use. May be omitted if
            project is defined within persistent API credentials obtained via `earthengine authenticate`.
        patch_size (int, optional): The patch size to use for inference. Defaults to 1024.
        overlap (int, optional): The overlap to use for inference. Defaults to 16.
        batch_size (int, optional): The batch size to use for inference. Defaults to 8.
        reflection (int, optional): The reflection padding to use for inference. Defaults to 0.

    Todo:
        Document the structure of the input data dir.

    """
    # Import here to avoid long loading times when running other commands
    from darts_ensemble.ensemble_v1 import EnsembleV1
    from darts_export.inference import InferenceResultWriter
    from darts_postprocessing import prepare_export
    from darts_preprocessing import load_and_preprocess_sentinel2_scene

    from darts.utils.earthengine import init_ee

    init_ee(ee_project)

    # Find all Sentinel 2 scenes
    for fpath in sentinel2_dir.glob("*/"):
        scene_id = fpath.name
        outpath = output_data_dir / scene_id
        tile = load_and_preprocess_sentinel2_scene(fpath, arcticdem_slope_vrt, arcticdem_elevation_vrt, cache_dir)

        ensemble = EnsembleV1(model_dir / tcvis_model_name, model_dir / notcvis_model_name)
        tile = ensemble.segment_tile(
            tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )
        tile = prepare_export(tile)

        outpath.mkdir(parents=True, exist_ok=True)
        writer = InferenceResultWriter(tile)
        writer.export_probabilities(outpath)
        writer.export_binarized(outpath)
        writer.export_polygonized(outpath)

darts_acquisition

Acquisition of data from various sources for the DARTS dataset.

hello(name)

Say hello to the user.

Parameters:

Name Type Description Default
name str

Name of the user.

required

Returns:

Name Type Description
str str

Greating message.

Source code in darts-acquisition/src/darts_acquisition/__init__.py
def hello(name: str) -> str:
    """Say hello to the user.

    Args:
        name (str): Name of the user.

    Returns:
        str: Greating message.

    """
    return f"Hello, {name}, from darts-acquisition!"

darts_ensemble

Inference and model ensembling for the DARTS dataset.

hello()

Say hello to the user.

Returns:

Name Type Description
str str

Greating message.

Source code in darts-ensemble/src/darts_ensemble/__init__.py
def hello() -> str:
    """Say hello to the user.

    Returns:
        str: Greating message.

    """
    return "Hello from darts-ensemble!"

darts_export

Dataset export for the DARTS dataset.

InferenceResultWriter

Writer class to export inference result datasets.

Source code in darts-export/src/darts_export/inference.py
class InferenceResultWriter:
    """Writer class to export inference result datasets."""

    def __init__(self, ds) -> None:
        """Initialize the dataset."""
        self.ds: xarray.Dataset = ds

    def export_probabilities(self, path: Path, filename="pred_probabilities.tif", tags={}):
        """Export the probabilities layer to a file.

        Args:
            path (Path): The path where to export to.
            filename (str, optional): the filename. Defaults to "pred_probabilities.tif".
            tags (dict, optional): optional GeoTIFF metadate to be written. Defaults to no additional metadata.

        Returns:
            the Path of the written file

        """
        # write the probability layer from the raster to a GeoTiff
        file_path = path / filename
        self.ds.probabilities.rio.to_raster(file_path, driver="GTiff", tags=tags, compress="LZW")
        return file_path

    def export_binarized(self, path: Path, filename="pred_binarized.tif", tags={}):
        """Export the binarized segmentation result of the inference Result.

        Args:
            path (Path): The path where to export to.
            filename (str, optional): the filename. Defaults to "pred_binarized.tif".
            tags (dict, optional): optional GeoTIFF metadate to be written. Defaults to no additional metadata.

        Returns:
            the Path of the written file

        """
        file_path = path / filename
        self.ds.binarized_segmentation.rio.to_raster(file_path, driver="GTiff", tags=tags, compress="LZW")
        return file_path

    def export_polygonized(self, path: Path, filename_prefix="pred_segments", minimum_mapping_unit=32):
        """Export the binarized probabilities as a vector dataset in GeoPackage and GeoParquet format.

        Args:
            path (Path): The path where to export the files
            filename_prefix (str, optional): the file prefix of the exported files. Defaults to "pred_segments".
            minimum_mapping_unit (int, optional): segments covering less pixel are removed. Defaults to 32.

        """
        polygon_gdf = vectorization.vectorize(self.ds, minimum_mapping_unit=minimum_mapping_unit)

        path_gpkg = path / f"{filename_prefix}.gpkg"
        path_parquet = path / f"{filename_prefix}.parquet"

        polygon_gdf.to_file(path_gpkg, layer=filename_prefix)
        polygon_gdf.to_parquet(path_parquet)

ds: xarray.Dataset = ds instance-attribute

__init__(ds)

Initialize the dataset.

Source code in darts-export/src/darts_export/inference.py
def __init__(self, ds) -> None:
    """Initialize the dataset."""
    self.ds: xarray.Dataset = ds

export_binarized(path, filename='pred_binarized.tif', tags={})

Export the binarized segmentation result of the inference Result.

Parameters:

Name Type Description Default
path Path

The path where to export to.

required
filename str

the filename. Defaults to "pred_binarized.tif".

'pred_binarized.tif'
tags dict

optional GeoTIFF metadate to be written. Defaults to no additional metadata.

{}

Returns:

Type Description

the Path of the written file

Source code in darts-export/src/darts_export/inference.py
def export_binarized(self, path: Path, filename="pred_binarized.tif", tags={}):
    """Export the binarized segmentation result of the inference Result.

    Args:
        path (Path): The path where to export to.
        filename (str, optional): the filename. Defaults to "pred_binarized.tif".
        tags (dict, optional): optional GeoTIFF metadate to be written. Defaults to no additional metadata.

    Returns:
        the Path of the written file

    """
    file_path = path / filename
    self.ds.binarized_segmentation.rio.to_raster(file_path, driver="GTiff", tags=tags, compress="LZW")
    return file_path

export_polygonized(path, filename_prefix='pred_segments', minimum_mapping_unit=32)

Export the binarized probabilities as a vector dataset in GeoPackage and GeoParquet format.

Parameters:

Name Type Description Default
path Path

The path where to export the files

required
filename_prefix str

the file prefix of the exported files. Defaults to "pred_segments".

'pred_segments'
minimum_mapping_unit int

segments covering less pixel are removed. Defaults to 32.

32
Source code in darts-export/src/darts_export/inference.py
def export_polygonized(self, path: Path, filename_prefix="pred_segments", minimum_mapping_unit=32):
    """Export the binarized probabilities as a vector dataset in GeoPackage and GeoParquet format.

    Args:
        path (Path): The path where to export the files
        filename_prefix (str, optional): the file prefix of the exported files. Defaults to "pred_segments".
        minimum_mapping_unit (int, optional): segments covering less pixel are removed. Defaults to 32.

    """
    polygon_gdf = vectorization.vectorize(self.ds, minimum_mapping_unit=minimum_mapping_unit)

    path_gpkg = path / f"{filename_prefix}.gpkg"
    path_parquet = path / f"{filename_prefix}.parquet"

    polygon_gdf.to_file(path_gpkg, layer=filename_prefix)
    polygon_gdf.to_parquet(path_parquet)

export_probabilities(path, filename='pred_probabilities.tif', tags={})

Export the probabilities layer to a file.

Parameters:

Name Type Description Default
path Path

The path where to export to.

required
filename str

the filename. Defaults to "pred_probabilities.tif".

'pred_probabilities.tif'
tags dict

optional GeoTIFF metadate to be written. Defaults to no additional metadata.

{}

Returns:

Type Description

the Path of the written file

Source code in darts-export/src/darts_export/inference.py
def export_probabilities(self, path: Path, filename="pred_probabilities.tif", tags={}):
    """Export the probabilities layer to a file.

    Args:
        path (Path): The path where to export to.
        filename (str, optional): the filename. Defaults to "pred_probabilities.tif".
        tags (dict, optional): optional GeoTIFF metadate to be written. Defaults to no additional metadata.

    Returns:
        the Path of the written file

    """
    # write the probability layer from the raster to a GeoTiff
    file_path = path / filename
    self.ds.probabilities.rio.to_raster(file_path, driver="GTiff", tags=tags, compress="LZW")
    return file_path

darts_postprocessing

Postprocessing steps for the DARTS dataset.

darts_preprocessing

Data preprocessing and feature engineering for the DARTS dataset.

load_and_preprocess_planet_scene(planet_scene_path, slope_vrt, elevation_vrt, cache_dir=None)

Load and preprocess a Planet Scene (PSOrthoTile or PSScene) into an xr.Dataset.

Parameters:

Name Type Description Default
planet_scene_path Path

path to the Planet Scene

required
slope_vrt Path

path to the ArcticDEM slope VRT file

required
elevation_vrt Path

path to the ArcticDEM elevation VRT file

required
cache_dir Path | None

The cache directory. If None, no caching will be used. Defaults to None.

None

Returns:

Type Description
Dataset

xr.Dataset: preprocessed Planet Scene

Examples:

PS Orthotile

Data directory structure:

    data/input
    ├── ArcticDEM
       ├── elevation.vrt
       ├── slope.vrt
       ├── relative_elevation
          └── 4372514_relative_elevation_100.tif
       └── slope
           └── 4372514_slope.tif
    └── planet
        └── PSOrthoTile
            └── 4372514/5790392_4372514_2022-07-16_2459
                ├── 5790392_4372514_2022-07-16_2459_BGRN_Analytic_metadata.xml
                ├── 5790392_4372514_2022-07-16_2459_BGRN_DN_udm.tif
                ├── 5790392_4372514_2022-07-16_2459_BGRN_SR.tif
                ├── 5790392_4372514_2022-07-16_2459_metadata.json
                └── 5790392_4372514_2022-07-16_2459_udm2.tif

Load and preprocess a Planet Scene:

    from pathlib import Path
    from darts_preprocessing.preprocess import load_and_preprocess_planet_scene

    fpath = Path("data/input/planet/PSOrthoTile/4372514/5790392_4372514_2022-07-16_2459")
    arcticdem_dir = input_data_dir / "ArcticDEM"
    tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir / "slope.vrt", arcticdem_dir / "elevation.vrt")
PS Scene

Data directory structure:

    data/input
    ├── ArcticDEM
       ├── elevation.vrt
       ├── slope.vrt
       ├── relative_elevation
          └── 4372514_relative_elevation_100.tif
       └── slope
           └── 4372514_slope.tif
    └── planet
        └── PSScene
            └── 20230703_194241_43_2427
                ├── 20230703_194241_43_2427_3B_AnalyticMS_metadata.xml
                ├── 20230703_194241_43_2427_3B_AnalyticMS_SR.tif
                ├── 20230703_194241_43_2427_3B_udm2.tif
                ├── 20230703_194241_43_2427_metadata.json
                └── 20230703_194241_43_2427.json

Load and preprocess a Planet Scene:

    from pathlib import Path
    from darts_preprocessing.preprocess import load_and_preprocess_planet_scene

    fpath = Path("data/input/planet/PSOrthoTile/20230703_194241_43_2427")
    arcticdem_dir = input_data_dir / "ArcticDEM"
    tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir / "slope.vrt", arcticdem_dir / "elevation.vrt")
Source code in darts-preprocessing/src/darts_preprocessing/preprocess.py
def load_and_preprocess_planet_scene(
    planet_scene_path: Path, slope_vrt: Path, elevation_vrt: Path, cache_dir: Path | None = None
) -> xr.Dataset:
    """Load and preprocess a Planet Scene (PSOrthoTile or PSScene) into an xr.Dataset.

    Args:
        planet_scene_path (Path): path to the Planet Scene
        slope_vrt (Path): path to the ArcticDEM slope VRT file
        elevation_vrt (Path): path to the ArcticDEM elevation VRT file
        cache_dir (Path | None): The cache directory. If None, no caching will be used. Defaults to None.

    Returns:
        xr.Dataset: preprocessed Planet Scene

    Examples:
        ### PS Orthotile

        Data directory structure:

        ```sh
            data/input
            ├── ArcticDEM
            │   ├── elevation.vrt
            │   ├── slope.vrt
            │   ├── relative_elevation
            │   │   └── 4372514_relative_elevation_100.tif
            │   └── slope
            │       └── 4372514_slope.tif
            └── planet
                └── PSOrthoTile
                    └── 4372514/5790392_4372514_2022-07-16_2459
                        ├── 5790392_4372514_2022-07-16_2459_BGRN_Analytic_metadata.xml
                        ├── 5790392_4372514_2022-07-16_2459_BGRN_DN_udm.tif
                        ├── 5790392_4372514_2022-07-16_2459_BGRN_SR.tif
                        ├── 5790392_4372514_2022-07-16_2459_metadata.json
                        └── 5790392_4372514_2022-07-16_2459_udm2.tif
        ```

        Load and preprocess a Planet Scene:

        ```python
            from pathlib import Path
            from darts_preprocessing.preprocess import load_and_preprocess_planet_scene

            fpath = Path("data/input/planet/PSOrthoTile/4372514/5790392_4372514_2022-07-16_2459")
            arcticdem_dir = input_data_dir / "ArcticDEM"
            tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir / "slope.vrt", arcticdem_dir / "elevation.vrt")
        ```


        ### PS Scene

        Data directory structure:

        ```sh
            data/input
            ├── ArcticDEM
            │   ├── elevation.vrt
            │   ├── slope.vrt
            │   ├── relative_elevation
            │   │   └── 4372514_relative_elevation_100.tif
            │   └── slope
            │       └── 4372514_slope.tif
            └── planet
                └── PSScene
                    └── 20230703_194241_43_2427
                        ├── 20230703_194241_43_2427_3B_AnalyticMS_metadata.xml
                        ├── 20230703_194241_43_2427_3B_AnalyticMS_SR.tif
                        ├── 20230703_194241_43_2427_3B_udm2.tif
                        ├── 20230703_194241_43_2427_metadata.json
                        └── 20230703_194241_43_2427.json
        ```

        Load and preprocess a Planet Scene:

        ```python
            from pathlib import Path
            from darts_preprocessing.preprocess import load_and_preprocess_planet_scene

            fpath = Path("data/input/planet/PSOrthoTile/20230703_194241_43_2427")
            arcticdem_dir = input_data_dir / "ArcticDEM"
            tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir / "slope.vrt", arcticdem_dir / "elevation.vrt")
        ```

    """
    # load planet scene
    ds_planet = load_planet_scene(planet_scene_path)

    # calculate xr.dataset ndvi
    ds_ndvi = calculate_ndvi(ds_planet)

    ds_articdem = load_arcticdem(slope_vrt, elevation_vrt, ds_planet)

    ds_tcvis = load_tcvis(ds_planet, cache_dir)

    # load udm2
    ds_data_masks = load_planet_masks(planet_scene_path)

    # merge to final dataset
    ds_merged = xr.merge([ds_planet, ds_ndvi, ds_articdem, ds_tcvis, ds_data_masks])

    return ds_merged

load_and_preprocess_sentinel2_scene(s2_scene_path, slope_vrt, elevation_vrt, cache_dir=None)

Load and preprocess a Sentinel 2 scene into an xr.Dataset.

Parameters:

Name Type Description Default
s2_scene_path Path

path to the Sentinel 2 Scene

required
slope_vrt Path

path to the ArcticDEM slope VRT file

required
elevation_vrt Path

path to the ArcticDEM elevation VRT file

required
cache_dir Path | None

The cache directory. If None, no caching will be used. Defaults to None.

None

Returns:

Type Description
Dataset

xr.Dataset: preprocessed Sentinel Scene

Examples:

Data directory structure:

    data/input
    ├── ArcticDEM
       ├── elevation.vrt
       ├── slope.vrt
       ├── relative_elevation
          └── 4372514_relative_elevation_100.tif
       └── slope
           └── 4372514_slope.tif
    └── sentinel2
        └── 20220826T200911_20220826T200905_T17XMJ/
            ├── 20220826T200911_20220826T200905_T17XMJ_SCL_clip.tif
            └── 20220826T200911_20220826T200905_T17XMJ_SR_clip.tif

Load and preprocess a Sentinel Scene:

    from pathlib import Path
    from darts_preprocessing.preprocess import load_and_preprocess_sentinel2_scene

    fpath = Path("data/input/sentinel2/20220826T200911_20220826T200905_T17XMJ")
    arcticdem_dir = input_data_dir / "ArcticDEM"
    tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir / "slope.vrt", arcticdem_dir / "elevation.vrt")
Source code in darts-preprocessing/src/darts_preprocessing/preprocess.py
def load_and_preprocess_sentinel2_scene(
    s2_scene_path: Path, slope_vrt: Path, elevation_vrt: Path, cache_dir: Path | None = None
) -> xr.Dataset:
    """Load and preprocess a Sentinel 2 scene into an xr.Dataset.

    Args:
        s2_scene_path (Path): path to the Sentinel 2 Scene
        slope_vrt (Path): path to the ArcticDEM slope VRT file
        elevation_vrt (Path): path to the ArcticDEM elevation VRT file
        cache_dir (Path | None): The cache directory. If None, no caching will be used. Defaults to None.

    Returns:
        xr.Dataset: preprocessed Sentinel Scene

    Examples:
        Data directory structure:

        ```sh
            data/input
            ├── ArcticDEM
            │   ├── elevation.vrt
            │   ├── slope.vrt
            │   ├── relative_elevation
            │   │   └── 4372514_relative_elevation_100.tif
            │   └── slope
            │       └── 4372514_slope.tif
            └── sentinel2
                └── 20220826T200911_20220826T200905_T17XMJ/
                    ├── 20220826T200911_20220826T200905_T17XMJ_SCL_clip.tif
                    └── 20220826T200911_20220826T200905_T17XMJ_SR_clip.tif
        ```

        Load and preprocess a Sentinel Scene:

        ```python
            from pathlib import Path
            from darts_preprocessing.preprocess import load_and_preprocess_sentinel2_scene

            fpath = Path("data/input/sentinel2/20220826T200911_20220826T200905_T17XMJ")
            arcticdem_dir = input_data_dir / "ArcticDEM"
            tile = load_and_preprocess_planet_scene(fpath, arcticdem_dir / "slope.vrt", arcticdem_dir / "elevation.vrt")
        ```

    """
    # load planet scene
    ds_s2 = load_s2_scene(s2_scene_path)

    # calculate xr.dataset ndvi
    ds_ndvi = calculate_ndvi(ds_s2)

    ds_articdem = load_arcticdem(slope_vrt, elevation_vrt, ds_s2)

    ds_tcvis = load_tcvis(ds_s2, cache_dir)

    # load scl
    ds_data_masks = load_s2_masks(s2_scene_path)

    # merge to final dataset
    ds_merged = xr.merge([ds_s2, ds_ndvi, ds_articdem, ds_tcvis, ds_data_masks])

    return ds_merged

darts_segmentation

Image segmentation of thaw-slumps for the DARTS dataset.

SMPSegmenter

An actor that keeps a model as its state and segments tiles.

Source code in darts-segmentation/src/darts_segmentation/segment.py
class SMPSegmenter:
    """An actor that keeps a model as its state and segments tiles."""

    config: SMPSegmenterConfig
    model: nn.Module
    device: torch.device

    def __init__(self, model_checkpoint: Path | str):
        """Initialize the segmenter.

        Args:
            model_checkpoint (Path): The path to the model checkpoint.

        """
        self.device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")
        ckpt = torch.load(model_checkpoint, map_location=self.device)
        self.config = validate_config(ckpt["config"])
        self.model = smp.create_model(**self.config["model"], encoder_weights=None)
        self.model.to(self.device)
        self.model.load_state_dict(ckpt["statedict"])
        self.model.eval()

    def tile2tensor(self, tile: xr.Dataset) -> torch.Tensor:
        """Take a tile and convert it to a pytorch tensor.

        Respects the input combination from the config.

        Returns:
            A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.

        """
        bands = []
        # e.g. input_combination: ["red", "green", "blue", "relative_elevation", ...]
        # tile.data_vars: ["red", "green", "blue", "relative_elevation", ...]

        for feature_name in self.config["input_combination"]:
            norm = self.config["norm_factors"][feature_name]
            band_data = tile[feature_name]
            # Normalize the band data
            band_data = band_data * norm
            bands.append(torch.from_numpy(band_data.to_numpy().astype("float32")))

        return torch.stack(bands, dim=0)

    def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor:
        """Take a list of tiles and convert them to a pytorch tensor.

        Respects the the input combination from the config.

        Returns:
            A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.

        """
        bands = []
        for feature_name in self.config["input_combination"]:
            norm = self.config["norm_factors"][feature_name]
            for tile in tiles:
                band_data = tile[feature_name]
                # Normalize the band data
                band_data = band_data * norm
                bands.append(torch.from_numpy(band_data.to_numpy().astype("float32")))
        # TODO: Test this
        return torch.stack(bands, dim=0).reshape(len(tiles), len(self.config["input_combination"]), *bands[0].shape)

    def segment_tile(
        self, tile: xr.Dataset, patch_size: int = 1024, overlap: int = 16, batch_size: int = 8, reflection: int = 0
    ) -> xr.Dataset:
        """Run inference on a tile.

        Args:
            tile: The input tile, containing preprocessed, harmonized data.
            patch_size (int): The size of the patches. Defaults to 1024.
            overlap (int): The size of the overlap. Defaults to 16.
            batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
            reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

        Returns:
            Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

        """
        # Convert the tile to a tensor
        tensor_tile = self.tile2tensor(tile)

        # Create a batch dimension, because predict expects it
        tensor_tile = tensor_tile.unsqueeze(0)

        probabilities = predict_in_patches(
            self.model, tensor_tile, patch_size, overlap, batch_size, reflection, self.device
        ).squeeze(0)

        # Highly sophisticated DL-based predictor
        # TODO: is there a better way to pass metadata?
        tile["probabilities"] = tile["red"].copy(data=probabilities.cpu().numpy())
        tile["probabilities"].attrs = {
            "long_name": "Probabilities",
        }
        tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan"))
        return tile

    def segment_tile_batched(
        self,
        tiles: list[xr.Dataset],
        patch_size: int = 1024,
        overlap: int = 16,
        batch_size: int = 8,
        reflection: int = 0,
    ) -> list[xr.Dataset]:
        """Run inference on a list of tiles.

        Args:
            tiles: The input tiles, containing preprocessed, harmonized data.
            patch_size (int): The size of the patches. Defaults to 1024.
            overlap (int): The size of the overlap. Defaults to 16.
            batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
            reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

        Returns:
            A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

        """
        # Convert the tiles to tensors
        # TODO: maybe create a batched tile2tensor function?
        # tensor_tiles = [self.tile2tensor(tile).to(self.dev) for tile in tiles]
        tensor_tiles = self.tile2tensor_batched(tiles)

        # Create a batch dimension, because predict expects it
        tensor_tiles = torch.stack(tensor_tiles, dim=0)

        probabilities = predict_in_patches(
            self.model, tensor_tiles, patch_size, overlap, batch_size, reflection, self.device
        )

        # Highly sophisticated DL-based predictor
        for tile, probs in zip(tiles, probabilities):
            # TODO: is there a better way to pass metadata?
            tile["probabilities"] = tile["red"].copy(data=probs.cpu().numpy())
            tile["probabilities"].attrs = {
                "long_name": "Probabilities",
            }
            tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan"))
        return tiles

    def __call__(
        self,
        input: xr.Dataset | list[xr.Dataset],
        patch_size: int = 1024,
        overlap: int = 16,
        batch_size: int = 8,
        reflection: int = 0,
    ) -> xr.Dataset | list[xr.Dataset]:
        """Run inference on a single tile or a list of tiles.

        Args:
            input (xr.Dataset | list[xr.Dataset]): A single tile or a list of tiles.
            patch_size (int): The size of the patches. Defaults to 1024.
            overlap (int): The size of the overlap. Defaults to 16.
            batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
            reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

        Returns:
            A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input.
            Each `probability` has type float32 and range [0, 1].

        Raises:
            ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset

        """
        if isinstance(input, xr.Dataset):
            return self.segment_tile(
                input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
            )
        elif isinstance(input, list):
            return self.segment_tile_batched(
                input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
            )
        else:
            raise ValueError(f"Expected xr.Dataset or list of xr.Dataset, got {type(input)}")

config: SMPSegmenterConfig = validate_config(ckpt['config']) instance-attribute

device: torch.device = torch.device('cpu') if not torch.cuda.is_available() else torch.device('cuda') instance-attribute

model: nn.Module = smp.create_model(**self.config['model'], encoder_weights=None) instance-attribute

__call__(input, patch_size=1024, overlap=16, batch_size=8, reflection=0)

Run inference on a single tile or a list of tiles.

Parameters:

Name Type Description Default
input Dataset | list[Dataset]

A single tile or a list of tiles.

required
patch_size int

The size of the patches. Defaults to 1024.

1024
overlap int

The size of the overlap. Defaults to 16.

16
batch_size int

The batch size for the prediction, NOT the batch_size of input tiles.

8
reflection int

Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

0

Returns:

Type Description
Dataset | list[Dataset]

A single tile or a list of tiles augmented by a predicted probabilities layer, depending on the input.

Dataset | list[Dataset]

Each probability has type float32 and range [0, 1].

Raises:

Type Description
ValueError

in case the input is not an xr.Dataset or a list of xr.Dataset

Source code in darts-segmentation/src/darts_segmentation/segment.py
def __call__(
    self,
    input: xr.Dataset | list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xr.Dataset | list[xr.Dataset]:
    """Run inference on a single tile or a list of tiles.

    Args:
        input (xr.Dataset | list[xr.Dataset]): A single tile or a list of tiles.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
        Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

    Returns:
        A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input.
        Each `probability` has type float32 and range [0, 1].

    Raises:
        ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset

    """
    if isinstance(input, xr.Dataset):
        return self.segment_tile(
            input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )
    elif isinstance(input, list):
        return self.segment_tile_batched(
            input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )
    else:
        raise ValueError(f"Expected xr.Dataset or list of xr.Dataset, got {type(input)}")

__init__(model_checkpoint)

Initialize the segmenter.

Parameters:

Name Type Description Default
model_checkpoint Path

The path to the model checkpoint.

required
Source code in darts-segmentation/src/darts_segmentation/segment.py
def __init__(self, model_checkpoint: Path | str):
    """Initialize the segmenter.

    Args:
        model_checkpoint (Path): The path to the model checkpoint.

    """
    self.device = torch.device("cpu") if not torch.cuda.is_available() else torch.device("cuda")
    ckpt = torch.load(model_checkpoint, map_location=self.device)
    self.config = validate_config(ckpt["config"])
    self.model = smp.create_model(**self.config["model"], encoder_weights=None)
    self.model.to(self.device)
    self.model.load_state_dict(ckpt["statedict"])
    self.model.eval()

segment_tile(tile, patch_size=1024, overlap=16, batch_size=8, reflection=0)

Run inference on a tile.

Parameters:

Name Type Description Default
tile Dataset

The input tile, containing preprocessed, harmonized data.

required
patch_size int

The size of the patches. Defaults to 1024.

1024
overlap int

The size of the overlap. Defaults to 16.

16
batch_size int

The batch size for the prediction, NOT the batch_size of input tiles.

8
reflection int

Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

0

Returns:

Type Description
Dataset

Input tile augmented by a predicted probabilities layer with type float32 and range [0, 1].

Source code in darts-segmentation/src/darts_segmentation/segment.py
def segment_tile(
    self, tile: xr.Dataset, patch_size: int = 1024, overlap: int = 16, batch_size: int = 8, reflection: int = 0
) -> xr.Dataset:
    """Run inference on a tile.

    Args:
        tile: The input tile, containing preprocessed, harmonized data.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
        Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

    Returns:
        Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

    """
    # Convert the tile to a tensor
    tensor_tile = self.tile2tensor(tile)

    # Create a batch dimension, because predict expects it
    tensor_tile = tensor_tile.unsqueeze(0)

    probabilities = predict_in_patches(
        self.model, tensor_tile, patch_size, overlap, batch_size, reflection, self.device
    ).squeeze(0)

    # Highly sophisticated DL-based predictor
    # TODO: is there a better way to pass metadata?
    tile["probabilities"] = tile["red"].copy(data=probabilities.cpu().numpy())
    tile["probabilities"].attrs = {
        "long_name": "Probabilities",
    }
    tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan"))
    return tile

segment_tile_batched(tiles, patch_size=1024, overlap=16, batch_size=8, reflection=0)

Run inference on a list of tiles.

Parameters:

Name Type Description Default
tiles list[Dataset]

The input tiles, containing preprocessed, harmonized data.

required
patch_size int

The size of the patches. Defaults to 1024.

1024
overlap int

The size of the overlap. Defaults to 16.

16
batch_size int

The batch size for the prediction, NOT the batch_size of input tiles.

8
reflection int

Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

0

Returns:

Type Description
list[Dataset]

A list of input tiles augmented by a predicted probabilities layer with type float32 and range [0, 1].

Source code in darts-segmentation/src/darts_segmentation/segment.py
def segment_tile_batched(
    self,
    tiles: list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> list[xr.Dataset]:
    """Run inference on a list of tiles.

    Args:
        tiles: The input tiles, containing preprocessed, harmonized data.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
        Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

    Returns:
        A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

    """
    # Convert the tiles to tensors
    # TODO: maybe create a batched tile2tensor function?
    # tensor_tiles = [self.tile2tensor(tile).to(self.dev) for tile in tiles]
    tensor_tiles = self.tile2tensor_batched(tiles)

    # Create a batch dimension, because predict expects it
    tensor_tiles = torch.stack(tensor_tiles, dim=0)

    probabilities = predict_in_patches(
        self.model, tensor_tiles, patch_size, overlap, batch_size, reflection, self.device
    )

    # Highly sophisticated DL-based predictor
    for tile, probs in zip(tiles, probabilities):
        # TODO: is there a better way to pass metadata?
        tile["probabilities"] = tile["red"].copy(data=probs.cpu().numpy())
        tile["probabilities"].attrs = {
            "long_name": "Probabilities",
        }
        tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan"))
    return tiles

tile2tensor(tile)

Take a tile and convert it to a pytorch tensor.

Respects the input combination from the config.

Returns:

Type Description
Tensor

A torch tensor for the full tile consisting of the bands specified in self.band_combination.

Source code in darts-segmentation/src/darts_segmentation/segment.py
def tile2tensor(self, tile: xr.Dataset) -> torch.Tensor:
    """Take a tile and convert it to a pytorch tensor.

    Respects the input combination from the config.

    Returns:
        A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.

    """
    bands = []
    # e.g. input_combination: ["red", "green", "blue", "relative_elevation", ...]
    # tile.data_vars: ["red", "green", "blue", "relative_elevation", ...]

    for feature_name in self.config["input_combination"]:
        norm = self.config["norm_factors"][feature_name]
        band_data = tile[feature_name]
        # Normalize the band data
        band_data = band_data * norm
        bands.append(torch.from_numpy(band_data.to_numpy().astype("float32")))

    return torch.stack(bands, dim=0)

tile2tensor_batched(tiles)

Take a list of tiles and convert them to a pytorch tensor.

Respects the the input combination from the config.

Returns:

Type Description
Tensor

A torch tensor for the full tile consisting of the bands specified in self.band_combination.

Source code in darts-segmentation/src/darts_segmentation/segment.py
def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor:
    """Take a list of tiles and convert them to a pytorch tensor.

    Respects the the input combination from the config.

    Returns:
        A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.

    """
    bands = []
    for feature_name in self.config["input_combination"]:
        norm = self.config["norm_factors"][feature_name]
        for tile in tiles:
            band_data = tile[feature_name]
            # Normalize the band data
            band_data = band_data * norm
            bands.append(torch.from_numpy(band_data.to_numpy().astype("float32")))
    # TODO: Test this
    return torch.stack(bands, dim=0).reshape(len(tiles), len(self.config["input_combination"]), *bands[0].shape)

SMPSegmenterConfig

Bases: TypedDict

Configuration for the segmentor.

Source code in darts-segmentation/src/darts_segmentation/segment.py
class SMPSegmenterConfig(TypedDict):
    """Configuration for the segmentor."""

    input_combination: list[str]
    model: dict[str, Any]
    norm_factors: dict[str, float]

input_combination: list[str] instance-attribute

model: dict[str, Any] instance-attribute

norm_factors: dict[str, float] instance-attribute

create_patches(tensor_tiles, patch_size, overlap, return_coords=False)

Create patches from a tensor.

Parameters:

Name Type Description Default
tensor_tiles Tensor

The input tensor. Shape: (BS, C, H, W).

required
patch_size int

The size of the patches.

required
overlap int

The size of the overlap.

required
return_coords bool

Whether to return the coordinates of the patches. Can be used for debugging. Defaults to False.

False

Returns:

Type Description
Tensor

torch.Tensor: The patches. Shape: (BS, N_h, N_w, C, patch_size, patch_size).

Source code in darts-segmentation/src/darts_segmentation/utils.py
@torch.no_grad()
def create_patches(
    tensor_tiles: torch.Tensor, patch_size: int, overlap: int, return_coords: bool = False
) -> torch.Tensor:
    """Create patches from a tensor.

    Args:
        tensor_tiles (torch.Tensor): The input tensor. Shape: (BS, C, H, W).
        patch_size (int, optional): The size of the patches.
        overlap (int, optional): The size of the overlap.
        return_coords (bool, optional): Whether to return the coordinates of the patches.
            Can be used for debugging. Defaults to False.

    Returns:
        torch.Tensor: The patches. Shape: (BS, N_h, N_w, C, patch_size, patch_size).

    """
    start_time = time.time()
    logger.debug(
        f"Creating patches from a tensor with shape {tensor_tiles.shape} "
        f"with patch_size {patch_size} and overlap {overlap}"
    )
    assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
    bs, c, h, w = tensor_tiles.shape
    assert h > patch_size > overlap
    assert w > patch_size > overlap

    step_size = patch_size - overlap

    # The problem with unfold is that is cuts off the last patch if it doesn't fit exactly
    # Padding could help, but then the next problem is that the view needs to get reshaped (copied in memory)
    # to fit the model input shape. Such a complex view can't be inserted into the model.
    # Since we need, doing it manually is currently our best choice, since be can avoid the padding.
    # patches = (
    #     tensor_tiles.unfold(2, patch_size, step_size).unfold(3, patch_size, step_size).transpose(1, 2).transpose(2, 3)
    # )
    # return patches

    nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)
    # Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
    patches = torch.zeros((bs, nh, nw, c, patch_size, patch_size), device=tensor_tiles.device)
    coords = torch.zeros((nh, nw, 5))
    for i, (y, x, patch_idx_h, patch_idx_w) in enumerate(patch_coords(h, w, patch_size, overlap)):
        patches[:, patch_idx_h, patch_idx_w, :] = tensor_tiles[:, :, y : y + patch_size, x : x + patch_size]
        coords[patch_idx_h, patch_idx_w, :] = torch.tensor([i, y, x, patch_idx_h, patch_idx_w])

    logger.debug(f"Creating {nh * nw} patches took {time.time() - start_time:.2f}s")
    if return_coords:
        return patches, coords
    else:
        return patches

patch_coords(h, w, patch_size, overlap)

Yield patch coordinates based on height, width, patch size and margin size.

Parameters:

Name Type Description Default
h int

Height of the image.

required
w int

Width of the image.

required
patch_size int

Patch size.

required
overlap int

Margin size.

required

Yields:

Type Description
tuple[int, int, int, int]

tuple[int, int, int, int]: The patch coordinates y, x, patch_idx_y and patch_idx_x.

Source code in darts-segmentation/src/darts_segmentation/utils.py
def patch_coords(h: int, w: int, patch_size: int, overlap: int) -> Generator[tuple[int, int, int, int], None, None]:
    """Yield patch coordinates based on height, width, patch size and margin size.

    Args:
        h (int): Height of the image.
        w (int): Width of the image.
        patch_size (int): Patch size.
        overlap (int): Margin size.

    Yields:
        tuple[int, int, int, int]: The patch coordinates y, x, patch_idx_y and patch_idx_x.

    """
    step_size = patch_size - overlap
    # Substract the overlap from h and w so that an exact match of the last patch won't create a duplicate
    for patch_idx_y, y in enumerate(range(0, h - overlap, step_size)):
        for patch_idx_x, x in enumerate(range(0, w - overlap, step_size)):
            if y + patch_size > h:
                y = h - patch_size
            if x + patch_size > w:
                x = w - patch_size
            yield y, x, patch_idx_y, patch_idx_x

predict_in_patches(model, tensor_tiles, patch_size, overlap, batch_size, reflection, device=torch.device, return_weights=False)

Predict on a tensor.

Parameters:

Name Type Description Default
model Module

The model to use for prediction.

required
tensor_tiles Tensor

The input tensor. Shape: (BS, C, H, W).

required
patch_size int

The size of the patches.

required
overlap int

The size of the overlap.

required
batch_size int

The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches.

required
reflection int

Reflection-Padding which will be applied to the edges of the tensor.

required
device device

The device to use for the prediction.

device
return_weights bool

Whether to return the weights. Can be used for debugging. Defaults to False.

False

Returns:

Type Description
Tensor

The predicted tensor.

Source code in darts-segmentation/src/darts_segmentation/utils.py
@torch.no_grad()
def predict_in_patches(
    model: nn.Module,
    tensor_tiles: torch.Tensor,
    patch_size: int,
    overlap: int,
    batch_size: int,
    reflection: int,
    device=torch.device,
    return_weights: bool = False,
) -> torch.Tensor:
    """Predict on a tensor.

    Args:
        model: The model to use for prediction.
        tensor_tiles: The input tensor. Shape: (BS, C, H, W).
        patch_size (int): The size of the patches.
        overlap (int): The size of the overlap.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor.
        device (torch.device): The device to use for the prediction.
        return_weights (bool, optional): Whether to return the weights. Can be used for debugging. Defaults to False.

    Returns:
        The predicted tensor.

    """
    start_time = time.time()
    logger.debug(
        f"Predicting on a tensor with shape {tensor_tiles.shape} "
        f"with patch_size {patch_size}, overlap {overlap} and batch_size {batch_size} on device {device}"
    )
    assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
    # Add a 1px + reflection border to avoid pixel loss when applying the soft margin and to reduce edge-artefacts
    p = 1 + reflection
    tensor_tiles = torch.nn.functional.pad(tensor_tiles, (p, p, p, p), mode="reflect")
    bs, c, h, w = tensor_tiles.shape
    step_size = patch_size - overlap
    nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)

    # Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
    patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap)

    # Flatten the patches so they fit to the model
    # (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size)
    patches = patches.view(bs * nh * nw, c, patch_size, patch_size)

    # Create a soft margin for the patches
    margin_ramp = torch.cat(
        [
            torch.linspace(0, 1, overlap),
            torch.ones(patch_size - 2 * overlap),
            torch.linspace(1, 0, overlap),
        ]
    )
    soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1)
    soft_margin = soft_margin.to(patches.device)

    # Infer logits with model and turn into probabilities with sigmoid in a batched manner
    # TODO: check with ingmar and jonas if moving all patches to the device at the same time is a good idea
    patched_probabilities = torch.zeros_like(patches[:, 0, :, :])
    patches = patches.split(batch_size)
    for i, batch in track(enumerate(patches), total=len(patches), description="Predicting on patches"):
        batch = batch.to(device)
        # logger.debug(f"Predicting on batch {i + 1}/{len(patches)}")
        patched_probabilities[i * batch_size : (i + 1) * batch_size] = (
            torch.sigmoid(model(batch)).squeeze(1).to(patched_probabilities.device)
        )
        batch = batch.to(patched_probabilities.device)  # Transfer back to the original device to avoid memory leaks

    patched_probabilities = patched_probabilities.view(bs, nh, nw, patch_size, patch_size)

    # Reconstruct the image from the patches
    prediction = torch.zeros(bs, h, w, device=tensor_tiles.device)
    weights = torch.zeros(bs, h, w, device=tensor_tiles.device)

    for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, overlap):
        patch = patched_probabilities[:, patch_idx_h, patch_idx_w]
        prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin
        weights[:, y : y + patch_size, x : x + patch_size] += soft_margin

    # Avoid division by zero
    weights = torch.where(weights == 0, torch.ones_like(weights), weights)
    prediction = prediction / weights

    # Remove the 1px border and the padding
    prediction = prediction[:, p:-p, p:-p]
    logger.debug(f"Predicting took {time.time() - start_time:.2f}s")

    if return_weights:
        return prediction, weights
    else:
        return prediction

darts_superresolution

Image superresolution of Sentinel 2 imagery for the DARTS dataset.

hello()

Say hello to the user.

Returns:

Name Type Description
str str

Greating message.

Source code in darts-superresolution/src/darts_superresolution/__init__.py
def hello() -> str:
    """Say hello to the user.

    Returns:
        str: Greating message.

    """
    return "Hello from darts-superresolution!"