Skip to content

darts_ensemble.EnsembleV1

DARTS v1 ensemble based on a list of models.

Initialize the ensemble.

Parameters:

  • model_dict (dict) –

    The paths to model checkpoints to ensemble, the key is should be a model identifier to be written to outputs.

  • device (torch.device, default: darts_ensemble.ensemble_v1.DEFAULT_DEVICE ) –

    The device to run the model on. Defaults to torch.device("cuda") if cuda is available, else torch.device("cpu").

Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
def __init__(
    self,
    model_dict,
    device: torch.device = DEFAULT_DEVICE,
):
    """Initialize the ensemble.

    Args:
        model_dict (dict): The paths to model checkpoints to ensemble, the key is should be a model identifier
            to be written to outputs.
        device (torch.device): The device to run the model on.
            Defaults to torch.device("cuda") if cuda is available, else torch.device("cpu").

    """
    model_paths = {k: Path(v) for k, v in model_dict.items()}
    logger.debug(
        "Loading models:\n" + "\n".join([f" - {k.upper()} model: {v.resolve()}" for k, v in model_paths.items()])
    )
    self.models = {k: SMPSegmenter(v, device=device) for k, v in model_paths.items()}

models instance-attribute

models = {
    k: darts_segmentation.segment.SMPSegmenter(
        v,
        device=darts_ensemble.ensemble_v1.EnsembleV1(
            device
        ),
    )
    for (k, v) in model_paths.items()
}

__call__

__call__(
    input: xarray.Dataset | list[xarray.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> xarray.Dataset

Run the ensemble on the given tile.

Parameters:

  • input (xarray.Dataset | list[xarray.Dataset]) –

    A single tile or a list of tiles.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

  • keep_inputs (bool, default: False ) –

    Whether to keep the input probabilities in the output. Defaults to False.

Returns:

  • xarray.Dataset

    xr.Dataset: Output tile with the ensemble applied.

Raises:

  • ValueError

    in case the input is not an xr.Dataset or a list of xr.Dataset

Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
def __call__(
    self,
    input: xr.Dataset | list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> xr.Dataset:
    """Run the ensemble on the given tile.

    Args:
        input (xr.Dataset | list[xr.Dataset]): A single tile or a list of tiles.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
        keep_inputs (bool, optional): Whether to keep the input probabilities in the output. Defaults to False.

    Returns:
        xr.Dataset: Output tile with the ensemble applied.

    Raises:
        ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset

    """
    if isinstance(input, xr.Dataset):
        return self.segment_tile(
            input,
            patch_size=patch_size,
            overlap=overlap,
            batch_size=batch_size,
            reflection=reflection,
            keep_inputs=keep_inputs,
        )
    elif isinstance(input, list):
        return self.segment_tile_batched(
            input,
            patch_size=patch_size,
            overlap=overlap,
            batch_size=batch_size,
            reflection=reflection,
            keep_inputs=keep_inputs,
        )
    else:
        raise ValueError("Input must be an xr.Dataset or a list of xr.Dataset.")

segment_tile

segment_tile(
    tile: xarray.Dataset,
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> xarray.Dataset

Run inference on a tile.

Parameters:

  • tile (xarray.Dataset) –

    The input tile, containing preprocessed, harmonized data.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

  • keep_inputs (bool, default: False ) –

    Whether to keep the input probabilities in the output. Defaults to False.

Returns:

  • xarray.Dataset

    Input tile augmented by a predicted probabilities layer with type float32 and range [0, 1].

Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
@stopuhr.funkuhr(
    "Ensemble inference",
    printer=logger.debug,
    print_kwargs=["patch_size", "overlap", "batch_size", "reflection", "keep_inputs"],
)
def segment_tile(
    self,
    tile: xr.Dataset,
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> xr.Dataset:
    """Run inference on a tile.

    Args:
        tile: The input tile, containing preprocessed, harmonized data.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
        keep_inputs (bool, optional): Whether to keep the input probabilities in the output. Defaults to False.

    Returns:
        Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

    """
    probabilities = {}
    for model_name, model in self.models.items():
        probabilities[model_name] = model.segment_tile(
            tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )["probabilities"].copy()

    # calculate the mean
    tile["probabilities"] = xr.concat(probabilities.values(), dim="model_probs").mean(dim="model_probs")

    if keep_inputs:
        for k, v in probabilities.items():
            tile[f"probabilities-{k}"] = v

    return tile

segment_tile_batched

segment_tile_batched(
    tiles: list[xarray.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> list[xarray.Dataset]

Run inference on a list of tiles.

Parameters:

  • tiles (list[xarray.Dataset]) –

    The input tiles, containing preprocessed, harmonized data.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

  • keep_inputs (bool, default: False ) –

    Whether to keep the input probabilities in the output. Defaults to False.

Returns:

  • list[xarray.Dataset]

    A list of input tiles augmented by a predicted probabilities layer with type float32 and range [0, 1].

Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
def segment_tile_batched(
    self,
    tiles: list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> list[xr.Dataset]:
    """Run inference on a list of tiles.

    Args:
        tiles: The input tiles, containing preprocessed, harmonized data.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
        keep_inputs (bool, optional): Whether to keep the input probabilities in the output. Defaults to False.

    Returns:
        A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

    """
    return [
        self.segment_tile(
            tile,
            patch_size=patch_size,
            overlap=overlap,
            batch_size=batch_size,
            reflection=reflection,
            keep_inputs=keep_inputs,
        )
        for tile in tiles
    ]