Skip to content

darts_postprocessing.binarize

Binarize the probabilities based on a threshold and a mask.

Steps for binarization
  1. Dilate the mask. This will dilate the edges of holes in the mask as well as the edges of the tile.
  2. Binarize the probabilities based on the threshold.
  3. Remove objects at which overlap with either the edge of the tile or the noData mask.
  4. Remove small objects.

Parameters:

  • probs (xarray.DataArray) –

    Probabilities to binarize.

  • threshold (float) –

    Threshold to binarize the probabilities.

  • min_object_size (int) –

    Minimum object size to keep.

  • mask (xarray.DataArray) –

    Mask to apply to the binarized probabilities. Expects 0=negative, 1=postitive.

  • device (typing.Literal['cuda', 'cpu'] | int) –

    The device to use for removing small objects.

Returns:

Source code in darts-postprocessing/src/darts_postprocessing/postprocess.py
@stopuhr.funkuhr("Binarizing probabilities", printer=logger.debug, print_kwargs=["threshold", "min_object_size"])
def binarize(
    probs: xr.DataArray,
    threshold: float,
    min_object_size: int,
    mask: xr.DataArray,
    device: Literal["cuda", "cpu"] | int,
) -> xr.DataArray:
    """Binarize the probabilities based on a threshold and a mask.

    Steps for binarization:
        1. Dilate the mask. This will dilate the edges of holes in the mask as well as the edges of the tile.
        2. Binarize the probabilities based on the threshold.
        3. Remove objects at which overlap with either the edge of the tile or the noData mask.
        4. Remove small objects.

    Args:
        probs (xr.DataArray): Probabilities to binarize.
        threshold (float): Threshold to binarize the probabilities.
        min_object_size (int): Minimum object size to keep.
        mask (xr.DataArray): Mask to apply to the binarized probabilities. Expects 0=negative, 1=postitive.
        device (Literal["cuda", "cpu"] | int): The device to use for removing small objects.

    Returns:
        xr.DataArray: Binarized probabilities.

    """
    use_gpu = device == "cuda" or isinstance(device, int)

    # Warn user if use_gpu is set but no GPU is available
    if use_gpu and not CUCIM_AVAILABLE:
        logger.warning(
            f"Device was set to {device}, but GPU acceleration is not available. Calculating TPI and slope on CPU."
        )
        use_gpu = False

    # Where the output from the ensemble / segmentation is nan turn it into 0, else threshold it
    # Also, where there was no valid input data, turn it into 0
    binarized = (probs.fillna(0) > threshold).astype("uint8")

    # Remove objects at which overlap with either the edge of the tile or the noData mask
    labels = binarized.copy(data=label(binarized, connectivity=2))
    edge_label_ids = np.unique(xr.where(~mask, labels, 0))
    binarized = ~labels.isin(edge_label_ids) & binarized

    # Remove small objects with GPU
    if use_gpu:
        device_nr = device if isinstance(device, int) else 0
        logger.debug(f"Moving binarized to GPU:{device}.")
        # Check if binarized is dask, if not persist it, since remove_small_objects_gpu can't be calculated from
        # cupy-dask arrays
        if binarized.chunks is not None:
            binarized = binarized.persist()
        with cp.cuda.Device(device_nr):
            binarized = binarized.cupy.as_cupy()
            binarized.values = remove_small_objects_gpu(
                binarized.astype(bool).expand_dims("batch", 0).data, min_size=min_object_size
            )[0]
            binarized = binarized.cupy.as_numpy()
            free_cupy()
    else:
        binarized.values = remove_small_objects(
            binarized.astype(bool).expand_dims("batch", 0).values, min_size=min_object_size
        )[0]

    # Convert back to int8
    binarized = binarized.astype("uint8")

    return binarized