binary_instance_prc
darts_segmentation.metrics.binary_instance_prc
¶
Complex binary instance segmentation metrics.
BinaryInstanceAveragePrecision
¶
BinaryInstanceAveragePrecision(
thresholds: int | list[float] | torch.Tensor = None,
matching_threshold: float = 0.5,
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: typing.Any,
)
Bases: darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve
Compute the average precision for binary instance segmentation.
Create a new instance of the BinaryInstancePrecisionRecallCurve metric.
Parameters:
-
thresholds
(int | list[float] | torch.Tensor
, default:None
) –The thresholds to use for the curve. Defaults to None.
-
matching_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Raises:
-
ValueError
–If thresholds is None.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
matching_threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
validate_args
)
compute
¶
plot
¶
plot(
val: torch.Tensor
| collections.abc.Sequence[torch.Tensor]
| None = None,
ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
update
¶
Update metric states.
Parameters:
-
preds
(torch.Tensor
) –The predicted mask. Shape: (batch_size, height, width)
-
target
(torch.Tensor
) –The target mask. Shape: (batch_size, height, width)
Raises:
-
ValueError
–If preds and target have different shapes.
-
ValueError
–If the input targets are not binary masks.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
BinaryInstancePrecisionRecallCurve
¶
BinaryInstancePrecisionRecallCurve(
thresholds: int | list[float] | torch.Tensor = None,
matching_threshold: float = 0.5,
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: typing.Any,
)
Bases: torchmetrics.Metric
Compute the precision-recall curve for binary instance segmentation.
This metric works similar to torchmetrics.classification.PrecisionRecallCurve
, with two key differences:
1. It calculates the tp, fp, fn values for each instance (blob) in the batch, and then aggregates them.
Instead of calculating the values for each pixel.
2. The "thresholds" argument is required.
Calculating the thresholds at the compute stage would cost to much memory for this usecase.
Create a new instance of the BinaryInstancePrecisionRecallCurve metric.
Parameters:
-
thresholds
(int | list[float] | torch.Tensor
, default:None
) –The thresholds to use for the curve. Defaults to None.
-
matching_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Raises:
-
ValueError
–If thresholds is None.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
matching_threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
validate_args
)
compute
¶
plot
¶
plot(
curve: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| None = None,
score: torch.Tensor | bool | None = None,
ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
update
¶
Update metric states.
Parameters:
-
preds
(torch.Tensor
) –The predicted mask. Shape: (batch_size, height, width)
-
target
(torch.Tensor
) –The target mask. Shape: (batch_size, height, width)
Raises:
-
ValueError
–If preds and target have different shapes.
-
ValueError
–If the input targets are not binary masks.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
mask_to_instances
¶
Convert a binary segmentation mask into multiple instance masks. Expects a batched version of the input.
Currently only supports uint8 tensors, hence a maximum number of 255 instances per mask.
Parameters:
-
x
(torch.Tensor
) –The binary segmentation mask. Shape: (batch_size, height, width), dtype: torch.uint8
-
validate_args
(bool
, default:False
) –Whether to validate the input arguments. Defaults to False.
Returns:
-
list[torch.Tensor]
–list[torch.Tensor]: The instance masks. Length of list: batch_size. Shape of a tensor: (height, width), dtype: torch.uint8
Source code in darts-segmentation/src/darts_segmentation/metrics/instance_helpers.py
match_instances
¶
match_instances(
instances_target: torch.Tensor,
instances_preds: torch.Tensor,
match_threshold: float = 0.5,
validate_args: bool = False,
) -> tuple[int, int, int]
Match instances between target and prediction masks. Expects non-batched input from skimage.measure.label.
Parameters:
-
instances_target
(torch.Tensor
) –The instance mask of the target. Shape: (height, width), dtype: torch.uint8
-
instances_preds
(torch.Tensor
) –The instance mask of the prediction. Shape: (height, width), dtype: torch.uint8
-
match_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
validate_args
(bool
, default:False
) –Whether to validate the input arguments. Defaults to False.
Returns: