Skip to content

ensemble_v1

darts_ensemble.ensemble_v1

DARTS v1 ensemble based on two models, one trained with TCVIS data and the other without.

DEFAULT_DEVICE module-attribute

DEFAULT_DEVICE = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

logger module-attribute

logger = logging.getLogger(
    __name__.replace("darts_", "darts.")
)

EnsembleV1

EnsembleV1(
    model_dict,
    device: torch.device = darts_ensemble.ensemble_v1.DEFAULT_DEVICE,
)

DARTS v1 ensemble based on a list of models.

Initialize the ensemble.

Parameters:

  • model_dict (dict) –

    The paths to model checkpoints to ensemble, the key is should be a model identifier to be written to outputs.

  • device (torch.device, default: darts_ensemble.ensemble_v1.DEFAULT_DEVICE ) –

    The device to run the model on. Defaults to torch.device("cuda") if cuda is available, else torch.device("cpu").

Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
def __init__(
    self,
    model_dict,
    device: torch.device = DEFAULT_DEVICE,
):
    """Initialize the ensemble.

    Args:
        model_dict (dict): The paths to model checkpoints to ensemble, the key is should be a model identifier
            to be written to outputs.
        device (torch.device): The device to run the model on.
            Defaults to torch.device("cuda") if cuda is available, else torch.device("cpu").

    """
    model_paths = {k: Path(v) for k, v in model_dict.items()}
    logger.debug(
        "Loading models:\n" + "\n".join([f" - {k.upper()} model: {v.resolve()}" for k, v in model_paths.items()])
    )
    self.models = {k: SMPSegmenter(v, device=device) for k, v in model_paths.items()}

models instance-attribute

models = {
    k: (
        darts_segmentation.segment.SMPSegmenter(
            v,
            device=darts_ensemble.ensemble_v1.EnsembleV1(
                device
            ),
        )
    )
    for (k, v) in (model_paths.items())
}

__call__

__call__(
    input: xarray.Dataset | list[xarray.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> xarray.Dataset

Run the ensemble on the given tile.

Parameters:

  • input (xarray.Dataset | list[xarray.Dataset]) –

    A single tile or a list of tiles.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

  • keep_inputs (bool, default: False ) –

    Whether to keep the input probabilities in the output. Defaults to False.

Returns:

  • xarray.Dataset

    xr.Dataset: Output tile with the ensemble applied.

Raises:

  • ValueError

    in case the input is not an xr.Dataset or a list of xr.Dataset

Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
def __call__(
    self,
    input: xr.Dataset | list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> xr.Dataset:
    """Run the ensemble on the given tile.

    Args:
        input (xr.Dataset | list[xr.Dataset]): A single tile or a list of tiles.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
        keep_inputs (bool, optional): Whether to keep the input probabilities in the output. Defaults to False.

    Returns:
        xr.Dataset: Output tile with the ensemble applied.

    Raises:
        ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset

    """
    if isinstance(input, xr.Dataset):
        return self.segment_tile(
            input,
            patch_size=patch_size,
            overlap=overlap,
            batch_size=batch_size,
            reflection=reflection,
            keep_inputs=keep_inputs,
        )
    elif isinstance(input, list):
        return self.segment_tile_batched(
            input,
            patch_size=patch_size,
            overlap=overlap,
            batch_size=batch_size,
            reflection=reflection,
            keep_inputs=keep_inputs,
        )
    else:
        raise ValueError("Input must be an xr.Dataset or a list of xr.Dataset.")

segment_tile

segment_tile(
    tile: xarray.Dataset,
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> xarray.Dataset

Run inference on a tile.

Parameters:

  • tile (xarray.Dataset) –

    The input tile, containing preprocessed, harmonized data.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

  • keep_inputs (bool, default: False ) –

    Whether to keep the input probabilities in the output. Defaults to False.

Returns:

  • xarray.Dataset

    Input tile augmented by a predicted probabilities layer with type float32 and range [0, 1].

Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
@stopwatch.f(
    "Ensemble inference",
    printer=logger.debug,
    print_kwargs=["patch_size", "overlap", "batch_size", "reflection", "keep_inputs"],
)
def segment_tile(
    self,
    tile: xr.Dataset,
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> xr.Dataset:
    """Run inference on a tile.

    Args:
        tile: The input tile, containing preprocessed, harmonized data.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
        keep_inputs (bool, optional): Whether to keep the input probabilities in the output. Defaults to False.

    Returns:
        Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

    """
    probabilities = {}
    for model_name, model in self.models.items():
        probabilities[model_name] = model.segment_tile(
            tile, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )["probabilities"].copy()

    # calculate the mean
    tile["probabilities"] = xr.concat(probabilities.values(), dim="model_probs").mean(dim="model_probs")

    if keep_inputs:
        for k, v in probabilities.items():
            tile[f"probabilities-{k}"] = v

    return tile

segment_tile_batched

segment_tile_batched(
    tiles: list[xarray.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> list[xarray.Dataset]

Run inference on a list of tiles.

Parameters:

  • tiles (list[xarray.Dataset]) –

    The input tiles, containing preprocessed, harmonized data.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

  • keep_inputs (bool, default: False ) –

    Whether to keep the input probabilities in the output. Defaults to False.

Returns:

  • list[xarray.Dataset]

    A list of input tiles augmented by a predicted probabilities layer with type float32 and range [0, 1].

Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
def segment_tile_batched(
    self,
    tiles: list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
    keep_inputs: bool = False,
) -> list[xr.Dataset]:
    """Run inference on a list of tiles.

    Args:
        tiles: The input tiles, containing preprocessed, harmonized data.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
        keep_inputs (bool, optional): Whether to keep the input probabilities in the output. Defaults to False.

    Returns:
        A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

    """
    return [
        self.segment_tile(
            tile,
            patch_size=patch_size,
            overlap=overlap,
            batch_size=batch_size,
            reflection=reflection,
            keep_inputs=keep_inputs,
        )
        for tile in tiles
    ]

SMPSegmenter

SMPSegmenter(
    model_checkpoint: pathlib.Path | str,
    device: torch.device = darts_segmentation.segment.DEFAULT_DEVICE,
)

An actor that keeps a model as its state and segments tiles.

Initialize the segmenter.

Parameters:

Source code in darts-segmentation/src/darts_segmentation/segment.py
def __init__(self, model_checkpoint: Path | str, device: torch.device = DEFAULT_DEVICE):
    """Initialize the segmenter.

    Args:
        model_checkpoint (Path): The path to the model checkpoint.
        device (torch.device): The device to run the model on.
            Defaults to torch.device("cuda") if cuda is available, else torch.device("cpu").

    """
    if isinstance(model_checkpoint, str):
        model_checkpoint = Path(model_checkpoint)
    self.device = device
    ckpt = torch.load(model_checkpoint, map_location=self.device, weights_only=False)
    self.config = SMPSegmenterConfig.from_ckpt(ckpt)
    # Overwrite the encoder weights with None, because we load our own
    self.config["model"] |= {"encoder_weights": None}
    self.model = smp.create_model(**self.config["model"])
    self.model.to(self.device)

    # Legacy version
    if "statedict" in ckpt.keys():
        statedict = ckpt["statedict"]
    else:
        statedict = ckpt["state_dict"]
        # Lightning Checkpoints are prefixed with "model." -> we need to remove them. This is an in-place function
        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(statedict, "model.")
    self.model.load_state_dict(statedict)
    self.model.eval()

    logger.debug(f"Successfully loaded model from {model_checkpoint.resolve()} with inputs: {self.config['bands']}")

device instance-attribute

model instance-attribute

model: torch.nn.Module = (
    segmentation_models_pytorch.create_model(
        **(
            darts_segmentation.segment.SMPSegmenter(
                self
            ).config["model"]
        )
    )
)

__call__

__call__(
    input: xarray.Dataset | list[xarray.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xarray.Dataset | list[xarray.Dataset]

Run inference on a single tile or a list of tiles.

Parameters:

  • input (xarray.Dataset | list[xarray.Dataset]) –

    A single tile or a list of tiles.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

Returns:

Raises:

  • ValueError

    in case the input is not an xr.Dataset or a list of xr.Dataset

Source code in darts-segmentation/src/darts_segmentation/segment.py
def __call__(
    self,
    input: xr.Dataset | list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xr.Dataset | list[xr.Dataset]:
    """Run inference on a single tile or a list of tiles.

    Args:
        input (xr.Dataset | list[xr.Dataset]): A single tile or a list of tiles.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

    Returns:
        A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input.
        Each `probability` has type float32 and range [0, 1].

    Raises:
        ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset

    """
    if isinstance(input, xr.Dataset):
        return self.segment_tile(
            input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )
    elif isinstance(input, list):
        return NotImplementedError("Currently passing multiple datasets at once is not supported.")
    else:
        raise ValueError(f"Expected xr.Dataset or list of xr.Dataset, got {type(input)}")

segment_tile

segment_tile(
    tile: xarray.Dataset,
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xarray.Dataset

Run inference on a tile.

Parameters:

  • tile (xarray.Dataset) –

    The input tile, containing preprocessed, harmonized data.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

Returns:

  • xarray.Dataset

    Input tile augmented by a predicted probabilities layer with type float32 and range [0, 1].

Source code in darts-segmentation/src/darts_segmentation/segment.py
@stopwatch.f(
    "Segmenting tile",
    printer=logger.debug,
    print_kwargs=["patch_size", "overlap", "batch_size", "reflection"],
)
def segment_tile(
    self, tile: xr.Dataset, patch_size: int = 1024, overlap: int = 16, batch_size: int = 8, reflection: int = 0
) -> xr.Dataset:
    """Run inference on a tile.

    Args:
        tile: The input tile, containing preprocessed, harmonized data.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

    Returns:
        Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

    """
    # Convert the tile to a tensor
    tile = tile[self.config["bands"]].transpose("y", "x")
    tile = manager.normalize(tile)
    # ? The heavy operation is .to_dataarray()
    tensor_tile = torch.as_tensor(tile.to_dataarray().data)

    # Create a batch dimension, because predict expects it
    tensor_tile = tensor_tile.unsqueeze(0)

    probabilities = predict_in_patches(
        self.model, tensor_tile, patch_size, overlap, batch_size, reflection, self.device
    ).squeeze(0)

    # Highly sophisticated DL-based predictor
    # TODO: is there a better way to pass metadata?
    tile["probabilities"] = tile["red"].copy(data=probabilities.cpu().numpy())
    tile["probabilities"].attrs = {"long_name": "Probabilities"}
    tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan"))

    # Cleanup cuda memory
    del tensor_tile, probabilities
    free_torch()

    return tile