Skip to content

callbacks

darts_segmentation.training.callbacks

PyTorch Lightning Callbacks for training and validation.

Stage module-attribute

Stage = typing.Literal["fit", "validate", "test", "predict"]

logger module-attribute

logger = logging.getLogger(
    __name__.replace("darts_", "darts.")
)

Bands

Bases: collections.UserList[darts_segmentation.utils.Band]

Wrapper for the list of bands.

factors property

factors: list[float]

Get the factors of the bands.

Returns:

  • list[float]

    list[float]: The factors of the bands.

names property

names: list[str]

Get the names of the bands.

Returns:

  • list[str]

    list[str]: The names of the bands.

offsets property

offsets: list[float]

Get the offsets of the bands.

Returns:

  • list[float]

    list[float]: The offsets of the bands.

__reduce__

__reduce__()
Source code in darts-segmentation/src/darts_segmentation/utils.py
def __reduce__(self):  # noqa: D105
    # This is needed to pickle (and unpickle) the Bands object as a dict
    # This is needed, because this way we don't need to have this class present when unpickling
    # a pytorch checkpoint
    return (dict, (self.to_config(),))

__repr__

__repr__() -> str
Source code in darts-segmentation/src/darts_segmentation/utils.py
def __repr__(self) -> str:  # noqa: D105
    band_info = ", ".join([f"{band.name}(*{band.factor:.5f}+{band.offset:.5f})" for band in self])
    return f"Bands({band_info})"

filter

filter(
    band_names: list[str],
) -> darts_segmentation.utils.Bands

Filter the bands by name.

Parameters:

  • band_names (list[str]) –

    The names of the bands to keep.

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
def filter(self, band_names: list[str]) -> "Bands":
    """Filter the bands by name.

    Args:
        band_names (list[str]): The names of the bands to keep.

    Returns:
        Bands: The filtered Bands object.

    """
    return Bands([band for band in self if band.name in band_names])

from_config classmethod

from_config(
    config: dict[
        typing.Literal[
            "bands", "band_factors", "band_offsets"
        ],
        list,
    ]
    | dict[str, tuple[float, float]],
) -> darts_segmentation.utils.Bands

Create a Bands object from a config dictionary.

Parameters:

  • config (dict) –

    The config dictionary containing the band information. Expects config to be a dictionary with keys "bands", "band_factors" and "band_offsets", with the values to be lists of the same length.

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
@classmethod
def from_config(
    cls,
    config: dict[Literal["bands", "band_factors", "band_offsets"], list] | dict[str, tuple[float, float]],
) -> "Bands":
    """Create a Bands object from a config dictionary.

    Args:
        config (dict): The config dictionary containing the band information.
            Expects config to be a dictionary with keys "bands", "band_factors" and "band_offsets",
            with the values to be lists of the same length.

    Returns:
        Bands: The Bands object.

    """
    assert "bands" in config and "band_factors" in config and "band_offsets" in config, (
        f"Config must contain keys 'bands', 'band_factors' and 'band_offsets'.Got {config} instead."
    )
    return cls(
        [
            Band(name=name, factor=factor, offset=offset)
            for name, factor, offset in zip(config["bands"], config["band_factors"], config["band_offsets"])
        ]
    )

from_dict classmethod

from_dict(
    config: dict[str, tuple[float, float]],
) -> darts_segmentation.utils.Bands

Create a Bands object from a dictionary.

Parameters:

  • config (dict[str, tuple[float, float]]) –

    The dictionary containing the band information. Expects the keys to be the band names and the values to be tuples of (factor, offset). Example: {"band1": (1.0, 0.0), "band2": (2.0, 1.0)}

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
@classmethod
def from_dict(cls, config: dict[str, tuple[float, float]]) -> "Bands":
    """Create a Bands object from a dictionary.

    Args:
        config (dict[str, tuple[float, float]]): The dictionary containing the band information.
            Expects the keys to be the band names and the values to be tuples of (factor, offset).
            Example: {"band1": (1.0, 0.0), "band2": (2.0, 1.0)}

    Returns:
        Bands: The Bands object.

    """
    return cls([Band(name=name, factor=factor, offset=offset) for name, (factor, offset) in config.items()])

to_config

to_config() -> dict[
    typing.Literal["bands", "band_factors", "band_offsets"],
    list,
]

Convert the Bands object to a config dictionary.

Returns:

  • dict ( dict[typing.Literal['bands', 'band_factors', 'band_offsets'], list] ) –

    The config dictionary containing the band information.

Source code in darts-segmentation/src/darts_segmentation/utils.py
def to_config(self) -> dict[Literal["bands", "band_factors", "band_offsets"], list]:
    """Convert the Bands object to a config dictionary.

    Returns:
        dict: The config dictionary containing the band information.

    """
    return {
        "bands": [band.name for band in self],
        "band_factors": [band.factor for band in self],
        "band_offsets": [band.offset for band in self],
    }

to_dict

to_dict() -> dict[str, tuple[float, float]]

Convert the Bands object to a dictionary.

Returns:

  • dict[str, tuple[float, float]]

    dict[str, tuple[float, float]]: The dictionary containing the band information.

Source code in darts-segmentation/src/darts_segmentation/utils.py
def to_dict(self) -> dict[str, tuple[float, float]]:
    """Convert the Bands object to a dictionary.

    Returns:
        dict[str, tuple[float, float]]: The dictionary containing the band information.

    """
    return {band.name: (band.factor, band.offset) for band in self}

BinaryBoundaryIoU

BinaryBoundaryIoU(
    dilation: float | int = 0.02,
    threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Unpack[
        darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoUKwargs
    ],
)

Bases: torchmetrics.Metric

Binary Boundary IoU metric for binary segmentation tasks.

This metric is similar to the Binary Intersection over Union (IoU or Jaccard Index) metric, but instead of comparing all pixels it only compares the boundaries of each foreground object.

Create a new instance of the BinaryBoundaryIoU metric.

Please see the torchmetrics docs for more info about the **kwargs.

Parameters:

  • dilation (float | int, default: 0.02 ) –

    The dilation (factor) / width of the boundary. Dilation in pixels if int, else ratio to calculate dilation = dilation_ratio * image_diagonal. Default: 0.02

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • **kwargs (typing.Unpack[darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoUKwargs], default: {} ) –

    Additional keyword arguments for the metric.

Other Parameters:

  • zero_division (int) –

    Value to return when there is a zero division. Default is 0.

  • compute_on_cpu (bool) –

    If metric state should be stored on CPU during computations. Only works for list states.

  • dist_sync_on_step (bool) –

    If metric state should synchronize on forward(). Default is False.

  • process_group (str) –

    The process group on which the synchronization is called. Default is the world.

  • dist_sync_fn (callable) –

    Function that performs the allgather option on the metric state. Default is a custom implementation that calls torch.distributed.all_gather internally.

  • distributed_available_fn (callable) –

    Function that checks if the distributed backend is available. Defaults to a check of torch.distributed.is_available() and torch.distributed.is_initialized().

  • sync_on_compute (bool) –

    If metric state should synchronize when compute is called. Default is True.

  • compute_with_cache (bool) –

    If results from compute should be cached. Default is True.

Raises:

  • ValueError

    If dilation is not a float or int.

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
def __init__(
    self,
    dilation: float | int = 0.02,
    threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Unpack[BinaryBoundaryIoUKwargs],
):
    """Create a new instance of the BinaryBoundaryIoU metric.

    Please see the
    [torchmetrics docs](https://lightning.ai/docs/torchmetrics/stable/pages/overview.html#metric-kwargs)
    for more info about the **kwargs.

    Args:
        dilation (float | int, optional): The dilation (factor) / width of the boundary.
            Dilation in pixels if int, else ratio to calculate `dilation = dilation_ratio * image_diagonal`.
            Default: 0.02
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class.  Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        **kwargs: Additional keyword arguments for the metric.

    Keyword Args:
        zero_division (int):
            Value to return when there is a zero division. Default is 0.
        compute_on_cpu (bool):
            If metric state should be stored on CPU during computations. Only works for list states.
        dist_sync_on_step (bool):
            If metric state should synchronize on ``forward()``. Default is ``False``.
        process_group (str):
            The process group on which the synchronization is called. Default is the world.
        dist_sync_fn (callable):
            Function that performs the allgather option on the metric state. Default is a custom
            implementation that calls ``torch.distributed.all_gather`` internally.
        distributed_available_fn (callable):
            Function that checks if the distributed backend is available. Defaults to a
            check of ``torch.distributed.is_available()`` and ``torch.distributed.is_initialized()``.
        sync_on_compute (bool):
            If metric state should synchronize when ``compute`` is called. Default is ``True``.
        compute_with_cache (bool):
            If results from ``compute`` should be cached. Default is ``True``.

    Raises:
        ValueError: If dilation is not a float or int.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super().__init__(**kwargs)

    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not isinstance(dilation, float | int):
            raise ValueError(f"Expected argument `dilation` to be a float or int, but got {dilation}.")

    self.dilation = dilation
    self.threshold = threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    if multidim_average == "samplewise":
        self.add_state("intersection", default=[], dist_reduce_fx="cat")
        self.add_state("union", default=[], dist_reduce_fx="cat")
    else:
        self.add_state("intersection", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("union", default=torch.tensor(0), dist_reduce_fx="sum")

dilation instance-attribute

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

intersection instance-attribute

intersection: torch.Tensor | list[torch.Tensor]

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

multidim_average instance-attribute

multidim_average = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
    multidim_average
)

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

union instance-attribute

validate_args instance-attribute

validate_args = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
    validate_args
)

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor

Compute the metric.

Returns:

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
def compute(self) -> Tensor:
    """Compute the metric.

    Returns:
        Tensor: The computed metric.

    """
    if self.multidim_average == "global":
        return self.intersection / self.union
    else:
        self.intersection = torch.tensor(self.intersection)
        self.union = torch.tensor(self.union)
        return self.intersection / self.union

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If the input arguments are invalid.

  • ValueError

    If the input shapes are invalid.

Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If the input arguments are invalid.
        ValueError: If the input shapes are invalid.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.shape == target.shape:
            raise ValueError(
                f"Expected `preds` and `target` to have the same shape, but got {preds.shape} and {target.shape}."
            )
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions, but got {preds.dim()}.")

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    target = target.to(torch.uint8)
    preds = preds.to(torch.uint8)

    target_boundary = get_boundary((target == 1).to(torch.uint8), self.dilation, self.validate_args)
    preds_boundary = get_boundary(preds, self.dilation, self.validate_args)

    intersection = target_boundary & preds_boundary
    union = target_boundary | preds_boundary

    if self.ignore_index is not None:
        # Important that this is NOT the boundary, but the original mask
        valid_idx = target != self.ignore_index
        intersection &= valid_idx
        union &= valid_idx

    intersection = intersection.sum().item()
    union = union.sum().item()

    if self.multidim_average == "global":
        self.intersection += intersection
        self.union += union
    else:
        self.intersection.append(intersection)
        self.union.append(union)

BinaryInstanceAccuracy

BinaryInstanceAccuracy(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance accuracy metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _accuracy_reduce(
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceAveragePrecision

BinaryInstanceAveragePrecision(
    thresholds: int | list[float] | torch.Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve

Compute the average precision for binary instance segmentation.

Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

Parameters:

  • thresholds (int | list[float] | torch.Tensor, default: None ) –

    The thresholds to use for the curve. Defaults to None.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def __init__(
    self,
    thresholds: int | list[float] | Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

    Args:
        thresholds (int | list[float] | Tensor, optional): The thresholds to use for the curve. Defaults to None.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If thresholds is None.

    """
    super().__init__(**kwargs)
    if validate_args:
        _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )
        if thresholds is None:
            raise ValueError("Argument `thresholds` must be provided for this metric.")

    self.matching_threshold = matching_threshold
    self.ignore_index = ignore_index
    self.validate_args = validate_args

    thresholds = _adjust_threshold_arg(thresholds)
    self.register_buffer("thresholds", thresholds, persistent=False)
    self.add_state("confmat", default=torch.zeros(len(thresholds), 2, 2, dtype=torch.long), dist_reduce_fx="sum")

confmat instance-attribute

confmat: torch.Tensor

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

preds instance-attribute

preds: list[torch.Tensor]

target instance-attribute

target: list[torch.Tensor]

thesholds instance-attribute

thesholds: torch.Tensor

validate_args instance-attribute

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def compute(self) -> Tensor:  # type: ignore[override]  # noqa: D102
    return _binary_average_precision_compute(self.confmat, self.thresholds)

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def plot(  # type: ignore[override]  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update metric states.

Parameters:

  • preds (torch.Tensor) –

    The predicted mask. Shape: (batch_size, height, width)

  • target (torch.Tensor) –

    The target mask. Shape: (batch_size, height, width)

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update metric states.

    Args:
        preds (Tensor): The predicted mask. Shape: (batch_size, height, width)
        target (Tensor): The target mask. Shape: (batch_size, height, width)

    Raises:
        ValueError: If preds and target have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.sigmoid()

    if self.ignore_index is not None:
        target = (target == 1).to(torch.uint8)

    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)

    len_t = len(self.thresholds)
    confmat = self.thresholds.new_zeros((len_t, 2, 2), dtype=torch.int64)
    for i in range(len_t):
        preds_i = preds >= self.thresholds[i]

        if self.ignore_index is not None:
            invalid_idx = target == self.ignore_index
            preds_i = preds_i.clone()
            preds_i[invalid_idx] = 0  # This will prevent from counting instances in the ignored area

        instance_list_preds_i = mask_to_instances(preds_i.to(torch.uint8), self.validate_args)
        for target_i, preds_i in zip(instance_list_target, instance_list_preds_i):
            tp, fp, fn = match_instances(
                target_i,
                preds_i,
                match_threshold=self.matching_threshold,
                validate_args=self.validate_args,
            )
            confmat[i, 1, 1] += tp
            confmat[i, 0, 1] += fp
            confmat[i, 1, 0] += fn
    self.confmat += confmat

BinaryInstanceConfusionMatrix

BinaryInstanceConfusionMatrix(
    normalize: bool | None = None,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance confusion matrix metric.

Create a new instance of the BinaryInstanceConfusionMatrix metric.

Parameters:

  • normalize (bool, default: None ) –

    If True, return the confusion matrix normalized by the number of instances. If False, return the confusion matrix without normalization. Defaults to None.

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    normalize: bool | None = None,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceConfusionMatrix metric.

    Args:
        normalize (bool, optional): If True, return the confusion matrix normalized by the number of instances.
            If False, return the confusion matrix without normalization. Defaults to None.
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `normalize` is not a bool.

    """
    super().__init__(
        threshold=threshold,
        matching_threshold=matching_threshold,
        multidim_average=multidim_average,
        ignore_index=ignore_index,
        validate_args=False,
        **kwargs,
    )
    if normalize is not None and not isinstance(normalize, bool):
        raise ValueError(f"Argument `normalize` needs to be of bool type but got {type(normalize)}")
    self.normalize = normalize

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = None

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

normalize instance-attribute

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    # tn is always 0
    if self.normalize:
        all = tp + fp + fn
        return torch.tensor([[0, fp / all], [fn / all, tp / all]], device=tp.device)
    else:
        return torch.tensor([[tn, fp], [fn, tp]], device=tp.device)

plot

plot(
    val: torch.Tensor | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
    add_text: bool = True,
    labels: list[str] | None = None,
    cmap: torchmetrics.utilities.plot._CMAP_TYPE
    | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
    add_text: bool = True,
    labels: list[str] | None = None,  # type: ignore
    cmap: _CMAP_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    val = val or self.compute()
    if not isinstance(val, Tensor):
        raise TypeError(f"Expected val to be a single tensor but got {val}")
    fig, ax = plot_confusion_matrix(val, ax=ax, add_text=add_text, labels=labels, cmap=cmap)
    return fig, ax

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstanceF1Score

BinaryInstanceF1Score(
    threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore

Binary instance F1 score metric.

Create a new instance of the BinaryInstanceF1Score metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • zero_division (float, default: 0 ) –

    Value to return when there is a zero division. Defaults to 0.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    zero_division: float = 0,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceF1Score metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        zero_division (float, optional): Value to return when there is a zero division. Defaults to 0.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    """
    super().__init__(
        beta=1.0,
        threshold=threshold,
        multidim_average=multidim_average,
        ignore_index=ignore_index,
        validate_args=validate_args,
        zero_division=zero_division,
        **kwargs,
    )

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _fbeta_reduce(
        tp,
        fp,
        tn,
        fn,
        self.beta,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstancePrecision

BinaryInstancePrecision(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance precision metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _precision_recall_reduce(
        "precision",
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinaryInstancePrecisionRecallCurve

BinaryInstancePrecisionRecallCurve(
    thresholds: int | list[float] | torch.Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: torchmetrics.Metric

Compute the precision-recall curve for binary instance segmentation.

This metric works similar to torchmetrics.classification.PrecisionRecallCurve, with two key differences: 1. It calculates the tp, fp, fn values for each instance (blob) in the batch, and then aggregates them. Instead of calculating the values for each pixel. 2. The "thresholds" argument is required. Calculating the thresholds at the compute stage would cost to much memory for this usecase.

Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

Parameters:

  • thresholds (int | list[float] | torch.Tensor, default: None ) –

    The thresholds to use for the curve. Defaults to None.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def __init__(
    self,
    thresholds: int | list[float] | Tensor = None,
    matching_threshold: float = 0.5,
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstancePrecisionRecallCurve metric.

    Args:
        thresholds (int | list[float] | Tensor, optional): The thresholds to use for the curve. Defaults to None.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If thresholds is None.

    """
    super().__init__(**kwargs)
    if validate_args:
        _binary_precision_recall_curve_arg_validation(thresholds, ignore_index)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )
        if thresholds is None:
            raise ValueError("Argument `thresholds` must be provided for this metric.")

    self.matching_threshold = matching_threshold
    self.ignore_index = ignore_index
    self.validate_args = validate_args

    thresholds = _adjust_threshold_arg(thresholds)
    self.register_buffer("thresholds", thresholds, persistent=False)
    self.add_state("confmat", default=torch.zeros(len(thresholds), 2, 2, dtype=torch.long), dist_reduce_fx="sum")

confmat instance-attribute

confmat: torch.Tensor

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = None

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

preds instance-attribute

preds: list[torch.Tensor]

target instance-attribute

target: list[torch.Tensor]

thesholds instance-attribute

thesholds: torch.Tensor

validate_args instance-attribute

compute

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def compute(self) -> tuple[Tensor, Tensor, Tensor]:  # noqa: D102
    return _binary_precision_recall_curve_compute(self.confmat, self.thresholds)

plot

plot(
    curve: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    | None = None,
    score: torch.Tensor | bool | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def plot(  # noqa: D102
    self,
    curve: tuple[Tensor, Tensor, Tensor] | None = None,
    score: Tensor | bool | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    curve_computed = curve or self.compute()
    # switch order as the standard way is recall along x-axis and precision along y-axis
    curve_computed = (curve_computed[1], curve_computed[0], curve_computed[2])

    score = (
        _auc_compute_without_check(curve_computed[0], curve_computed[1], direction=-1.0)
        if not curve and score is True
        else None
    )
    return plot_curve(
        curve_computed, score=score, ax=ax, label_names=("Recall", "Precision"), name=self.__class__.__name__
    )

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update metric states.

Parameters:

  • preds (torch.Tensor) –

    The predicted mask. Shape: (batch_size, height, width)

  • target (torch.Tensor) –

    The target mask. Shape: (batch_size, height, width)

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update metric states.

    Args:
        preds (Tensor): The predicted mask. Shape: (batch_size, height, width)
        target (Tensor): The target mask. Shape: (batch_size, height, width)

    Raises:
        ValueError: If preds and target have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_precision_recall_curve_tensor_validation(preds, target, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if not torch.all((preds >= 0) * (preds <= 1)):
        preds = preds.sigmoid()

    if self.ignore_index is not None:
        target = (target == 1).to(torch.uint8)

    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)

    len_t = len(self.thresholds)
    confmat = self.thresholds.new_zeros((len_t, 2, 2), dtype=torch.int64)
    for i in range(len_t):
        preds_i = preds >= self.thresholds[i]

        if self.ignore_index is not None:
            invalid_idx = target == self.ignore_index
            preds_i = preds_i.clone()
            preds_i[invalid_idx] = 0  # This will prevent from counting instances in the ignored area

        instance_list_preds_i = mask_to_instances(preds_i.to(torch.uint8), self.validate_args)
        for target_i, preds_i in zip(instance_list_target, instance_list_preds_i):
            tp, fp, fn = match_instances(
                target_i,
                preds_i,
                match_threshold=self.matching_threshold,
                validate_args=self.validate_args,
            )
            confmat[i, 1, 1] += tp
            confmat[i, 0, 1] += fp
            confmat[i, 1, 0] += fn
    self.confmat += confmat

BinaryInstanceRecall

BinaryInstanceRecall(
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: typing.Literal[
        "global", "samplewise"
    ] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: typing.Any,
)

Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores

Binary instance recall metric.

Create a new instance of the BinaryInstanceStatScores metric.

Parameters:

  • threshold (float, default: 0.5 ) –

    Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.

  • matching_threshold (float, default: 0.5 ) –

    The threshold for matching instances. Defaults to 0.5.

  • multidim_average (typing.Literal['global', 'samplewise'], default: 'global' ) –

    How the average over multiple batches is calculated. Defaults to "global".

  • ignore_index (int | None, default: None ) –

    Ignores an invalid class. Defaults to None.

  • validate_args (bool, default: True ) –

    Weather to validate inputs. Defaults to True.

  • kwargs (typing.Any, default: {} ) –

    Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.

Raises:

  • ValueError

    If matching_threshold is not a float in the [0,1] range.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def __init__(
    self,
    threshold: float = 0.5,
    matching_threshold: float = 0.5,
    multidim_average: Literal["global", "samplewise"] = "global",
    ignore_index: int | None = None,
    validate_args: bool = True,
    **kwargs: Any,
) -> None:
    """Create a new instance of the BinaryInstanceStatScores metric.

    Args:
        threshold (float, optional): Threshold for binarizing the prediction.
            Has no effect if the prediction is already binarized. Defaults to 0.5.
        matching_threshold (float, optional): The threshold for matching instances. Defaults to 0.5.
        multidim_average (Literal["global", "samplewise"], optional): How the average over multiple batches is
            calculated. Defaults to "global".
        ignore_index (int | None, optional): Ignores an invalid class. Defaults to None.
        validate_args (bool, optional): Weather to validate inputs. Defaults to True.
        kwargs: Additional arguments for the Metric class, regarding compute-methods.
            Please refer to torchmetrics for more examples.

    Raises:
        ValueError: If `matching_threshold` is not a float in the [0,1] range.

    """
    zero_division = kwargs.pop("zero_division", 0)
    super(_AbstractStatScores, self).__init__(**kwargs)
    if validate_args:
        _binary_stat_scores_arg_validation(threshold, multidim_average, ignore_index, zero_division)
        if not (isinstance(matching_threshold, float) and (0 <= matching_threshold <= 1)):
            raise ValueError(
                f"Expected arg `matching_threshold` to be a float in the [0,1] range, but got {matching_threshold}."
            )

    self.threshold = threshold
    self.matching_threshold = matching_threshold
    self.multidim_average = multidim_average
    self.ignore_index = ignore_index
    self.validate_args = validate_args
    self.zero_division = zero_division

    self._create_state(size=1, multidim_average=multidim_average)

full_state_update class-attribute instance-attribute

full_state_update: bool = False

higher_is_better class-attribute instance-attribute

higher_is_better: bool | None = True

ignore_index instance-attribute

is_differentiable class-attribute instance-attribute

is_differentiable: bool = False

matching_threshold instance-attribute

multidim_average instance-attribute

plot_lower_bound class-attribute instance-attribute

plot_lower_bound: float = 0.0

plot_upper_bound class-attribute instance-attribute

plot_upper_bound: float = 1.0

threshold instance-attribute

validate_args instance-attribute

zero_division instance-attribute

zero_division = zero_division

compute

compute() -> torch.Tensor
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def compute(self) -> Tensor:  # noqa: D102
    tp, fp, tn, fn = self._final_state()
    return _precision_recall_reduce(
        "recall",
        tp,
        fp,
        tn,
        fn,
        average="binary",
        multidim_average=self.multidim_average,
        zero_division=self.zero_division,
    )

plot

plot(
    val: torch.Tensor
    | collections.abc.Sequence[torch.Tensor]
    | None = None,
    ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def plot(  # noqa: D102
    self,
    val: Tensor | Sequence[Tensor] | None = None,
    ax: _AX_TYPE | None = None,  # type: ignore
) -> _PLOT_OUT_TYPE:  # type: ignore
    return self._plot(val, ax)

update

update(preds: torch.Tensor, target: torch.Tensor) -> None

Update the metric state.

If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.

Parameters:

  • preds (torch.Tensor) –

    Predictions from model (logits or probabilities).

  • target (torch.Tensor) –

    Ground truth labels.

Raises:

  • ValueError

    If preds and target have different shapes.

  • ValueError

    If the input targets are not binary masks.

Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
def update(self, preds: Tensor, target: Tensor) -> None:
    """Update the metric state.

    If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and
    then binarized using the threshold.
    If the predictions are probabilities, they are binarized using the threshold.

    Args:
        preds (Tensor): Predictions from model (logits or probabilities).
        target (Tensor): Ground truth labels.

    Raises:
        ValueError: If `preds` and `target` have different shapes.
        ValueError: If the input targets are not binary masks.

    """
    if self.validate_args:
        _binary_stat_scores_tensor_validation(preds, target, self.multidim_average, self.ignore_index)
        if not preds.dim() == 3:
            raise ValueError(f"Expected `preds` and `target` to have 3 dimensions (BHW), but got {preds.dim()}.")
        if self.ignore_index is None and target.max() > 1:
            raise ValueError(
                "Expected binary mask, got more than 1 unique value in target."
                " You can set 'ignore_index' to ignore a class."
            )

    # Format
    if preds.is_floating_point():
        if not torch.all((preds >= 0) * (preds <= 1)):
            # preds is logits, convert with sigmoid
            preds = preds.sigmoid()
        preds = preds > self.threshold

    if self.ignore_index is not None:
        invalid_idx = target == self.ignore_index
        preds = preds.clone()
        preds[invalid_idx] = 0  # This will prevent from counting instances in the ignored area
        target = (target == 1).to(torch.uint8)

    # Update state
    instance_list_target = mask_to_instances(target.to(torch.uint8), self.validate_args)
    instance_list_preds = mask_to_instances(preds.to(torch.uint8), self.validate_args)

    for target_i, preds_i in zip(instance_list_target, instance_list_preds):
        tp, fp, fn = match_instances(
            target_i,
            preds_i,
            match_threshold=self.matching_threshold,
            validate_args=self.validate_args,
        )
        self._update_state(tp, fp, 0, fn)

BinarySegmentationMetrics

BinarySegmentationMetrics(
    *,
    bands: darts_segmentation.utils.Bands,
    val_set: str = "val",
    test_set: str = "test",
    plot_every_n_val_epochs: int = 5,
    is_crossval: bool = False,
    batch_size: int = 8,
    patch_size: int = 512,
)

Bases: lightning.pytorch.callbacks.Callback

Callback for validation metrics and visualizations.

Initialize the ValidationCallback.

Parameters:

  • bands (darts_segmentation.utils.Bands) –

    List of bands to combine for the visualization.

  • val_set (str, default: 'val' ) –

    Name of the validation set. Only used for naming the validation metrics. Defaults to "val".

  • test_set (str, default: 'test' ) –

    Name of the test set. Only used for naming the test metrics. Defaults to "test".

  • plot_every_n_val_epochs (int, default: 5 ) –

    Plot validation samples every n epochs. Defaults to 5.

  • is_crossval (bool, default: False ) –

    Whether the training is done with cross-validation. This will change the logging behavior of scalar metrics from logging to {val_set} to just "val". The logging behaviour of the samples is not affected. Defaults to False.

  • batch_size (int, default: 8 ) –

    Batch size. Needed for throughput measurements. Defaults to 8.

  • patch_size (int, default: 512 ) –

    Patch size. Needed for throughput measurements. Defaults to 512.

Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def __init__(
    self,
    *,
    bands: Bands,
    val_set: str = "val",
    test_set: str = "test",
    plot_every_n_val_epochs: int = 5,
    is_crossval: bool = False,
    batch_size: int = 8,
    patch_size: int = 512,
):
    """Initialize the ValidationCallback.

    Args:
        bands (Bands): List of bands to combine for the visualization.
        val_set (str, optional): Name of the validation set. Only used for naming the validation metrics.
            Defaults to "val".
        test_set (str, optional): Name of the test set. Only used for naming the test metrics. Defaults to "test".
        plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.
        is_crossval (bool, optional): Whether the training is done with cross-validation.
            This will change the logging behavior of scalar metrics from logging to {val_set} to just "val".
            The logging behaviour of the samples is not affected.
            Defaults to False.
        batch_size (int, optional): Batch size. Needed for throughput measurements. Defaults to 8.
        patch_size (int, optional): Patch size. Needed for throughput measurements. Defaults to 512.

    """
    assert "/" not in val_set, "val_set must not contain '/'"
    assert "/" not in test_set, "test_set must not contain '/'"
    self.val_set = val_set
    self.test_set = test_set
    self.plot_every_n_val_epochs = plot_every_n_val_epochs
    self.band_names = bands.names
    self.is_crossval = is_crossval
    self.batch_size = batch_size
    self.patch_size = patch_size

band_names instance-attribute

batch_size instance-attribute

is_crossval instance-attribute

patch_size instance-attribute

pl_module instance-attribute

pl_module: lightning.LightningModule

plot_every_n_val_epochs instance-attribute

plot_every_n_val_epochs = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
    plot_every_n_val_epochs
)

stage instance-attribute

test_cmx instance-attribute

test_cmx: torchmetrics.ConfusionMatrix

test_instance_cmx instance-attribute

test_instance_prc instance-attribute

test_metrics instance-attribute

test_metrics: torchmetrics.MetricCollection

test_prc instance-attribute

test_prc: torchmetrics.PrecisionRecallCurve

test_roc instance-attribute

test_roc: torchmetrics.ROC

test_set instance-attribute

train_metrics instance-attribute

train_metrics: torchmetrics.MetricCollection

trainer instance-attribute

trainer: lightning.Trainer

val_cmx instance-attribute

val_cmx: torchmetrics.ConfusionMatrix

val_metrics instance-attribute

val_metrics: torchmetrics.MetricCollection

val_prc instance-attribute

val_prc: torchmetrics.PrecisionRecallCurve

val_roc instance-attribute

val_roc: torchmetrics.ROC

val_set instance-attribute

is_val_plot_epoch

is_val_plot_epoch(
    current_epoch: int, check_val_every_n_epoch: int | None
) -> bool

Check if the current epoch is an epoch where validation samples should be plotted.

Parameters:

  • current_epoch (int) –

    The current epoch.

  • check_val_every_n_epoch (int | None) –

    The number of epochs to check for plotting. If None, no plotting is done.

Returns:

  • bool ( bool ) –

    True if the current epoch is a plot epoch, False otherwise.

Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def is_val_plot_epoch(self, current_epoch: int, check_val_every_n_epoch: int | None) -> bool:
    """Check if the current epoch is an epoch where validation samples should be plotted.

    Args:
        current_epoch (int): The current epoch.
        check_val_every_n_epoch (int | None): The number of epochs to check for plotting.
            If None, no plotting is done.

    Returns:
        bool: True if the current epoch is a plot epoch, False otherwise.

    """
    if check_val_every_n_epoch is None:
        return False
    n = self.plot_every_n_val_epochs * check_val_every_n_epoch
    return ((current_epoch + 1) % n) == 0 or current_epoch == 0

on_test_batch_end

on_test_batch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    outputs,
    batch,
    batch_idx,
    dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_test_batch_end(  # noqa: D102
    self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0
):
    pl_module.log(f"{self.test_set}/loss", outputs["loss"])
    _x, y = batch
    assert "y_hat" in outputs, (
        "Output does not contain 'y_hat' tensor."
        " Please make sure the 'test_step' method returns a dict with 'y_hat' and 'loss' keys."
        " The 'y_hat' should be the model's prediction (a pytorch tensor of shape [B, C, H, W])."
        " The 'loss' should be the loss value (a scalar tensor).",
    )
    y_hat = outputs["y_hat"]

    pl_module.test_metrics.update(y_hat, y)
    pl_module.test_roc.update(y_hat, y)
    pl_module.test_prc.update(y_hat, y)
    pl_module.test_cmx.update(y_hat, y)
    pl_module.test_instance_prc.update(y_hat, y)
    pl_module.test_instance_cmx.update(y_hat, y)

on_test_epoch_end

on_test_epoch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
    pl_module.test_cmx.compute()
    pl_module.test_roc.compute()
    pl_module.test_prc.compute()
    pl_module.test_instance_prc.compute()
    pl_module.test_instance_cmx.compute()

    # Plot roc, prc and confusion matrix to disk and wandb
    fig_cmx, _ = pl_module.test_cmx.plot(cmap="Blues")
    fig_roc, _ = pl_module.test_roc.plot(score=True)
    fig_prc, _ = pl_module.test_prc.plot(score=True)
    fig_instance_cmx, _ = pl_module.test_instance_cmx.plot(cmap="Blues")
    fig_instance_prc, _ = pl_module.test_instance_prc.plot(score=True)

    # Check for a wandb or csv logger to log the images
    for pllogger in pl_module.loggers:
        if isinstance(pllogger, CSVLogger):
            fig_dir = Path(pllogger.log_dir) / "figures" / f"{self.test_set}-samples"
            fig_dir.mkdir(exist_ok=True, parents=True)
            fig_cmx.savefig(fig_dir / f"cmx_{pl_module.global_step}.png")
            fig_roc.savefig(fig_dir / f"roc_{pl_module.global_step}.png")
            fig_prc.savefig(fig_dir / f"prc_{pl_module.global_step}.png")
            fig_instance_cmx.savefig(fig_dir / f"instance_cmx_{pl_module.global_step}.png")
            fig_instance_prc.savefig(fig_dir / f"instance_prc_{pl_module.global_step}.png")
        if isinstance(pllogger, WandbLogger):
            wandb_run: Run = pllogger.experiment
            wandb_run.log({f"{self.test_set}/cmx": wandb.Image(fig_cmx)}, commit=False)
            wandb_run.log({f"{self.test_set}/roc": wandb.Image(fig_roc)}, commit=False)
            wandb_run.log({f"{self.test_set}/prc": wandb.Image(fig_prc)}, commit=False)
            wandb_run.log({f"{self.test_set}/instance_cmx": wandb.Image(fig_instance_cmx)}, commit=False)
            wandb_run.log({f"{self.test_set}/instance_prc": wandb.Image(fig_instance_prc)}, commit=False)

    fig_cmx.clear()
    fig_roc.clear()
    fig_prc.clear()
    fig_instance_cmx.clear()
    fig_instance_prc.clear()
    plt.close("all")

    # This will also commit the accumulated plots
    pl_module.log_dict(pl_module.test_metrics.compute())

    pl_module.test_metrics.reset()
    pl_module.test_roc.reset()
    pl_module.test_prc.reset()
    pl_module.test_cmx.reset()
    pl_module.test_instance_prc.reset()
    pl_module.test_instance_cmx.reset()

on_train_batch_end

on_train_batch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    outputs,
    batch,
    batch_idx,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx):  # noqa: D102
    pl_module.log("train/loss", outputs["loss"])
    _, y = batch
    # Expect the output to has a tensor called "y_hat"
    assert "y_hat" in outputs, (
        "Output does not contain 'y_hat' tensor."
        " Please make sure the 'training_step' method returns a dict with 'y_hat' and 'loss' keys."
        " The 'y_hat' should be the model's prediction (a pytorch tensor of shape [B, C, H, W])."
        " The 'loss' should be the loss value (a scalar tensor).",
    )
    y_hat = outputs["y_hat"]
    pl_module.train_metrics(y_hat, y)
    pl_module.log_dict(pl_module.train_metrics, on_step=True, on_epoch=False)

on_train_epoch_end

on_train_epoch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
    pl_module.train_metrics.reset()

on_validation_batch_end

on_validation_batch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    outputs,
    batch,
    batch_idx,
    dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_validation_batch_end(  # noqa: D102
    self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0
):
    pl_module.log(f"{self._val_prefix}/loss", outputs["loss"])
    _x, y = batch
    # Expect the output to has a tensor called "y_hat"
    assert "y_hat" in outputs, (
        "Output does not contain 'y_hat' tensor."
        " Please make sure the 'validation_step' method returns a dict with 'y_hat' and 'loss' keys."
        " The 'y_hat' should be the model's prediction (a pytorch tensor of shape [B, C, H, W])."
        " The 'loss' should be the loss value (a scalar tensor).",
    )
    y_hat = outputs["y_hat"]

    pl_module.val_metrics.update(y_hat, y)
    pl_module.val_roc.update(y_hat, y)
    pl_module.val_prc.update(y_hat, y)
    pl_module.val_cmx.update(y_hat, y)

on_validation_epoch_end

on_validation_epoch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
    # Only do this every self.plot_every_n_val_epochs epochs
    is_val_plot_epoch = self.is_val_plot_epoch(pl_module.current_epoch, trainer.check_val_every_n_epoch)
    if is_val_plot_epoch and trainer.state.stage != "sanity_check":
        pl_module.val_cmx.compute()
        pl_module.val_roc.compute()
        pl_module.val_prc.compute()

        # Plot roc, prc and confusion matrix to disk and wandb
        fig_cmx, _ = pl_module.val_cmx.plot(cmap="Blues")
        fig_roc, _ = pl_module.val_roc.plot(score=True)
        fig_prc, _ = pl_module.val_prc.plot(score=True)

        # Check for a wandb or csv logger to log the images
        for pllogger in pl_module.loggers:
            if isinstance(pllogger, CSVLogger):
                fig_dir = Path(pllogger.log_dir) / "figures" / f"{self._val_prefix}-samples"
                fig_dir.mkdir(exist_ok=True, parents=True)
                fig_cmx.savefig(fig_dir / f"cmx_{pl_module.global_step}png")
                fig_roc.savefig(fig_dir / f"roc_{pl_module.global_step}png")
                fig_prc.savefig(fig_dir / f"prc_{pl_module.global_step}.png")
            if isinstance(pllogger, WandbLogger):
                wandb_run: Run = pllogger.experiment
                wandb_run.log({f"{self._val_prefix}/cmx": wandb.Image(fig_cmx)}, commit=False)
                wandb_run.log({f"{self._val_prefix}/roc": wandb.Image(fig_roc)}, commit=False)
                wandb_run.log({f"{self._val_prefix}/prc": wandb.Image(fig_prc)}, commit=False)

        fig_cmx.clear()
        fig_roc.clear()
        fig_prc.clear()
        plt.close("all")

    # This will also commit the accumulated plots
    pl_module.log_dict(pl_module.val_metrics.compute())

    pl_module.val_metrics.reset()
    pl_module.val_roc.reset()
    pl_module.val_prc.reset()
    pl_module.val_cmx.reset()

setup

setup(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    stage: darts_segmentation.training.callbacks.Stage,
)

Setups the callback.

Creates metrics required for the specific stage:

  • For the "fit" stage, creates training and validation metrics and visualizations.
  • For the "validate" stage, only creates validation metrics and visualizations.
  • For the "test" stage, only creates test metrics and visualizations.
  • For the "predict" stage, no metrics or visualizations are created.

Always maps the trainer and pl_module to the callback.

Training and validation metrics are "simple" metrics from torchmetrics. The validation visualizations are more complex metrics from torchmetrics. The test metrics and vsiualizations are the same as the validation ones, and also include custom "Instance" metrics.

Parameters:

  • trainer (lightning.Trainer) –

    The lightning trainer.

  • pl_module (lightning.LightningModule) –

    The lightning module.

  • stage (typing.Literal['fit', 'validate', 'test', 'predict']) –

    The current stage. One of: "fit", "validate", "test", "predict".

Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Stage):
    """Setups the callback.

    Creates metrics required for the specific stage:

    - For the "fit" stage, creates training and validation metrics and visualizations.
    - For the "validate" stage, only creates validation metrics and visualizations.
    - For the "test" stage, only creates test metrics and visualizations.
    - For the "predict" stage, no metrics or visualizations are created.

    Always maps the trainer and pl_module to the callback.

    Training and validation metrics are "simple" metrics from torchmetrics.
    The validation visualizations are more complex metrics from torchmetrics.
    The test metrics and vsiualizations are the same as the validation ones,
    and also include custom "Instance" metrics.

    Args:
        trainer (Trainer): The lightning trainer.
        pl_module (LightningModule): The lightning module.
        stage (Literal["fit", "validate", "test", "predict"]): The current stage.
            One of: "fit", "validate", "test", "predict".

    """
    # Save references to the trainer and pl_module
    self.trainer = trainer
    self.pl_module = pl_module
    self.stage = stage

    # We don't want to use memory in the predict stage
    if stage == "predict":
        return

    # Add throughput metric, meant to be consumed by the ThroughputMonitor callback
    # ! This will assume that the batch size does not change during training!
    with torch.device("meta"):
        model: torch.Module = copy.deepcopy(self.pl_module.model).to(device="meta")

        def sample_forward():
            batch = torch.randn(
                self.batch_size,
                len(self.band_names),
                self.patch_size,
                self.patch_size,
                device="meta",
            )
            return model(batch)

        if stage == "fit":
            # We use sum as a dummy loss function because we don't have a second input available
            self.pl_module.flops_per_batch = measure_flops(model, sample_forward, loss_fn=torch.Tensor.sum)
        else:
            # Don't compute backward pass for validation and test
            self.pl_module.flops_per_batch = measure_flops(model, sample_forward)
        logger.debug(f"FLOPS per batch: {self.pl_module.flops_per_batch}")

    metric_kwargs = {"task": "binary", "validate_args": False, "ignore_index": 2}
    metrics = MetricCollection(
        {
            "Accuracy": Accuracy(**metric_kwargs),
            "Precision": Precision(**metric_kwargs),
            "Specificity": Specificity(**metric_kwargs),
            "Recall": Recall(**metric_kwargs),
            "F1Score": F1Score(**metric_kwargs),
            "JaccardIndex": JaccardIndex(**metric_kwargs),
            "CohenKappa": CohenKappa(**metric_kwargs),
            "HammingDistance": HammingDistance(**metric_kwargs),
        }
    )

    added_metrics: defaultdict[str] = defaultdict(list)

    # Train metrics only for the fit stage
    if stage == "fit":
        pl_module.train_metrics = metrics.clone(prefix="train/")
        added_metrics["train"] += list(pl_module.train_metrics.keys(keep_base=True))
    # Validation metrics and visualizations for the fit and validate stages
    if stage == "fit" or stage == "validate":
        pl_module.val_metrics = metrics.clone(prefix=f"{self._val_prefix}/")
        pl_module.val_metrics.add_metrics(
            {
                "AUROC": AUROC(thresholds=20, **metric_kwargs),
                "AveragePrecision": AveragePrecision(thresholds=20, **metric_kwargs),
            }
        )
        pl_module.val_roc = ROC(thresholds=20, **metric_kwargs)
        pl_module.val_prc = PrecisionRecallCurve(thresholds=20, **metric_kwargs)
        pl_module.val_cmx = ConfusionMatrix(normalize="true", **metric_kwargs)
        added_metrics[self._val_prefix] += list(pl_module.val_metrics.keys(keep_base=True))
        added_metrics[self._val_prefix] += ["roc", "prc", "cmx"]

    # Test metrics and visualizations for the test stage
    if stage == "test":
        pl_module.test_metrics = metrics.clone(prefix=f"{self.test_set}/")
        pl_module.test_metrics.add_metrics(
            {
                "AUROC": AUROC(thresholds=20, **metric_kwargs),
                "AveragePrecision": AveragePrecision(thresholds=20, **metric_kwargs),
            }
        )
        pl_module.test_roc = ROC(thresholds=20, **metric_kwargs)
        pl_module.test_prc = PrecisionRecallCurve(thresholds=20, **metric_kwargs)
        pl_module.test_cmx = ConfusionMatrix(normalize="true", **metric_kwargs)

        # Instance Metrics
        instance_metric_kwargs = {"validate_args": False, "ignore_index": 2, "matching_threshold": 0.3}
        pl_module.test_metrics.add_metrics(
            {
                "InstanceAccuracy": BinaryInstanceAccuracy(**instance_metric_kwargs),
                "InstancePrecision": BinaryInstancePrecision(**instance_metric_kwargs),
                "InstanceRecall": BinaryInstanceRecall(**instance_metric_kwargs),
                "InstanceF1Score": BinaryInstanceF1Score(**instance_metric_kwargs),
                "InstanceAveragePrecision": BinaryInstanceAveragePrecision(thresholds=20, **instance_metric_kwargs),
            }
        )
        boundary_metric_kwargs = {"validate_args": False, "ignore_index": 2}
        pl_module.test_metrics.add_metrics(
            {
                "InstanceBoundaryIoU": BinaryBoundaryIoU(**boundary_metric_kwargs),
            }
        )
        pl_module.test_instance_prc = BinaryInstancePrecisionRecallCurve(thresholds=20, **instance_metric_kwargs)
        pl_module.test_instance_cmx = BinaryInstanceConfusionMatrix(normalize=True, **instance_metric_kwargs)

        added_metrics[self.test_set] += list(pl_module.test_metrics.keys(keep_base=True))
        added_metrics[self.test_set] += ["roc", "prc", "cmx", "instance_prc", "instance_cmx"]

    # Log the added metrics
    added_metrics = {k: str(sorted(v)) for k, v in added_metrics.items()}
    logger.debug(f"Added metrics:{added_metrics}")

teardown

teardown(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    stage: darts_segmentation.training.callbacks.Stage,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: Stage):  # noqa: D102
    # Delete the references to the trainer and pl_module
    del self.trainer
    del self.pl_module
    del self.stage

    # No need to delete anything if we are in the predict stage
    if stage == "predict":
        return

    if stage == "fit":
        del pl_module.train_metrics

    if stage == "fit" or stage == "validate":
        del pl_module.val_metrics
        del pl_module.val_roc
        del pl_module.val_prc
        del pl_module.val_cmx

    if stage == "test":
        del pl_module.test_metrics
        del pl_module.test_roc
        del pl_module.test_prc
        del pl_module.test_cmx

BinarySegmentationPreview

BinarySegmentationPreview(
    *,
    bands: darts_segmentation.utils.Bands,
    val_set: str = "val",
    test_set: str = "test",
    plot_every_n_val_epochs: int = 5,
)

Bases: lightning.pytorch.callbacks.Callback

Callback for validation metrics and visualizations.

Initialize the ValidationCallback.

Parameters:

  • bands (darts_segmentation.utils.Bands) –

    List of bands to combine for the visualization.

  • val_set (str, default: 'val' ) –

    Name of the validation set. Only used for naming the validation metrics. Defaults to "val".

  • test_set (str, default: 'test' ) –

    Name of the test set. Only used for naming the test metrics. Defaults to "test".

  • plot_every_n_val_epochs (int, default: 5 ) –

    Plot validation samples every n epochs. Defaults to 5.

Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def __init__(
    self,
    *,
    bands: Bands,
    val_set: str = "val",
    test_set: str = "test",
    plot_every_n_val_epochs: int = 5,
):
    """Initialize the ValidationCallback.

    Args:
        bands (Bands): List of bands to combine for the visualization.
        val_set (str, optional): Name of the validation set. Only used for naming the validation metrics.
            Defaults to "val".
        test_set (str, optional): Name of the test set. Only used for naming the test metrics. Defaults to "test".
        plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.

    """
    assert "/" not in val_set, "val_set must not contain '/'"
    assert "/" not in test_set, "test_set must not contain '/'"
    self.val_set = val_set
    self.test_set = test_set
    self.plot_every_n_val_epochs = plot_every_n_val_epochs
    self.band_names = bands.names

band_names instance-attribute

pl_module instance-attribute

pl_module: lightning.LightningModule

plot_every_n_val_epochs instance-attribute

plot_every_n_val_epochs = darts_segmentation.training.callbacks.BinarySegmentationPreview(
    plot_every_n_val_epochs
)

stage instance-attribute

test_set instance-attribute

trainer instance-attribute

trainer: lightning.Trainer

val_set instance-attribute

is_val_plot_epoch

is_val_plot_epoch(
    current_epoch: int, check_val_every_n_epoch: int | None
) -> bool

Check if the current epoch is an epoch where validation samples should be plotted.

Parameters:

  • current_epoch (int) –

    The current epoch.

  • check_val_every_n_epoch (int | None) –

    The number of epochs to check for plotting. If None, no plotting is done.

Returns:

  • bool ( bool ) –

    True if the current epoch is a plot epoch, False otherwise.

Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def is_val_plot_epoch(self, current_epoch: int, check_val_every_n_epoch: int | None) -> bool:
    """Check if the current epoch is an epoch where validation samples should be plotted.

    Args:
        current_epoch (int): The current epoch.
        check_val_every_n_epoch (int | None): The number of epochs to check for plotting.
            If None, no plotting is done.

    Returns:
        bool: True if the current epoch is a plot epoch, False otherwise.

    """
    if check_val_every_n_epoch is None:
        return False

    n = self.plot_every_n_val_epochs * check_val_every_n_epoch
    return ((current_epoch + 1) % n) == 0 or current_epoch == 0

on_test_batch_end

on_test_batch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    outputs,
    batch,
    batch_idx,
    dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_test_batch_end(  # noqa: D102
    self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0
):
    # Only do this every self.plot_every_n_val_epochs epochs
    is_val_plot_epoch = self.is_val_plot_epoch(pl_module.current_epoch, trainer.check_val_every_n_epoch)
    if not is_val_plot_epoch:
        return

    x, y = batch
    assert "y_hat" in outputs, (
        "Output does not contain 'y_hat' tensor."
        " Please make sure the 'test_step' method returns a dict with 'y_hat' and 'loss' keys."
        " The 'y_hat' should be the model's prediction (a pytorch tensor of shape [B, C, H, W])."
        " The 'loss' should be the loss value (a scalar tensor).",
    )
    y_hat = outputs["y_hat"]

    # Create figures for the samples (plot at maximum 30)
    # We want to plot at max 20 POSITIVE samples and 10 NEGATIVE samples in a single epoch
    # These should also be the same over all epochs
    for i in range(x.shape[0]):
        if self._test_pos_visualizations >= 20 and self._test_neg_visualizations >= 10:
            break

        # Plot positive sample
        if y[i].sum() > 0 and self._test_pos_visualizations < 20:
            fig, _ = plot_sample(x[i], y[i], y_hat[i], self.band_names)
            self._test_pos_visualizations += 1
        # Plot negative sample
        elif y[i].sum() == 0 and self._test_neg_visualizations < 10:
            fig, _ = plot_sample(x[i], y[i], y_hat[i], self.band_names)
            self._test_neg_visualizations += 1
        # Either the number of positive or negative samples is already full
        else:
            continue

        for pllogger in pl_module.loggers:
            if isinstance(pllogger, CSVLogger):
                fig_dir = Path(pllogger.log_dir) / "figures" / f"{self.test_set}-samples"
                fig_dir.mkdir(exist_ok=True, parents=True)
                fig.savefig(fig_dir / f"sample_{pl_module.global_step}_{batch_idx}_{i}.png")
            if isinstance(pllogger, WandbLogger):
                wandb_run: Run = pllogger.experiment
                # We don't commit the log yet, so that the step is increased with the next lightning log
                # Which happens at the end of the validation epoch
                img_name = f"{self.test_set}-samples/sample_{batch_idx}_{i}"
                wandb_run.log({img_name: wandb.Image(fig)}, commit=False)
        fig.clear()
        plt.close(fig)

on_test_epoch_end

on_test_epoch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
    self._test_pos_visualizations = 0
    self._test_neg_visualizations = 0

on_validation_batch_end

on_validation_batch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    outputs,
    batch,
    batch_idx,
    dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_validation_batch_end(  # noqa: D102
    self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0
):
    # Only do this every self.plot_every_n_val_epochs epochs
    is_val_plot_epoch = self.is_val_plot_epoch(pl_module.current_epoch, trainer.check_val_every_n_epoch)
    if not is_val_plot_epoch:
        return

    x, y = batch
    # Expect the output to has a tensor called "y_hat"
    assert "y_hat" in outputs, (
        "Output does not contain 'y_hat' tensor."
        " Please make sure the 'validation_step' method returns a dict with 'y_hat' and 'loss' keys."
        " The 'y_hat' should be the model's prediction (a pytorch tensor of shape [B, C, H, W])."
        " The 'loss' should be the loss value (a scalar tensor).",
    )
    y_hat = outputs["y_hat"]

    # Create figures for the samples (plot at maximum 30)
    # We want to plot at max 20 POSITIVE samples and 10 NEGATIVE samples in a single epoch
    # These should also be the same over all epochs
    for i in range(x.shape[0]):
        if self._val_pos_visualizations >= 20 and self._val_neg_visualizations >= 10:
            break

        # Don't plot in sanity check
        if trainer.state.stage == "sanity_check":
            break

        # Plot positive sample
        is_postive = (y[i] == 1).sum() > 0
        if is_postive and self._val_pos_visualizations < 20:
            fig, _ = plot_sample(x[i], y[i], y_hat[i], self.band_names)
            self._val_pos_visualizations += 1
        # Plot negative sample
        elif not is_postive and self._val_neg_visualizations < 10:
            fig, _ = plot_sample(x[i], y[i], y_hat[i], self.band_names)
            self._val_neg_visualizations += 1
        # Either the number of positive or negative samples is already full
        else:
            continue

        for pllogger in pl_module.loggers:
            if isinstance(pllogger, CSVLogger):
                fig_dir = Path(pllogger.log_dir) / "figures" / f"{self.val_set}-samples"
                fig_dir.mkdir(exist_ok=True, parents=True)
                fig.savefig(fig_dir / f"sample_{pl_module.global_step}_{batch_idx}_{i}.png")
            if isinstance(pllogger, WandbLogger):
                wandb_run: Run = pllogger.experiment
                # We don't commit the log yet, so that the step is increased with the next lightning log
                # Which happens at the end of the validation epoch
                img_name = f"{self.val_set}-samples/sample_{batch_idx}_{i}"
                wandb_run.log({img_name: wandb.Image(fig)}, commit=False)
        fig.clear()
        plt.close(fig)

on_validation_epoch_end

on_validation_epoch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
    self._val_pos_visualizations = 0
    self._val_neg_visualizations = 0

setup

setup(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    stage: darts_segmentation.training.callbacks.Stage,
)

Setups the callback.

Parameters:

  • trainer (lightning.Trainer) –

    The lightning trainer.

  • pl_module (lightning.LightningModule) –

    The lightning module.

  • stage (typing.Literal['fit', 'validate', 'test', 'predict']) –

    The current stage. One of: "fit", "validate", "test", "predict".

Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Stage):
    """Setups the callback.

    Args:
        trainer (Trainer): The lightning trainer.
        pl_module (LightningModule): The lightning module.
        stage (Literal["fit", "validate", "test", "predict"]): The current stage.
            One of: "fit", "validate", "test", "predict".

    """
    # We don't want to use memory in the predict stage
    if stage == "predict":
        return

    # Validation metrics and visualizations for the fit and validate stages
    if stage == "fit" or stage == "validate":
        # Internal state to track how many visualizations have been generated in an epoch
        self._val_pos_visualizations = 0
        self._val_neg_visualizations = 0

    # Test metrics and visualizations for the test stage
    if stage == "test":
        # Internal state to track how many visualizations have been generated in an epoch
        self._test_pos_visualizations = 0
        self._test_neg_visualizations = 0

plot_sample

plot_sample(
    x: torch.Tensor,
    y: torch.Tensor,
    y_pred: torch.Tensor,
    band_names: list[str],
) -> tuple[
    matplotlib.pyplot.Figure,
    dict[str, matplotlib.pyplot.Axes],
]

Plot a single sample with the input, the ground truth and the prediction.

This function does a few expections on the input: - The input is expected to be normalized to 0-1. - The prediction is expected to be converted from logits to prediction. - The target is expected to be a int or long tensor with values of: 0 (negative class) 1 (positive class) and 2 (invalid pixels).

Parameters:

  • x (torch.Tensor) –

    The input tensor [C, H, W] (float).

  • y (torch.Tensor) –

    The ground truth tensor [H, W] (int).

  • y_pred (torch.Tensor) –

    The prediction tensor [H, W] (float).

  • band_names (list[str]) –

    The combinations of the input bands.

Returns:

  • tuple[matplotlib.pyplot.Figure, dict[str, matplotlib.pyplot.Axes]]

    tuple[Figure, dict[str, Axes]]: The figure and the axes of the plot.

Source code in darts-segmentation/src/darts_segmentation/training/viz.py
def plot_sample(
    x: torch.Tensor, y: torch.Tensor, y_pred: torch.Tensor, band_names: list[str]
) -> tuple[plt.Figure, dict[str, plt.Axes]]:
    """Plot a single sample with the input, the ground truth and the prediction.

    This function does a few expections on the input:
    - The input is expected to be normalized to 0-1.
    - The prediction is expected to be converted from logits to prediction.
    - The target is expected to be a int or long tensor with values of:
        0 (negative class)
        1 (positive class) and
        2 (invalid pixels).

    Args:
        x (torch.Tensor): The input tensor [C, H, W] (float).
        y (torch.Tensor): The ground truth tensor [H, W] (int).
        y_pred (torch.Tensor): The prediction tensor [H, W] (float).
        band_names (list[str]): The combinations of the input bands.

    Returns:
        tuple[Figure, dict[str, Axes]]: The figure and the axes of the plot.

    """
    x = x.cpu()
    y = y.cpu()
    y_pred = y_pred.detach().cpu()

    # Make y class 2 invalids (replace 2 with nan)
    x = x.where(y != 2, torch.nan)
    y_pred = y_pred.where(y != 2, torch.nan)
    y = y.where(y != 2, torch.nan)

    # pred == 0, y == 0 -> 0 (true negative)
    # pred == 1, y == 0 -> 1 (false positive)
    # pred == 0, y == 1 -> 2 (false negative)
    # pred == 1, y == 1 -> 3 (true positive)
    classification_labels = (y_pred > 0.5).int() + y * 2
    classification_labels = classification_labels.where(classification_labels != 0, torch.nan)

    # Calculate f1 and iou
    true_positive = (classification_labels == 3).sum()
    false_positive = (classification_labels == 1).sum()
    false_negative = (classification_labels == 2).sum()
    true_negative = (classification_labels == 0).sum()
    acc = (true_positive + true_negative) / (true_positive + true_negative + false_positive + false_negative)
    f1 = 2 * true_positive / (2 * true_positive + false_positive + false_negative)
    iou = true_positive / (true_positive + false_positive + false_negative)

    cmap = mcolors.ListedColormap(["#cd43b2", "#3e0f2f", "#6cd875"])
    fig, axs = plt.subplot_mosaic(
        # [["rgb", "rgb", "ndvi", "tcvis", "stats"], ["rgb", "rgb", "pred", "slope", "elev"]],
        [["rgb", "rgb", "pred", "tcvis"], ["rgb", "rgb", "ndvi", "slope"], ["none", "stats", "stats", "stats"]],
        # layout="constrained",
        figsize=(11, 8),
    )

    # Disable none plot
    axs["none"].axis("off")

    # RGB Plot
    ax_rgb = axs["rgb"]
    # disable axis
    ax_rgb.axis("off")
    is_rgb = "red" in band_names and "green" in band_names and "blue" in band_names
    if is_rgb:
        red_band = band_names.index("red")
        green_band = band_names.index("green")
        blue_band = band_names.index("blue")
        rgb = x[[red_band, green_band, blue_band]].transpose(0, 2).transpose(0, 1)
        ax_rgb.imshow(rgb ** (1 / 1.4))
        ax_rgb.set_title(f"Acc: {acc:.1%} F1: {f1:.1%} IoU: {iou:.1%}")
    else:
        # Plot empty with message that RGB is not provided
        ax_rgb.set_title("No RGB values are provided!")
    ax_rgb.imshow(classification_labels, alpha=0.6, cmap=cmap, vmin=1, vmax=3)
    # Add a legend
    patches = [
        mpatches.Patch(color="#6cd875", label="True Positive"),
        mpatches.Patch(color="#3e0f2f", label="False Negative"),
        mpatches.Patch(color="#cd43b2", label="False Positive"),
    ]
    ax_rgb.legend(handles=patches, loc="upper left")

    # NDVI Plot
    ax_ndvi = axs["ndvi"]
    ax_ndvi.axis("off")
    is_ndvi = "ndvi" in band_names
    if is_ndvi:
        ndvi_band = band_names.index("ndvi")
        ndvi = x[ndvi_band]
        ax_ndvi.imshow(ndvi, vmin=0, vmax=1, cmap="RdYlGn")
        ax_ndvi.set_title("NDVI")
    else:
        # Plot empty with message that NDVI is not provided
        ax_ndvi.set_title("No NDVI values are provided!")

    # TCVIS Plot
    ax_tcv = axs["tcvis"]
    ax_tcv.axis("off")
    is_tcvis = "tc_brightness" in band_names and "tc_greenness" in band_names and "tc_wetness" in band_names
    if is_tcvis:
        tcb_band = band_names.index("tc_brightness")
        tcg_band = band_names.index("tc_greenness")
        tcw_band = band_names.index("tc_wetness")
        tcvis = x[[tcb_band, tcg_band, tcw_band]].transpose(0, 2).transpose(0, 1)
        ax_tcv.imshow(tcvis)
        ax_tcv.set_title("TCVIS")
    else:
        ax_tcv.set_title("No TCVIS values are provided!")

    # Statistics Plot
    ax_stat = axs["stats"]
    if (y == 1).sum() > 0:
        n_bands = x.shape[0]
        n_pixel = x.shape[1] * x.shape[2]
        x_flat = x.flatten().cpu()
        y_flat = y.flatten().repeat(n_bands).cpu()
        bands = list(itertools.chain.from_iterable([band_names[i]] * n_pixel for i in range(n_bands)))
        plot_data = pd.DataFrame({"x": x_flat, "y": y_flat, "band": bands})
        if len(plot_data) > 50000:
            plot_data = plot_data.sample(50000)
        plot_data = plot_data.sort_values("band")
        sns.violinplot(
            x="x",
            y="band",
            hue="y",
            data=plot_data,
            split=True,
            inner="quart",
            fill=False,
            palette={1: "g", 0: ".35"},
            density_norm="width",
            ax=ax_stat,
        )
        ax_stat.set_title("Band Statistics")
    else:
        ax_stat.set_title("No positive labels in this sample!")
        ax_stat.axis("off")

    # Prediction Plot
    ax_mask = axs["pred"]
    ax_mask.imshow(y_pred, vmin=0, vmax=1)
    ax_mask.axis("off")
    ax_mask.set_title("Model Output")

    # Slope Plot
    ax_slope = axs["slope"]
    ax_slope.axis("off")
    is_slope = "slope" in band_names
    if is_slope:
        slope_band = band_names.index("slope")
        slope = x[slope_band]
        ax_slope.imshow(slope, cmap="cividis")
        # Add TPI as contour lines
        is_rel_elev = "relative_elevation" in band_names
        if is_rel_elev:
            rel_elev_band = band_names.index("relative_elevation")
            rel_elev = x[rel_elev_band]
            cs = ax_slope.contour(rel_elev, [0], colors="red", linewidths=0.3, alpha=0.6)
            ax_slope.clabel(cs, inline=True, fontsize=5, fmt="%.1f")

        ax_slope.set_title("Slope")
    else:
        # Plot empty with message that slope is not provided
        ax_slope.set_title("No Slope values are provided!")

    # Relative Elevation Plot
    # rel_elev_band = band_names.index("relative_elevation")
    # rel_elev = x[rel_elev_band]
    # ax_rel_elev = axs["elev"]
    # ax_rel_elev.imshow(rel_elev, cmap="cividis")
    # ax_rel_elev.axis("off")
    # ax_rel_elev.set_title("Relative Elevation")

    return fig, axs