Skip to content

darts_segmentation.metrics.BinaryInstancePrecisionRecallCurve

Bases: torchmetrics.Metric

Compute the precision-recall curve for binary instance segmentation.

This metric works similar to torchmetrics.classification.PrecisionRecallCurve, with two key differences: 1. It calculates the tp, fp, fn values for each instance (blob) in the batch, and then aggregates them. Instead of calculating the values for each pixel. 2. The "thresholds" argument is required. Calculating the thresholds at the compute stage would cost to much memory for this usecase.

Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

Parameters:

  • thresholds (int | list[float] | torch.Tensor, default: None ) –

    The thresholds to use for the curve. Defaults to None.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def __init__(
    self,
    thresholds: int | list[float] | Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

    Args:
        thresholds (int | list[float] | Tensor, optional): The thresholds to use for the curve. Defaults to None.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If thresholds is None.

    """
    super().__init__(**kwargs)
    if validate_args:
        _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )
        if thresholds is None:
            raise ValueError("Argument `thresholds` must be provided for this metric.")

    self.matching_threshold = matching_threshold
    self.ignore_index = ignore_index
    self.validate_args = validate_args

    thresholds = _adjust_threshold_arg(thresholds)
    self.register_buffer("thresholds", thresholds, persistent=False)
    self.add_state("confmat", default=torch.zeros(len(thresholds), 2, 2, dtype=torch.long), dist_reduce_fx="sum")

confmat instance-attribute

confmat: torch.Tensor

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = None

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    matching_threshold
)

preds instance-attribute

preds: list[torch.Tensor]

target instance-attribute

target: list[torch.Tensor]

thesholds instance-attribute

thesholds: torch.Tensor

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    validate_args
)

compute

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def compute(self) -> tuple[Tensor, Tensor, Tensor]:  # noqa: D102
    return _binary_precision_recall_curve_compute(self.confmat, self.thresholds)

plot

plot(
    curve: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    | None = None,
    score: torch.Tensor | bool | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def plot(  # noqa: D102
    self,
    curve: tuple[Tensor, Tensor, Tensor] | None = None,
    score: Tensor | bool | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    curve_computed = curve or self.compute()
    # switch order as the standard way is recall along x-axis and precision along y-axis
    curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2])

    score = (
        _auc_compute_without_check(curve_computed[0], curve_computed[1], direction=-1.0)
        if not curve and score is True
        else None
    )
    return plot_curve(
        curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__
    )

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update metric states.

Parameters:

  • preds (torch.Tensor) –

    The predicted mask. Shape: (batch_size, height, width)

  • target (torch.Tensor) –

    The target mask. Shape: (batch_size, height, width)

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update metric states.

    Args:
        preds (Tensor): The predicted mask. Shape: (batch_size, height, width)
        target (Tensor): The target mask. Shape: (batch_size, height, width)

    Raises:
        ValueError: If preds and target have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.sigmoid()

    if self.ignore_index is not None:
        target = (target == 1).to(torch.uint8)

    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)

    len_t = len(self.thresholds)
    confmat = self.thresholds.new_zeros((len_t, 2, 2), dtype=torch.int64)
    for i in range(len_t):
        preds_i = preds >= self.thresholds[i]

        if self.ignore_index is not None:
            invalid_idx = target == self.ignore_index
            preds_i = preds_i.clone()
            preds_i[invalid_idx] = 0  # This will prevent from counting instances in the ignored area

        instance_list_preds_i = mask_to_instances(preds_i.to(torch.uint8), self.validate_args)
        for target_i, preds_i in zip(instance_list_target, instance_list_preds_i):
            tp, fp, fn = match_instances(
                target_i,
                preds_i,
                match_threshold=self.matching_threshold,
                validate_args=self.validate_args,
            )
            confmat[i, 1, 1] += tp
            confmat[i, 0, 1] += fp
            confmat[i, 1, 0] += fn
    self.confmat += confmat