Skip to content

darts_segmentation.metrics.BinaryInstanceAveragePrecision

Bases: darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve

Compute the average precision for binary instance segmentation.

Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

Parameters:

  • thresholds (int | list[float] | torch.Tensor, default: None ) –

    The thresholds to use for the curve. Defaults to None.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def __init__(
    self,
    thresholds: int | list[float] | Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

    Args:
        thresholds (int | list[float] | Tensor, optional): The thresholds to use for the curve. Defaults to None.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If thresholds is None.

    """
    super().__init__(**kwargs)
    if validate_args:
        _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )
        if thresholds is None:
            raise ValueError("Argument `thresholds` must be provided for this metric.")

    self.matching_threshold = matching_threshold
    self.ignore_index = ignore_index
    self.validate_args = validate_args

    thresholds = _adjust_threshold_arg(thresholds)
    self.register_buffer("thresholds", thresholds, persistent=False)
    self.add_state("confmat", default=torch.zeros(len(thresholds), 2, 2, dtype=torch.long), dist_reduce_fx="sum")

confmat instance-attribute

confmat: torch.Tensor

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool = True

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    matching_threshold
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

preds instance-attribute

preds: list[torch.Tensor]

target instance-attribute

target: list[torch.Tensor]

thesholds instance-attribute

thesholds: torch.Tensor

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    validate_args
)

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def compute(self) -> Tensor:  # type: ignore[override]  # noqa: D102
    return _binary_average_precision_compute(self.confmat, self.thresholds)

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def plot(  # type: ignore[override]  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update metric states.

Parameters:

  • preds (torch.Tensor) –

    The predicted mask. Shape: (batch_size, height, width)

  • target (torch.Tensor) –

    The target mask. Shape: (batch_size, height, width)

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update metric states.

    Args:
        preds (Tensor): The predicted mask. Shape: (batch_size, height, width)
        target (Tensor): The target mask. Shape: (batch_size, height, width)

    Raises:
        ValueError: If preds and target have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.sigmoid()

    if self.ignore_index is not None:
        target = (target == 1).to(torch.uint8)

    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)

    len_t = len(self.thresholds)
    confmat = self.thresholds.new_zeros((len_t, 2, 2), dtype=torch.int64)
    for i in range(len_t):
        preds_i = preds >= self.thresholds[i]

        if self.ignore_index is not None:
            invalid_idx = target == self.ignore_index
            preds_i = preds_i.clone()
            preds_i[invalid_idx] = 0  # This will prevent from counting instances in the ignored area

        instance_list_preds_i = mask_to_instances(preds_i.to(torch.uint8), self.validate_args)
        for target_i, preds_i in zip(instance_list_target, instance_list_preds_i):
            tp, fp, fn = match_instances(
                target_i,
                preds_i,
                match_threshold=self.matching_threshold,
                validate_args=self.validate_args,
            )
            confmat[i, 1, 1] += tp
            confmat[i, 0, 1] += fp
            confmat[i, 1, 0] += fn
    self.confmat += confmat