Skip to content

darts_segmentation.metrics

Own metrics for segmentation tasks.

Classes:

BinaryBoundaryIoU

BinaryBoundaryIoU(
    dilation: float | int = 0.02,
    threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Unpack[
        darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoUKwargs
    ],
)

Bases: torchmetrics.Metric

Binary Boundary IoU metric for binary segmentation tasks.

This metric is similar to the Binary Intersection over Union (IoU or Jaccard Index) metric, but instead of comparing all pixels it only compares the boundaries of each foreground object.

Create a new instance of the BinaryBoundaryIoU metric.

Please see the torchmetrics docs for more info about the **kwargs.

Parameters:

  • dilation (float | int, default: 0.02 ) –

    The dilation (factor) / width of the boundary. Dilation in pixels if int, else ratio to calculate dilation = dilation_ratio * image_diagonal. Default: 0.02

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • **kwargs (typing.Unpack[darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoUKwargs], default: {} ) –

    Additional keyword arguments for the metric.

Other Parameters:

  • zero_division (int) –

    Value to return when there is a zero division. Default is 0.

  • compute_on_cpu (bool) –

    If metric state should be stored on CPU during computations. Only works for list states.

  • dist_sync_on_step (bool) –

    If metric state should synchronize on forward(). Default is False.

  • process_group (str) –

    The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn (callable) –

    Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn (callable) –

    Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute (bool) –

    If metric state should synchronize when compute is called. Default is True.

  • compute_with_cache (bool) –

    If results from compute should be cached. Default is True.

Raises:

  • ValueError

    If dilation is not a float or int.

Methods:

  • compute

    Compute the metric.

  • update

    Update the metric state.

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
def __init__(
    self,
    dilation: float | int = 0.02,
    threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Unpack[BinaryBoundaryIoUKwargs],
):
    """Create a new instance of the BinaryBoundaryIoU metric.

    Please see the
    [torchmetrics docs](https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metric-kwargs)
    for more info about the **kwargs.

    Args:
        dilation (float | int, optional): The dilation (factor) / width of the boundary.
            Dilation in pixels if int, else ratio to calculate `dilation = dilation_ratio * image_diagonal`.
            Default: 0.02
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class.  Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        **kwargs: Additional keyword arguments for the metric.

    Keyword Args:
        zero_division (int):
            Value to return when there is a zero division. Default is 0.
        compute_on_cpu (bool):
            If metric state should be stored on CPU during computations. Only works for list states.
        dist_sync_on_step (bool):
            If metric state should synchronize on ``forward()``. Default is ``False``.
        process_group (str):
            The process group on which the synchronization is called. Default is the world.
        dist_sync_fn (callable):
            Function that performs the allgather option on the metric state. Default is a custom
            implementation that calls ``torch.distributed.all_gather`` internally.
        distributed_available_fn (callable):
            Function that checks if the distributed backend is available. Defaults to a
            check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
        sync_on_compute (bool):
            If metric state should synchronize when ``compute`` is called. Default is ``True``.
        compute_with_cache (bool):
            If results from ``compute`` should be cached. Default is ``True``.

    Raises:
        ValueError: If dilation is not a float or int.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super().__init__(**kwargs)

    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not isinstance(dilation, float | int):
            raise ValueError(f"Expected argument `dilation` to be a float or int, but got {dilation}.")

    self.dilation = dilation
    self.threshold = threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    if multidim_average == "samplewise":
        self.add_state("intersection", default=[], dist_reduce_fx="cat")
        self.add_state("union", default=[], dist_reduce_fx="cat")
    else:
        self.add_state("intersection", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("union", default=torch.tensor(0), dist_reduce_fx="sum")

dilation instance-attribute

dilation = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
    dilation
)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
    ignore_index
)

intersection instance-attribute

intersection: torch.Tensor | list[torch.Tensor]

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
    multidim_average
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

threshold = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
    threshold
)

union instance-attribute

validate_args instance-attribute

validate_args = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
    validate_args
)

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor

Compute the metric.

Returns:

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
def compute(self) -> Tensor:
    """Compute the metric.

    Returns:
        Tensor: The computed metric.

    """
    if self.multidim_average == "global":
        return self.intersection / self.union
    else:
        self.intersection = torch.tensor(self.intersection)
        self.union = torch.tensor(self.union)
        return self.intersection / self.union

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If the input arguments are invalid.

  • ValueError

    If the input shapes are invalid.

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If the input arguments are invalid.
        ValueError: If the input shapes are invalid.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.shape == target.shape:
            raise ValueError(
                f"Expected `preds` and `target` to have the same shape, but got {preds.shape} and {target.shape}."
            )
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions, but got {preds.dim()}.")

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    target = target.to(torch.uint8)
    preds = preds.to(torch.uint8)

    target_boundary = get_boundary((target == 1).to(torch.uint8), self.dilation, self.validate_args)
    preds_boundary = get_boundary(preds, self.dilation, self.validate_args)

    intersection = target_boundary & preds_boundary
    union = target_boundary | preds_boundary

    if self.ignore_index is not None:
        # Important that this is NOT the boundary, but the original mask
        valid_idx = target != self.ignore_index
        intersection &= valid_idx
        union &= valid_idx

    intersection = intersection.sum().item()
    union = union.sum().item()

    if self.multidim_average == "global":
        self.intersection += intersection
        self.union += union
    else:
        self.intersection.append(intersection)
        self.union.append(union)

BinaryInstanceAccuracy

BinaryInstanceAccuracy(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance accuracy metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    matching_threshold
)

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    multidim_average
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    threshold
)

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    validate_args
)

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _accuracy_reduce(
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceAveragePrecision

BinaryInstanceAveragePrecision(
    thresholds: int | list[float] | torch.Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve

Compute the average precision for binary instance segmentation.

Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

Parameters:

  • thresholds (int | list[float] | torch.Tensor, default: None ) –

    The thresholds to use for the curve. Defaults to None.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def __init__(
    self,
    thresholds: int | list[float] | Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

    Args:
        thresholds (int | list[float] | Tensor, optional): The thresholds to use for the curve. Defaults to None.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If thresholds is None.

    """
    super().__init__(**kwargs)
    if validate_args:
        _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )
        if thresholds is None:
            raise ValueError("Argument `thresholds` must be provided for this metric.")

    self.matching_threshold = matching_threshold
    self.ignore_index = ignore_index
    self.validate_args = validate_args

    thresholds = _adjust_threshold_arg(thresholds)
    self.register_buffer("thresholds", thresholds, persistent=False)
    self.add_state("confmat", default=torch.zeros(len(thresholds), 2, 2, dtype=torch.long), dist_reduce_fx="sum")

confmat instance-attribute

confmat: torch.Tensor

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool = True

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    matching_threshold
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

preds instance-attribute

preds: list[torch.Tensor]

target instance-attribute

target: list[torch.Tensor]

thesholds instance-attribute

thesholds: torch.Tensor

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    validate_args
)

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def compute(self) -> Tensor:  # type: ignore[override]  # noqa: D102
    return _binary_average_precision_compute(self.confmat, self.thresholds)

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def plot(  # type: ignore[override]  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update metric states.

Parameters:

  • preds (torch.Tensor) –

    The predicted mask. Shape: (batch_size, height, width)

  • target (torch.Tensor) –

    The target mask. Shape: (batch_size, height, width)

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update metric states.

    Args:
        preds (Tensor): The predicted mask. Shape: (batch_size, height, width)
        target (Tensor): The target mask. Shape: (batch_size, height, width)

    Raises:
        ValueError: If preds and target have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.sigmoid()

    if self.ignore_index is not None:
        target = (target == 1).to(torch.uint8)

    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)

    len_t = len(self.thresholds)
    confmat = self.thresholds.new_zeros((len_t, 2, 2), dtype=torch.int64)
    for i in range(len_t):
        preds_i = preds >= self.thresholds[i]

        if self.ignore_index is not None:
            invalid_idx = target == self.ignore_index
            preds_i = preds_i.clone()
            preds_i[invalid_idx] = 0  # This will prevent from counting instances in the ignored area

        instance_list_preds_i = mask_to_instances(preds_i.to(torch.uint8), self.validate_args)
        for target_i, preds_i in zip(instance_list_target, instance_list_preds_i):
            tp, fp, fn = match_instances(
                target_i,
                preds_i,
                match_threshold=self.matching_threshold,
                validate_args=self.validate_args,
            )
            confmat[i, 1, 1] += tp
            confmat[i, 0, 1] += fp
            confmat[i, 1, 0] += fn
    self.confmat += confmat

BinaryInstanceConfusionMatrix

BinaryInstanceConfusionMatrix(
    normalize: bool | None = None,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance confusion matrix metric.

Create a new instance of the BinaryInstanceConfusionMatrix metric.

Parameters:

  • normalize (bool, default: None ) –

    If True, return the confusion matrix normalized by the number of instances. If False, return the confusion matrix without normalization. Defaults to None.

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    normalize: bool | None = None,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceConfusionMatrix metric.

    Args:
        normalize (bool, optional): If True, return the confusion matrix normalized by the number of instances.
            If False, return the confusion matrix without normalization. Defaults to None.
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `normalize` is not a bool.

    """
    super().__init__(
        threshold=threshold,
        matching_threshold=matching_threshold,
        multidim_average=multidim_average,
        ignore_index=ignore_index,
        validate_args=False,
        **kwargs,
    )
    if normalize is not None and not isinstance(normalize, bool):
        raise ValueError(f"Argument `normalize` needs to be of bool type but got {type(normalize)}")
    self.normalize = normalize

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = None

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    matching_threshold
)

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    multidim_average
)

normalize instance-attribute

normalize = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceConfusionMatrix(
    normalize
)

threshold instance-attribute

threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    threshold
)

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    validate_args
)

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    # tn is always 0
    if self.normalize:
        all = tp + fp + fn
        return torch.tensor([[0, fp / all], [fn / all, tp / all]], device=tp.device)
    else:
        return torch.tensor([[tn, fp], [fn, tp]], device=tp.device)

plot

plot(
    val: torch.Tensor | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
    add_text: bool = True,
    labels: list[str] | None = None,
    cmap: torchmetrics.utilities.plot._CMAP_TYPE
    | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
    add_text: bool = True,
    labels: list[str] | None = None,  # type: ignore
    cmap: _CMAP_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    val = val or self.compute()
    if not isinstance(val, Tensor):
        raise TypeError(f"Expected val to be a single tensor but got {val}")
    fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
    return fig, ax

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceF1Score

BinaryInstanceF1Score(
    threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore

Binary instance F1 score metric.

Create a new instance of the BinaryInstanceF1Score metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • zero_division (float, default: 0 ) –

    Value to return when there is a zero division. Defaults to 0.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceF1Score metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        zero_division (float, optional): Value to return when there is a zero division. Defaults to 0.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    """
    super().__init__(
        beta=1.0,
        threshold=threshold,
        multidim_average=multidim_average,
        ignore_index=ignore_index,
        validate_args=validate_args,
        zero_division=zero_division,
        **kwargs,
    )

beta instance-attribute

beta = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore(
    beta
)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    matching_threshold
)

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    multidim_average
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    threshold
)

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore(
    validate_args
)

zero_division instance-attribute

zero_division = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore(
    zero_division
)

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _fbeta_reduce(
        tp,
        fp,
        tn,
        fn,
        self.beta,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceFBetaScore

BinaryInstanceFBetaScore(
    beta: float,
    threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance F-beta score metric.

Create a new instance of the BinaryInstanceFBetaScore metric.

Parameters:

  • beta (float) –

    The beta parameter for the F-beta score.

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • zero_division (float, default: 0 ) –

    Value to return when there is a zero division. Defaults to 0.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    beta: float,
    threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceFBetaScore metric.

    Args:
        beta (float): The beta parameter for the F-beta score.
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        zero_division (float, optional): Value to return when there is a zero division. Defaults to 0.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    """
    super().__init__(
        threshold=threshold,
        multidim_average=multidim_average,
        ignore_index=ignore_index,
        validate_args=False,
        **kwargs,
    )
    if validate_args:
        _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index, zero_division)
    self.validate_args = validate_args
    self.zero_division = zero_division
    self.beta = beta

beta instance-attribute

beta = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore(
    beta
)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    matching_threshold
)

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    multidim_average
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    threshold
)

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore(
    validate_args
)

zero_division instance-attribute

zero_division = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore(
    zero_division
)

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _fbeta_reduce(
        tp,
        fp,
        tn,
        fn,
        self.beta,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstancePrecision

BinaryInstancePrecision(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance precision metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    matching_threshold
)

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    multidim_average
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    threshold
)

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    validate_args
)

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _precision_recall_reduce(
        "precision",
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstancePrecisionRecallCurve

BinaryInstancePrecisionRecallCurve(
    thresholds: int | list[float] | torch.Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: torchmetrics.Metric

Compute the precision-recall curve for binary instance segmentation.

This metric works similar to torchmetrics.classification.PrecisionRecallCurve, with two key differences: 1. It calculates the tp, fp, fn values for each instance (blob) in the batch, and then aggregates them. Instead of calculating the values for each pixel. 2. The "thresholds" argument is required. Calculating the thresholds at the compute stage would cost to much memory for this usecase.

Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

Parameters:

  • thresholds (int | list[float] | torch.Tensor, default: None ) –

    The thresholds to use for the curve. Defaults to None.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def __init__(
    self,
    thresholds: int | list[float] | Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

    Args:
        thresholds (int | list[float] | Tensor, optional): The thresholds to use for the curve. Defaults to None.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If thresholds is None.

    """
    super().__init__(**kwargs)
    if validate_args:
        _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )
        if thresholds is None:
            raise ValueError("Argument `thresholds` must be provided for this metric.")

    self.matching_threshold = matching_threshold
    self.ignore_index = ignore_index
    self.validate_args = validate_args

    thresholds = _adjust_threshold_arg(thresholds)
    self.register_buffer("thresholds", thresholds, persistent=False)
    self.add_state("confmat", default=torch.zeros(len(thresholds), 2, 2, dtype=torch.long), dist_reduce_fx="sum")

confmat instance-attribute

confmat: torch.Tensor

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = None

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    matching_threshold
)

preds instance-attribute

preds: list[torch.Tensor]

target instance-attribute

target: list[torch.Tensor]

thesholds instance-attribute

thesholds: torch.Tensor

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
    validate_args
)

compute

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def compute(self) -> tuple[Tensor, Tensor, Tensor]:  # noqa: D102
    return _binary_precision_recall_curve_compute(self.confmat, self.thresholds)

plot

plot(
    curve: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    | None = None,
    score: torch.Tensor | bool | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def plot(  # noqa: D102
    self,
    curve: tuple[Tensor, Tensor, Tensor] | None = None,
    score: Tensor | bool | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    curve_computed = curve or self.compute()
    # switch order as the standard way is recall along x-axis and precision along y-axis
    curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2])

    score = (
        _auc_compute_without_check(curve_computed[0], curve_computed[1], direction=-1.0)
        if not curve and score is True
        else None
    )
    return plot_curve(
        curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__
    )

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update metric states.

Parameters:

  • preds (torch.Tensor) –

    The predicted mask. Shape: (batch_size, height, width)

  • target (torch.Tensor) –

    The target mask. Shape: (batch_size, height, width)

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update metric states.

    Args:
        preds (Tensor): The predicted mask. Shape: (batch_size, height, width)
        target (Tensor): The target mask. Shape: (batch_size, height, width)

    Raises:
        ValueError: If preds and target have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.sigmoid()

    if self.ignore_index is not None:
        target = (target == 1).to(torch.uint8)

    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)

    len_t = len(self.thresholds)
    confmat = self.thresholds.new_zeros((len_t, 2, 2), dtype=torch.int64)
    for i in range(len_t):
        preds_i = preds >= self.thresholds[i]

        if self.ignore_index is not None:
            invalid_idx = target == self.ignore_index
            preds_i = preds_i.clone()
            preds_i[invalid_idx] = 0  # This will prevent from counting instances in the ignored area

        instance_list_preds_i = mask_to_instances(preds_i.to(torch.uint8), self.validate_args)
        for target_i, preds_i in zip(instance_list_target, instance_list_preds_i):
            tp, fp, fn = match_instances(
                target_i,
                preds_i,
                match_threshold=self.matching_threshold,
                validate_args=self.validate_args,
            )
            confmat[i, 1, 1] += tp
            confmat[i, 0, 1] += fp
            confmat[i, 1, 0] += fn
    self.confmat += confmat

BinaryInstanceRecall

BinaryInstanceRecall(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance recall metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    matching_threshold
)

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    multidim_average
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    threshold
)

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    validate_args
)

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _precision_recall_reduce(
        "recall",
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceStatScores

BinaryInstanceStatScores(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: torchmetrics.classification.stat_scores._AbstractStatScores

Base class for binary instance segmentation metrics.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = None

ignore_index instance-attribute

ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    ignore_index
)

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    matching_threshold
)

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    multidim_average
)

threshold instance-attribute

threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    threshold
)

validate_args instance-attribute

validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
    validate_args
)

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)