darts_segmentation.metrics.BinaryInstancePrecisionRecallCurve¶
Bases: torchmetrics.Metric
Compute the precision-recall curve for binary instance segmentation.
This metric works similar to torchmetrics.classification.PrecisionRecallCurve
, with two key differences:
1. It calculates the tp, fp, fn values for each instance (blob) in the batch, and then aggregates them.
Instead of calculating the values for each pixel.
2. The "thresholds" argument is required.
Calculating the thresholds at the compute stage would cost to much memory for this usecase.
Create a new instance of the BinaryInstancePrecisionRecallCurve metric.
Parameters:
-
thresholds
(int | list[float] | torch.Tensor
, default:None
) –The thresholds to use for the curve. Defaults to None.
-
matching_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Raises:
-
ValueError
–If thresholds is None.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
matching_threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
validate_args
)
compute
¶
plot
¶
plot(
curve: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| None = None,
score: torch.Tensor | bool | None = None,
ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
update
¶
Update metric states.
Parameters:
-
preds
(torch.Tensor
) –The predicted mask. Shape: (batch_size, height, width)
-
target
(torch.Tensor
) –The target mask. Shape: (batch_size, height, width)
Raises:
-
ValueError
–If preds and target have different shapes.
-
ValueError
–If the input targets are not binary masks.