Skip to content

binary_instance_stat_scores

darts_segmentation.metrics.binary_instance_stat_scores

Binary instance segmentation metrics.

BinaryInstanceAccuracy

BinaryInstanceAccuracy(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance accuracy metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _accuracy_reduce(
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceConfusionMatrix

BinaryInstanceConfusionMatrix(
    normalize: bool | None = None,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance confusion matrix metric.

Create a new instance of the BinaryInstanceConfusionMatrix metric.

Parameters:

  • normalize (bool, default: None ) –

    If True, return the confusion matrix normalized by the number of instances. If False, return the confusion matrix without normalization. Defaults to None.

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    normalize: bool | None = None,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceConfusionMatrix metric.

    Args:
        normalize (bool, optional): If True, return the confusion matrix normalized by the number of instances.
            If False, return the confusion matrix without normalization. Defaults to None.
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `normalize` is not a bool.

    """
    super().__init__(
        threshold=threshold,
        matching_threshold=matching_threshold,
        multidim_average=multidim_average,
        ignore_index=ignore_index,
        validate_args=False,
        **kwargs,
    )
    if normalize is not None and not isinstance(normalize, bool):
        raise ValueError(f"Argument `normalize` needs to be of bool type but got {type(normalize)}")
    self.normalize = normalize

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = None

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

normalize instance-attribute

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    # tn is always 0
    if self.normalize:
        all = tp + fp + fn
        return torch.tensor([[0, fp / all], [fn / all, tp / all]], device=tp.device)
    else:
        return torch.tensor([[tn, fp], [fn, tp]], device=tp.device)

plot

plot(
    val: torch.Tensor | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
    add_text: bool = True,
    labels: list[str] | None = None,
    cmap: torchmetrics.utilities.plot._CMAP_TYPE
    | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
    add_text: bool = True,
    labels: list[str] | None = None,  # type: ignore
    cmap: _CMAP_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    val = val or self.compute()
    if not isinstance(val, Tensor):
        raise TypeError(f"Expected val to be a single tensor but got {val}")
    fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
    return fig, ax

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceF1Score

BinaryInstanceF1Score(
    threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore

Binary instance F1 score metric.

Create a new instance of the BinaryInstanceF1Score metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • zero_division (float, default: 0 ) –

    Value to return when there is a zero division. Defaults to 0.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceF1Score metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        zero_division (float, optional): Value to return when there is a zero division. Defaults to 0.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    """
    super().__init__(
        beta=1.0,
        threshold=threshold,
        multidim_average=multidim_average,
        ignore_index=ignore_index,
        validate_args=validate_args,
        zero_division=zero_division,
        **kwargs,
    )

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _fbeta_reduce(
        tp,
        fp,
        tn,
        fn,
        self.beta,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceFBetaScore

BinaryInstanceFBetaScore(
    beta: float,
    threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance F-beta score metric.

Create a new instance of the BinaryInstanceFBetaScore metric.

Parameters:

  • beta (float) –

    The beta parameter for the F-beta score.

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • zero_division (float, default: 0 ) –

    Value to return when there is a zero division. Defaults to 0.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    beta: float,
    threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceFBetaScore metric.

    Args:
        beta (float): The beta parameter for the F-beta score.
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        zero_division (float, optional): Value to return when there is a zero division. Defaults to 0.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    """
    super().__init__(
        threshold=threshold,
        multidim_average=multidim_average,
        ignore_index=ignore_index,
        validate_args=False,
        **kwargs,
    )
    if validate_args:
        _binary_fbeta_score_arg_validation(beta, threshold, multidim_average, ignore_index, zero_division)
    self.validate_args = validate_args
    self.zero_division = zero_division
    self.beta = beta

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _fbeta_reduce(
        tp,
        fp,
        tn,
        fn,
        self.beta,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstancePrecision

BinaryInstancePrecision(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance precision metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _precision_recall_reduce(
        "precision",
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceRecall

BinaryInstanceRecall(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance recall metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _precision_recall_reduce(
        "recall",
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceStatScores

BinaryInstanceStatScores(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: torchmetrics.classification.stat_scores._AbstractStatScores

Base class for binary instance segmentation metrics.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = None

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _binary_stat_scores_compute(tp, fp, tn, fn, self.multidim_average)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

mask_to_instances

mask_to_instances(
    x: torch.Tensor, validate_args: bool = False
) -> list[torch.Tensor]

Convert a binary segmentation mask into multiple instance masks. Expects a batched version of the input.

Currently only supports uint8 tensors, hence a maximum number of 255 instances per mask.

Parameters:

  • x (torch.Tensor) –

    The binary segmentation mask. Shape: (batch_size, height, width), dtype: torch.uint8

  • validate_args (bool, default: False ) –

    Whether to validate the input arguments. Defaults to False.

Returns:

  • list[torch.Tensor]

    list[torch.Tensor]: The instance masks. Length of list: batch_size. Shape of a tensor: (height, width), dtype: torch.uint8

Source code in darts-segmentation/src/darts_segmentation/metrics/instance_helpers.py
@torch.no_grad()
def mask_to_instances(x: torch.Tensor, validate_args: bool = False) -> list[torch.Tensor]:
    """Convert a binary segmentation mask into multiple instance masks. Expects a batched version of the input.

    Currently only supports uint8 tensors, hence a maximum number of 255 instances per mask.

    Args:
        x (torch.Tensor): The binary segmentation mask. Shape: (batch_size, height, width), dtype: torch.uint8
        validate_args (bool, optional): Whether to validate the input arguments. Defaults to False.

    Returns:
        list[torch.Tensor]: The instance masks. Length of list: batch_size.
            Shape of a tensor: (height, width), dtype: torch.uint8

    """
    if validate_args:
        assert x.dim() == 3, f"Expected 3 dimensions, got {x.dim()}"
        assert x.dtype == torch.uint8, f"Expected torch.uint8, got {x.dtype}"
        assert x.min() >= 0 and x.max() <= 1, f"Expected binary mask, got {x.min()} and {x.max()}"

    # A note on using lists as separation between instances instead of using a batched tensor:
    # Using a batched tensor with instance numbers (1, 2, 3, ...) would indicate that the instances of the samples
    # are identical. Using a list clearly separates the instances of the samples.

    if CUCIM_AVAILABLE:
        # Check if device is cuda
        assert x.device.type == "cuda", f"Expected device to be cuda, got {x.device.type}"
        x = cp.asarray(x).astype(cp.uint8)

        instances = []
        for x_i in x:
            instances_i = label_gpu(x_i)
            instances_i = torch.tensor(instances_i, dtype=torch.uint8)
            instances.append(instances_i)
        return instances

    else:
        instances = []
        for x_i in x:
            x_i = x_i.cpu().numpy()
            instances_i = label(x_i)
            instances_i = torch.tensor(instances_i, dtype=torch.uint8)
            instances.append(instances_i)
        return instances

match_instances

match_instances(
    instances_target: torch.Tensor,
    instances_preds: torch.Tensor,
    match_threshold: float = 0.5,
    validate_args: bool = False,
) -> tuple[int, int, int]

Match instances between target and prediction masks. Expects non-batched input from skimage.measure.label.

Parameters:

  • instances_target (torch.Tensor) –

    The instance mask of the target. Shape: (height, width), dtype: torch.uint8

  • instances_preds (torch.Tensor) –

    The instance mask of the prediction. Shape: (height, width), dtype: torch.uint8

  • match_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • validate_args (bool, default: False ) –

    Whether to validate the input arguments. Defaults to False.

Returns:

  • tuple[int, int, int]

    tuple[int, int, int]: True positives, false positives, false negatives

Source code in darts-segmentation/src/darts_segmentation/metrics/instance_helpers.py
@torch.no_grad()
def match_instances(
    instances_target: torch.Tensor,
    instances_preds: torch.Tensor,
    match_threshold: float = 0.5,
    validate_args: bool = False,
) -> tuple[int, int, int]:
    """Match instances between target and prediction masks. Expects non-batched input from skimage.measure.label.

    Args:
        instances_target (torch.Tensor): The instance mask of the target. Shape: (height, width), dtype: torch.uint8
        instances_preds (torch.Tensor): The instance mask of the prediction. Shape: (height, width), dtype: torch.uint8
        match_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        validate_args (bool, optional): Whether to validate the input arguments. Defaults to False.

    Returns:
        tuple[int, int, int]: True positives, false positives, false negatives

    """
    if validate_args:
        assert instances_target.dim() == 2, f"Expected 2 dimensions, got {instances_target.dim()}"
        assert instances_preds.dim() == 2, f"Expected 2 dimensions, got {instances_preds.dim()}"
        assert instances_target.dtype == torch.uint8, f"Expected torch.uint8, got {instances_target.dtype}"
        assert instances_preds.dtype == torch.uint8, f"Expected torch.uint8, got {instances_preds.dtype}"
        assert instances_target.shape == instances_preds.shape, (
            f"Shapes do not match: {instances_target.shape} and {instances_preds.shape}"
        )

    height, width = instances_target.shape
    ntargets = instances_target.max().item()
    npreds = instances_preds.max().item()
    # If target or predictions has no instances, return 0 for their respective metrics.
    # If none of them has instances, return 0 for all metrics. (This is implied)
    if ntargets == 0:
        return 0, npreds, 0
    if npreds == 0:
        return 0, 0, ntargets

    # TODO: These are old edge case filter that need revision.
    # They are probably not necessary, since the instance metrics are meaningless for noisy predictions.
    # If there are too many predictions, return all as false positives (this happens when the model is very noisy)
    # if npreds > ntargets * 5:
    #     return 0, npreds, ntargets
    # If there is only one prediction, return all as false negatives (this happens when the model is very noisy)
    # if npreds == 1 and ntargets > 1:
    #     return 0, 1, ntargets

    # Create one-hot encoding of instances, so that each instance is a channel
    instances_target_onehot = torch.zeros((ntargets, height, width), dtype=torch.uint8, device=instances_target.device)
    instances_preds_onehot = torch.zeros((npreds, height, width), dtype=torch.uint8, device=instances_target.device)
    for i in range(ntargets):
        instances_target_onehot[i, :, :] = instances_target == (i + 1)
    for i in range(npreds):
        instances_preds_onehot[i, :, :] = instances_preds == (i + 1)

    # Now the instances are channels, hence tensors of shape (num_instances, height, width)

    # Calculate IoU (we need to do a n-m intersection and union, therefore we need to broadcast)
    intersection = (instances_target_onehot.unsqueeze(1) & instances_preds_onehot.unsqueeze(0)).sum(
        dim=(2, 3)
    )  # Shape: (num_instances_target, num_instances_preds)
    union = (instances_target_onehot.unsqueeze(1) | instances_preds_onehot.unsqueeze(0)).sum(
        dim=(2, 3)
    )  # Shape: (num_instances_target, num_instances_preds)
    iou = intersection / union  # Shape: (num_instances_target, num_instances_preds)

    # Match instances based on IoU
    tp = (iou >= match_threshold).sum().item()
    fp = npreds - tp
    fn = ntargets - tp

    return tp, fp, fn