Skip to content

boundary_iou

darts_segmentation.metrics.boundary_iou

Boundary IoU metric for binary segmentation tasks.

MatchingMetric module-attribute

MatchingMetric = typing.Literal['iou', 'boundary']

BinaryBoundaryIoU

BinaryBoundaryIoU(
    dilation: float | int = 0.02,
    threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Unpack[
        darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoUKwargs
    ],
)

Bases: torchmetrics.Metric

Binary Boundary IoU metric for binary segmentation tasks.

This metric is similar to the Binary Intersection over Union (IoU or Jaccard Index) metric, but instead of comparing all pixels it only compares the boundaries of each foreground object.

Create a new instance of the BinaryBoundaryIoU metric.

Please see the torchmetrics docs for more info about the **kwargs.

Parameters:

  • dilation (float | int, default: 0.02 ) –

    The dilation (factor) / width of the boundary. Dilation in pixels if int, else ratio to calculate dilation = dilation_ratio * image_diagonal. Default: 0.02

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • **kwargs (typing.Unpack[darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoUKwargs], default: {} ) –

    Additional keyword arguments for the metric.

Other Parameters:

  • zero_division (int) –

    Value to return when there is a zero division. Default is 0.

  • compute_on_cpu (bool) –

    If metric state should be stored on CPU during computations. Only works for list states.

  • dist_sync_on_step (bool) –

    If metric state should synchronize on forward(). Default is False.

  • process_group (str) –

    The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn (callable) –

    Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn (callable) –

    Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute (bool) –

    If metric state should synchronize when compute is called. Default is True.

  • compute_with_cache (bool) –

    If results from compute should be cached. Default is True.

Raises:

  • ValueError

    If dilation is not a float or int.

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
def __init__(
    self,
    dilation: float | int = 0.02,
    threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Unpack[BinaryBoundaryIoUKwargs],
):
    """Create a new instance of the BinaryBoundaryIoU metric.

    Please see the
    [torchmetrics docs](https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metric-kwargs)
    for more info about the **kwargs.

    Args:
        dilation (float | int, optional): The dilation (factor) / width of the boundary.
            Dilation in pixels if int, else ratio to calculate `dilation = dilation_ratio * image_diagonal`.
            Default: 0.02
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class.  Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        **kwargs: Additional keyword arguments for the metric.

    Keyword Args:
        zero_division (int):
            Value to return when there is a zero division. Default is 0.
        compute_on_cpu (bool):
            If metric state should be stored on CPU during computations. Only works for list states.
        dist_sync_on_step (bool):
            If metric state should synchronize on ``forward()``. Default is ``False``.
        process_group (str):
            The process group on which the synchronization is called. Default is the world.
        dist_sync_fn (callable):
            Function that performs the allgather option on the metric state. Default is a custom
            implementation that calls ``torch.distributed.all_gather`` internally.
        distributed_available_fn (callable):
            Function that checks if the distributed backend is available. Defaults to a
            check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
        sync_on_compute (bool):
            If metric state should synchronize when ``compute`` is called. Default is ``True``.
        compute_with_cache (bool):
            If results from ``compute`` should be cached. Default is ``True``.

    Raises:
        ValueError: If dilation is not a float or int.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super().__init__(**kwargs)

    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not isinstance(dilation, float | int):
            raise ValueError(f"Expected argument `dilation` to be a float or int, but got {dilation}.")

    self.dilation = dilation
    self.threshold = threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    if multidim_average == "samplewise":
        self.add_state("intersection", default=[], dist_reduce_fx="cat")
        self.add_state("union", default=[], dist_reduce_fx="cat")
    else:
        self.add_state("intersection", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("union", default=torch.tensor(0), dist_reduce_fx="sum")

dilation instance-attribute

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

intersection instance-attribute

intersection: torch.Tensor | list[torch.Tensor]

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
    multidim_average
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

union instance-attribute

validate_args instance-attribute

validate_args = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
    validate_args
)

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor

Compute the metric.

Returns:

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
def compute(self) -> Tensor:
    """Compute the metric.

    Returns:
        Tensor: The computed metric.

    """
    if self.multidim_average == "global":
        return self.intersection / self.union
    else:
        self.intersection = torch.tensor(self.intersection)
        self.union = torch.tensor(self.union)
        return self.intersection / self.union

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If the input arguments are invalid.

  • ValueError

    If the input shapes are invalid.

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If the input arguments are invalid.
        ValueError: If the input shapes are invalid.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.shape == target.shape:
            raise ValueError(
                f"Expected `preds` and `target` to have the same shape, but got {preds.shape} and {target.shape}."
            )
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions, but got {preds.dim()}.")

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    target = target.to(torch.uint8)
    preds = preds.to(torch.uint8)

    target_boundary = get_boundary((target == 1).to(torch.uint8), self.dilation, self.validate_args)
    preds_boundary = get_boundary(preds, self.dilation, self.validate_args)

    intersection = target_boundary & preds_boundary
    union = target_boundary | preds_boundary

    if self.ignore_index is not None:
        # Important that this is NOT the boundary, but the original mask
        valid_idx = target != self.ignore_index
        intersection &= valid_idx
        union &= valid_idx

    intersection = intersection.sum().item()
    union = union.sum().item()

    if self.multidim_average == "global":
        self.intersection += intersection
        self.union += union
    else:
        self.intersection.append(intersection)
        self.union.append(union)

BinaryBoundaryIoUKwargs

Bases: typing.TypedDict

Keyword arguments for the BinaryBoundaryIoU metric.

compute_on_cpu instance-attribute

compute_on_cpu: bool

compute_with_cache instance-attribute

compute_with_cache: bool

dist_sync_fn instance-attribute

dist_sync_fn: callable

dist_sync_on_step instance-attribute

dist_sync_on_step: bool

distributed_available_fn instance-attribute

distributed_available_fn: callable

process_group instance-attribute

process_group: str

sync_on_compute instance-attribute

sync_on_compute: bool

zero_division instance-attribute

zero_division: typing.Literal[0, 1]

get_boundary

get_boundary(
    binary_instances: torch.Tensor,
    dilation: float | int = 0.02,
    validate_args: bool = False,
)

Convert instance masks to instance boundaries.

Parameters:

  • binary_instances (torch.Tensor) –

    Target instance masks. Must be binary. Can be batched, one-hot encoded or both. (3 or 4 dimensions). The last two dimensions must be height and width.

  • dilation (float | int, default: 0.02 ) –

    The dilation (factor) / width of the boundary. Dilation in pixels if int, else ratio to calculate dilation = dilation_ratio * image_diagonal. Default: 0.02

  • validate_args (bool, default: False ) –

    Weather arguments should be validated. Defaults to False.

Returns:

  • tuple[torch.Tensor, torch.Tensor]: The boundaries of the instances.

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_helpers.py
@torch.no_grad()
def get_boundary(
    binary_instances: torch.Tensor,
    dilation: float | int = 0.02,
    validate_args: bool = False,
):
    """Convert instance masks to instance boundaries.

    Args:
        binary_instances (torch.Tensor): Target instance masks. Must be binary.
            Can be batched, one-hot encoded or both. (3 or 4 dimensions).
            The last two dimensions must be height and width.
        dilation (float | int, optional): The dilation (factor) / width of the boundary.
            Dilation in pixels if int, else ratio to calculate `dilation = dilation_ratio * image_diagonal`.
            Default: 0.02
        validate_args (bool, optional): Weather arguments should be validated. Defaults to False.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: The boundaries of the instances.

    """
    if validate_args:
        assert binary_instances.dim() in [3, 4], f"Expected 3 or 4 dimensions, got {binary_instances.dim()}"
        assert binary_instances.dtype == torch.uint8, f"Expected torch.uint8, got {binary_instances.dtype}"
        assert (
            binary_instances.min() >= 0 and binary_instances.max() <= 1
        ), f"Expected binary mask, got range between {binary_instances.min()} and {binary_instances.max()}"
        assert isinstance(dilation, float | int), f"Expected float or int, got {type(dilation)}"
        assert dilation >= 0, f"Expected dilation >= 0, got {dilation}"

    if binary_instances.dim() == 3:
        _n, h, w = binary_instances.shape
    else:
        _n, _c, h, w = binary_instances.shape

    if isinstance(dilation, float):
        img_diag = sqrt(h**2 + w**2)
        dilation = round(dilation * img_diag)
        if dilation < 1:
            dilation = 1

    # Pad the instances to avoid boundary issues
    pad = torchvision.transforms.Pad(1)
    binary_instances_padded = pad(binary_instances)

    # Erode the instances to get the boundaries
    eroded = erode_pytorch(binary_instances_padded, iterations=dilation, validate_args=validate_args)

    # Remove the padding
    if binary_instances.dim() == 3:
        eroded = eroded[:, 1:-1, 1:-1]
    else:
        eroded = eroded[:, :, 1:-1, 1:-1]
    # Calculate the boundary of the instances
    boundaries = binary_instances - eroded

    return boundaries