Skip to content

darts_segmentation.metrics.BinaryInstanceRecall

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance recall metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    matching_threshold
)

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    multidim_average
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    threshold
)

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    validate_args
)

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _precision_recall_reduce(
        "recall",
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)