Skip to content

binary_instance_prc

darts_segmentation.metrics.binary_instance_prc

Complex binary instance segmentation metrics.

MatchingMetric module-attribute

MatchingMetric = typing.Literal['iou', 'boundary']

BinaryInstanceAveragePrecision

BinaryInstanceAveragePrecision(
    thresholds: int | list[float] | torch.Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve

Compute the average precision for binary instance segmentation.

Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

Parameters:

  • thresholds (int | list[float] | torch.Tensor, default: None ) –

    The thresholds to use for the curve. Defaults to None.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def __init__(
    self,
    thresholds: int | list[float] | Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

    Args:
        thresholds (int | list[float] | Tensor, optional): The thresholds to use for the curve. Defaults to None.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If thresholds is None.

    """
    super().__init__(**kwargs)
    if validate_args:
        _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )
        if thresholds is None:
            raise ValueError("Argument `thresholds` must be provided for this metric.")

    self.matching_threshold = matching_threshold
    self.ignore_index = ignore_index
    self.validate_args = validate_args

    thresholds = _adjust_threshold_arg(thresholds)
    self.register_buffer("thresholds", thresholds, persistent=False)
    self.add_state("confmat", default=torch.zeros(len(thresholds), 2, 2, dtype=torch.long), dist_reduce_fx="sum")

confmat instance-attribute

confmat: torch.Tensor

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

preds instance-attribute

preds: list[torch.Tensor]

target instance-attribute

target: list[torch.Tensor]

thesholds instance-attribute

thesholds: torch.Tensor

validate_args instance-attribute

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def compute(self) -> Tensor:  # type: ignore[override]  # noqa: D102
    return _binary_average_precision_compute(self.confmat, self.thresholds)

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def plot(  # type: ignore[override]  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update metric states.

Parameters:

  • preds (torch.Tensor) –

    The predicted mask. Shape: (batch_size, height, width)

  • target (torch.Tensor) –

    The target mask. Shape: (batch_size, height, width)

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update metric states.

    Args:
        preds (Tensor): The predicted mask. Shape: (batch_size, height, width)
        target (Tensor): The target mask. Shape: (batch_size, height, width)

    Raises:
        ValueError: If preds and target have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.sigmoid()

    if self.ignore_index is not None:
        target = (target == 1).to(torch.uint8)

    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)

    len_t = len(self.thresholds)
    confmat = self.thresholds.new_zeros((len_t, 2, 2), dtype=torch.int64)
    for i in range(len_t):
        preds_i = preds >= self.thresholds[i]

        if self.ignore_index is not None:
            invalid_idx = target == self.ignore_index
            preds_i = preds_i.clone()
            preds_i[invalid_idx] = 0  # This will prevent from counting instances in the ignored area

        instance_list_preds_i = mask_to_instances(preds_i.to(torch.uint8), self.validate_args)
        for target_i, preds_i in zip(instance_list_target, instance_list_preds_i):
            tp, fp, fn = match_instances(
                target_i,
                preds_i,
                match_threshold=self.matching_threshold,
                validate_args=self.validate_args,
            )
            confmat[i, 1, 1] += tp
            confmat[i, 0, 1] += fp
            confmat[i, 1, 0] += fn
    self.confmat += confmat

BinaryInstancePrecisionRecallCurve

BinaryInstancePrecisionRecallCurve(
    thresholds: int | list[float] | torch.Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: torchmetrics.Metric

Compute the precision-recall curve for binary instance segmentation.

This metric works similar to torchmetrics.classification.PrecisionRecallCurve, with two key differences: 1. It calculates the tp, fp, fn values for each instance (blob) in the batch, and then aggregates them. Instead of calculating the values for each pixel. 2. The "thresholds" argument is required. Calculating the thresholds at the compute stage would cost to much memory for this usecase.

Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

Parameters:

  • thresholds (int | list[float] | torch.Tensor, default: None ) –

    The thresholds to use for the curve. Defaults to None.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def __init__(
    self,
    thresholds: int | list[float] | Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

    Args:
        thresholds (int | list[float] | Tensor, optional): The thresholds to use for the curve. Defaults to None.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If thresholds is None.

    """
    super().__init__(**kwargs)
    if validate_args:
        _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )
        if thresholds is None:
            raise ValueError("Argument `thresholds` must be provided for this metric.")

    self.matching_threshold = matching_threshold
    self.ignore_index = ignore_index
    self.validate_args = validate_args

    thresholds = _adjust_threshold_arg(thresholds)
    self.register_buffer("thresholds", thresholds, persistent=False)
    self.add_state("confmat", default=torch.zeros(len(thresholds), 2, 2, dtype=torch.long), dist_reduce_fx="sum")

confmat instance-attribute

confmat: torch.Tensor

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = None

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

preds instance-attribute

preds: list[torch.Tensor]

target instance-attribute

target: list[torch.Tensor]

thesholds instance-attribute

thesholds: torch.Tensor

validate_args instance-attribute

compute

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def compute(self) -> tuple[Tensor, Tensor, Tensor]:  # noqa: D102
    return _binary_precision_recall_curve_compute(self.confmat, self.thresholds)

plot

plot(
    curve: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    | None = None,
    score: torch.Tensor | bool | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def plot(  # noqa: D102
    self,
    curve: tuple[Tensor, Tensor, Tensor] | None = None,
    score: Tensor | bool | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    curve_computed = curve or self.compute()
    # switch order as the standard way is recall along x-axis and precision along y-axis
    curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2])

    score = (
        _auc_compute_without_check(curve_computed[0], curve_computed[1], direction=-1.0)
        if not curve and score is True
        else None
    )
    return plot_curve(
        curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__
    )

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update metric states.

Parameters:

  • preds (torch.Tensor) –

    The predicted mask. Shape: (batch_size, height, width)

  • target (torch.Tensor) –

    The target mask. Shape: (batch_size, height, width)

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update metric states.

    Args:
        preds (Tensor): The predicted mask. Shape: (batch_size, height, width)
        target (Tensor): The target mask. Shape: (batch_size, height, width)

    Raises:
        ValueError: If preds and target have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.sigmoid()

    if self.ignore_index is not None:
        target = (target == 1).to(torch.uint8)

    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)

    len_t = len(self.thresholds)
    confmat = self.thresholds.new_zeros((len_t, 2, 2), dtype=torch.int64)
    for i in range(len_t):
        preds_i = preds >= self.thresholds[i]

        if self.ignore_index is not None:
            invalid_idx = target == self.ignore_index
            preds_i = preds_i.clone()
            preds_i[invalid_idx] = 0  # This will prevent from counting instances in the ignored area

        instance_list_preds_i = mask_to_instances(preds_i.to(torch.uint8), self.validate_args)
        for target_i, preds_i in zip(instance_list_target, instance_list_preds_i):
            tp, fp, fn = match_instances(
                target_i,
                preds_i,
                match_threshold=self.matching_threshold,
                validate_args=self.validate_args,
            )
            confmat[i, 1, 1] += tp
            confmat[i, 0, 1] += fp
            confmat[i, 1, 0] += fn
    self.confmat += confmat

mask_to_instances

mask_to_instances(
    x: torch.Tensor, validate_args: bool = False
) -> list[torch.Tensor]

Convert a binary segmentation mask into multiple instance masks. Expects a batched version of the input.

Currently only supports uint8 tensors, hence a maximum number of 255 instances per mask.

Parameters:

  • x (torch.Tensor) –

    The binary segmentation mask. Shape: (batch_size, height, width), dtype: torch.uint8

  • validate_args (bool, default: False ) –

    Whether to validate the input arguments. Defaults to False.

Returns:

  • list[torch.Tensor]

    list[torch.Tensor]: The instance masks. Length of list: batch_size. Shape of a tensor: (height, width), dtype: torch.uint8

Source code in darts-segmentation/src/darts_segmentation/metrics/instance_helpers.py
@torch.no_grad()
def mask_to_instances(x: torch.Tensor, validate_args: bool = False) -> list[torch.Tensor]:
    """Convert a binary segmentation mask into multiple instance masks. Expects a batched version of the input.

    Currently only supports uint8 tensors, hence a maximum number of 255 instances per mask.

    Args:
        x (torch.Tensor): The binary segmentation mask. Shape: (batch_size, height, width), dtype: torch.uint8
        validate_args (bool, optional): Whether to validate the input arguments. Defaults to False.

    Returns:
        list[torch.Tensor]: The instance masks. Length of list: batch_size.
            Shape of a tensor: (height, width), dtype: torch.uint8

    """
    if validate_args:
        assert x.dim() == 3, f"Expected 3 dimensions, got {x.dim()}"
        assert x.dtype == torch.uint8, f"Expected torch.uint8, got {x.dtype}"
        assert x.min() >= 0 and x.max() <= 1, f"Expected binary mask, got {x.min()} and {x.max()}"

    # A note on using lists as separation between instances instead of using a batched tensor:
    # Using a batched tensor with instance numbers (1, 2, 3, ...) would indicate that the instances of the samples
    # are identical. Using a list clearly separates the instances of the samples.

    if CUCIM_AVAILABLE:
        # Check if device is cuda
        assert x.device.type == "cuda", f"Expected device to be cuda, got {x.device.type}"
        x = cp.asarray(x).astype(cp.uint8)

        instances = []
        for x_i in x:
            instances_i = label_gpu(x_i)
            instances_i = torch.tensor(instances_i, dtype=torch.uint8)
            instances.append(instances_i)
        return instances

    else:
        instances = []
        for x_i in x:
            x_i = x_i.cpu().numpy()
            instances_i = label(x_i)
            instances_i = torch.tensor(instances_i, dtype=torch.uint8)
            instances.append(instances_i)
        return instances

match_instances

match_instances(
    instances_target: torch.Tensor,
    instances_preds: torch.Tensor,
    match_threshold: float = 0.5,
    validate_args: bool = False,
) -> tuple[int, int, int]

Match instances between target and prediction masks. Expects non-batched input from skimage.measure.label.

Parameters:

  • instances_target (torch.Tensor) –

    The instance mask of the target. Shape: (height, width), dtype: torch.uint8

  • instances_preds (torch.Tensor) –

    The instance mask of the prediction. Shape: (height, width), dtype: torch.uint8

  • match_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • validate_args (bool, default: False ) –

    Whether to validate the input arguments. Defaults to False.

Returns:

  • tuple[int, int, int]

    tuple[int, int, int]: True positives, false positives, false negatives

Source code in darts-segmentation/src/darts_segmentation/metrics/instance_helpers.py
@torch.no_grad()
def match_instances(
    instances_target: torch.Tensor,
    instances_preds: torch.Tensor,
    match_threshold: float = 0.5,
    validate_args: bool = False,
) -> tuple[int, int, int]:
    """Match instances between target and prediction masks. Expects non-batched input from skimage.measure.label.

    Args:
        instances_target (torch.Tensor): The instance mask of the target. Shape: (height, width), dtype: torch.uint8
        instances_preds (torch.Tensor): The instance mask of the prediction. Shape: (height, width), dtype: torch.uint8
        match_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        validate_args (bool, optional): Whether to validate the input arguments. Defaults to False.

    Returns:
        tuple[int, int, int]: True positives, false positives, false negatives

    """
    if validate_args:
        assert instances_target.dim() == 2, f"Expected 2 dimensions, got {instances_target.dim()}"
        assert instances_preds.dim() == 2, f"Expected 2 dimensions, got {instances_preds.dim()}"
        assert instances_target.dtype == torch.uint8, f"Expected torch.uint8, got {instances_target.dtype}"
        assert instances_preds.dtype == torch.uint8, f"Expected torch.uint8, got {instances_preds.dtype}"
        assert instances_target.shape == instances_preds.shape, (
            f"Shapes do not match: {instances_target.shape} and {instances_preds.shape}"
        )

    height, width = instances_target.shape
    ntargets = instances_target.max().item()
    npreds = instances_preds.max().item()
    # If target or predictions has no instances, return 0 for their respective metrics.
    # If none of them has instances, return 0 for all metrics. (This is implied)
    if ntargets == 0:
        return 0, npreds, 0
    if npreds == 0:
        return 0, 0, ntargets

    # TODO: These are old edge case filter that need revision.
    # They are probably not necessary, since the instance metrics are meaningless for noisy predictions.
    # If there are too many predictions, return all as false positives (this happens when the model is very noisy)
    # if npreds > ntargets * 5:
    #     return 0, npreds, ntargets
    # If there is only one prediction, return all as false negatives (this happens when the model is very noisy)
    # if npreds == 1 and ntargets > 1:
    #     return 0, 1, ntargets

    # Create one-hot encoding of instances, so that each instance is a channel
    instances_target_onehot = torch.zeros((ntargets, height, width), dtype=torch.uint8, device=instances_target.device)
    instances_preds_onehot = torch.zeros((npreds, height, width), dtype=torch.uint8, device=instances_target.device)
    for i in range(ntargets):
        instances_target_onehot[i, :, :] = instances_target == (i + 1)
    for i in range(npreds):
        instances_preds_onehot[i, :, :] = instances_preds == (i + 1)

    # Now the instances are channels, hence tensors of shape (num_instances, height, width)

    # Calculate IoU (we need to do a n-m intersection and union, therefore we need to broadcast)
    intersection = (instances_target_onehot.unsqueeze(1) & instances_preds_onehot.unsqueeze(0)).sum(
        dim=(2, 3)
    )  # Shape: (num_instances_target, num_instances_preds)
    union = (instances_target_onehot.unsqueeze(1) | instances_preds_onehot.unsqueeze(0)).sum(
        dim=(2, 3)
    )  # Shape: (num_instances_target, num_instances_preds)
    iou = intersection / union  # Shape: (num_instances_target, num_instances_preds)

    # Match instances based on IoU
    tp = (iou >= match_threshold).sum().item()
    fp = npreds - tp
    fn = ntargets - tp

    return tp, fp, fn