callbacks
darts_segmentation.training.callbacks
¶
PyTorch Lightning Callbacks for training and validation.
Bands
¶
Bases: collections.UserList[darts_segmentation.utils.Band]
Wrapper for the list of bands.
factors
property
¶
names
property
¶
offsets
property
¶
__reduce__
¶
Source code in darts-segmentation/src/darts_segmentation/utils.py
filter
¶
filter(
band_names: list[str],
) -> darts_segmentation.utils.Bands
Filter the bands by name.
Parameters:
Returns:
-
Bands
(darts_segmentation.utils.Bands
) –The filtered Bands object.
Source code in darts-segmentation/src/darts_segmentation/utils.py
from_config
classmethod
¶
from_config(
config: dict[
typing.Literal[
"bands", "band_factors", "band_offsets"
],
list,
]
| dict[str, tuple[float, float]],
) -> darts_segmentation.utils.Bands
Create a Bands object from a config dictionary.
Parameters:
-
config
(dict
) –The config dictionary containing the band information. Expects config to be a dictionary with keys "bands", "band_factors" and "band_offsets", with the values to be lists of the same length.
Returns:
-
Bands
(darts_segmentation.utils.Bands
) –The Bands object.
Source code in darts-segmentation/src/darts_segmentation/utils.py
from_dict
classmethod
¶
Create a Bands object from a dictionary.
Parameters:
-
config
(dict[str, tuple[float, float]]
) –The dictionary containing the band information. Expects the keys to be the band names and the values to be tuples of (factor, offset). Example: {"band1": (1.0, 0.0), "band2": (2.0, 1.0)}
Returns:
-
Bands
(darts_segmentation.utils.Bands
) –The Bands object.
Source code in darts-segmentation/src/darts_segmentation/utils.py
to_config
¶
Convert the Bands object to a config dictionary.
Returns:
-
dict
(dict[typing.Literal['bands', 'band_factors', 'band_offsets'], list]
) –The config dictionary containing the band information.
Source code in darts-segmentation/src/darts_segmentation/utils.py
BinaryBoundaryIoU
¶
BinaryBoundaryIoU(
dilation: float | int = 0.02,
threshold: float = 0.5,
multidim_average: typing.Literal[
"global", "samplewise"
] = "global",
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: typing.Unpack[
darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoUKwargs
],
)
Bases: torchmetrics.Metric
Binary Boundary IoU metric for binary segmentation tasks.
This metric is similar to the Binary Intersection over Union (IoU or Jaccard Index) metric, but instead of comparing all pixels it only compares the boundaries of each foreground object.
Create a new instance of the BinaryBoundaryIoU metric.
Please see the torchmetrics docs for more info about the **kwargs.
Parameters:
-
dilation
(float | int
, default:0.02
) –The dilation (factor) / width of the boundary. Dilation in pixels if int, else ratio to calculate
dilation = dilation_ratio * image_diagonal
. Default: 0.02 -
threshold
(float
, default:0.5
) –Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.
-
multidim_average
(typing.Literal['global', 'samplewise']
, default:'global'
) –How the average over multiple batches is calculated. Defaults to "global".
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
**kwargs
(typing.Unpack[darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoUKwargs]
, default:{}
) –Additional keyword arguments for the metric.
Other Parameters:
-
zero_division
(int
) –Value to return when there is a zero division. Default is 0.
-
compute_on_cpu
(bool
) –If metric state should be stored on CPU during computations. Only works for list states.
-
dist_sync_on_step
(bool
) –If metric state should synchronize on
forward()
. Default isFalse
. -
process_group
(str
) –The process group on which the synchronization is called. Default is the world.
-
dist_sync_fn
(callable
) –Function that performs the allgather option on the metric state. Default is a custom implementation that calls
torch.distributed.all_gather
internally. -
distributed_available_fn
(callable
) –Function that checks if the distributed backend is available. Defaults to a check of
torch.distributed.is_available()
andtorch.distributed.is_initialized()
. -
sync_on_compute
(bool
) –If metric state should synchronize when
compute
is called. Default isTrue
. -
compute_with_cache
(bool
) –If results from
compute
should be cached. Default isTrue
.
Raises:
-
ValueError
–If dilation is not a float or int.
Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
dilation
instance-attribute
¶
dilation = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
dilation
)
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
ignore_index
)
multidim_average
instance-attribute
¶
multidim_average = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
multidim_average
)
threshold
instance-attribute
¶
threshold = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.boundary_iou.BinaryBoundaryIoU(
validate_args
)
compute
¶
Compute the metric.
Returns:
Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
update
¶
Update the metric state.
If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.
Parameters:
-
preds
(torch.Tensor
) –Predictions from model (logits or probabilities).
-
target
(torch.Tensor
) –Ground truth labels.
Raises:
-
ValueError
–If the input arguments are invalid.
-
ValueError
–If the input shapes are invalid.
Source code in darts-segmentation/src/darts_segmentation/metrics/boundary_iou.py
BinaryInstanceAccuracy
¶
BinaryInstanceAccuracy(
threshold: float = 0.5,
matching_threshold: float = 0.5,
multidim_average: typing.Literal[
"global", "samplewise"
] = "global",
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: typing.Any,
)
Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores
Binary instance accuracy metric.
Create a new instance of the BinaryInstanceStatScores metric.
Parameters:
-
threshold
(float
, default:0.5
) –Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.
-
matching_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
multidim_average
(typing.Literal['global', 'samplewise']
, default:'global'
) –How the average over multiple batches is calculated. Defaults to "global".
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Raises:
-
ValueError
–If
matching_threshold
is not a float in the [0,1] range.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
matching_threshold
)
multidim_average
instance-attribute
¶
multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
multidim_average
)
threshold
instance-attribute
¶
threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
validate_args
)
compute
¶
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
plot
¶
update
¶
Update the metric state.
If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.
Parameters:
-
preds
(torch.Tensor
) –Predictions from model (logits or probabilities).
-
target
(torch.Tensor
) –Ground truth labels.
Raises:
-
ValueError
–If
preds
andtarget
have different shapes. -
ValueError
–If the input targets are not binary masks.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
BinaryInstanceAveragePrecision
¶
BinaryInstanceAveragePrecision(
thresholds: int | list[float] | torch.Tensor = None,
matching_threshold: float = 0.5,
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: typing.Any,
)
Bases: darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve
Compute the average precision for binary instance segmentation.
Create a new instance of the BinaryInstancePrecisionRecallCurve metric.
Parameters:
-
thresholds
(int | list[float] | torch.Tensor
, default:None
) –The thresholds to use for the curve. Defaults to None.
-
matching_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Raises:
-
ValueError
–If thresholds is None.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
matching_threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
validate_args
)
compute
¶
plot
¶
plot(
val: torch.Tensor
| collections.abc.Sequence[torch.Tensor]
| None = None,
ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
update
¶
Update metric states.
Parameters:
-
preds
(torch.Tensor
) –The predicted mask. Shape: (batch_size, height, width)
-
target
(torch.Tensor
) –The target mask. Shape: (batch_size, height, width)
Raises:
-
ValueError
–If preds and target have different shapes.
-
ValueError
–If the input targets are not binary masks.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
BinaryInstanceConfusionMatrix
¶
BinaryInstanceConfusionMatrix(
normalize: bool | None = None,
threshold: float = 0.5,
matching_threshold: float = 0.5,
multidim_average: typing.Literal[
"global", "samplewise"
] = "global",
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: typing.Any,
)
Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores
Binary instance confusion matrix metric.
Create a new instance of the BinaryInstanceConfusionMatrix metric.
Parameters:
-
normalize
(bool
, default:None
) –If True, return the confusion matrix normalized by the number of instances. If False, return the confusion matrix without normalization. Defaults to None.
-
threshold
(float
, default:0.5
) –Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.
-
matching_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
multidim_average
(typing.Literal['global', 'samplewise']
, default:'global'
) –How the average over multiple batches is calculated. Defaults to "global".
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Raises:
-
ValueError
–If
normalize
is not a bool.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
matching_threshold
)
multidim_average
instance-attribute
¶
multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
multidim_average
)
normalize
instance-attribute
¶
normalize = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceConfusionMatrix(
normalize
)
threshold
instance-attribute
¶
threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
validate_args
)
compute
¶
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
plot
¶
plot(
val: torch.Tensor | None = None,
ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
add_text: bool = True,
labels: list[str] | None = None,
cmap: torchmetrics.utilities.plot._CMAP_TYPE
| None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
update
¶
Update the metric state.
If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.
Parameters:
-
preds
(torch.Tensor
) –Predictions from model (logits or probabilities).
-
target
(torch.Tensor
) –Ground truth labels.
Raises:
-
ValueError
–If
preds
andtarget
have different shapes. -
ValueError
–If the input targets are not binary masks.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
BinaryInstanceF1Score
¶
BinaryInstanceF1Score(
threshold: float = 0.5,
multidim_average: typing.Literal[
"global", "samplewise"
] = "global",
ignore_index: int | None = None,
validate_args: bool = True,
zero_division: float = 0,
**kwargs: typing.Any,
)
Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore
Binary instance F1 score metric.
Create a new instance of the BinaryInstanceF1Score metric.
Parameters:
-
threshold
(float
, default:0.5
) –Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.
-
multidim_average
(typing.Literal['global', 'samplewise']
, default:'global'
) –How the average over multiple batches is calculated. Defaults to "global".
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
zero_division
(float
, default:0
) –Value to return when there is a zero division. Defaults to 0.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
beta
instance-attribute
¶
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
matching_threshold
)
multidim_average
instance-attribute
¶
multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
multidim_average
)
threshold
instance-attribute
¶
threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore(
validate_args
)
zero_division
instance-attribute
¶
zero_division = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceFBetaScore(
zero_division
)
compute
¶
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
plot
¶
update
¶
Update the metric state.
If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.
Parameters:
-
preds
(torch.Tensor
) –Predictions from model (logits or probabilities).
-
target
(torch.Tensor
) –Ground truth labels.
Raises:
-
ValueError
–If
preds
andtarget
have different shapes. -
ValueError
–If the input targets are not binary masks.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
BinaryInstancePrecision
¶
BinaryInstancePrecision(
threshold: float = 0.5,
matching_threshold: float = 0.5,
multidim_average: typing.Literal[
"global", "samplewise"
] = "global",
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: typing.Any,
)
Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores
Binary instance precision metric.
Create a new instance of the BinaryInstanceStatScores metric.
Parameters:
-
threshold
(float
, default:0.5
) –Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.
-
matching_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
multidim_average
(typing.Literal['global', 'samplewise']
, default:'global'
) –How the average over multiple batches is calculated. Defaults to "global".
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Raises:
-
ValueError
–If
matching_threshold
is not a float in the [0,1] range.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
matching_threshold
)
multidim_average
instance-attribute
¶
multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
multidim_average
)
threshold
instance-attribute
¶
threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
validate_args
)
compute
¶
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
plot
¶
update
¶
Update the metric state.
If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.
Parameters:
-
preds
(torch.Tensor
) –Predictions from model (logits or probabilities).
-
target
(torch.Tensor
) –Ground truth labels.
Raises:
-
ValueError
–If
preds
andtarget
have different shapes. -
ValueError
–If the input targets are not binary masks.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
BinaryInstancePrecisionRecallCurve
¶
BinaryInstancePrecisionRecallCurve(
thresholds: int | list[float] | torch.Tensor = None,
matching_threshold: float = 0.5,
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: typing.Any,
)
Bases: torchmetrics.Metric
Compute the precision-recall curve for binary instance segmentation.
This metric works similar to torchmetrics.classification.PrecisionRecallCurve
, with two key differences:
1. It calculates the tp, fp, fn values for each instance (blob) in the batch, and then aggregates them.
Instead of calculating the values for each pixel.
2. The "thresholds" argument is required.
Calculating the thresholds at the compute stage would cost to much memory for this usecase.
Create a new instance of the BinaryInstancePrecisionRecallCurve metric.
Parameters:
-
thresholds
(int | list[float] | torch.Tensor
, default:None
) –The thresholds to use for the curve. Defaults to None.
-
matching_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Raises:
-
ValueError
–If thresholds is None.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
matching_threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_prc.BinaryInstancePrecisionRecallCurve(
validate_args
)
compute
¶
plot
¶
plot(
curve: tuple[torch.Tensor, torch.Tensor, torch.Tensor]
| None = None,
score: torch.Tensor | bool | None = None,
ax: torchmetrics.utilities.plot._AX_TYPE | None = None,
) -> torchmetrics.utilities.plot._PLOT_OUT_TYPE
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
update
¶
Update metric states.
Parameters:
-
preds
(torch.Tensor
) –The predicted mask. Shape: (batch_size, height, width)
-
target
(torch.Tensor
) –The target mask. Shape: (batch_size, height, width)
Raises:
-
ValueError
–If preds and target have different shapes.
-
ValueError
–If the input targets are not binary masks.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_prc.py
BinaryInstanceRecall
¶
BinaryInstanceRecall(
threshold: float = 0.5,
matching_threshold: float = 0.5,
multidim_average: typing.Literal[
"global", "samplewise"
] = "global",
ignore_index: int | None = None,
validate_args: bool = True,
**kwargs: typing.Any,
)
Bases: darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores
Binary instance recall metric.
Create a new instance of the BinaryInstanceStatScores metric.
Parameters:
-
threshold
(float
, default:0.5
) –Threshold for binarizing the prediction. Has no effect if the prediction is already binarized. Defaults to 0.5.
-
matching_threshold
(float
, default:0.5
) –The threshold for matching instances. Defaults to 0.5.
-
multidim_average
(typing.Literal['global', 'samplewise']
, default:'global'
) –How the average over multiple batches is calculated. Defaults to "global".
-
ignore_index
(int | None
, default:None
) –Ignores an invalid class. Defaults to None.
-
validate_args
(bool
, default:True
) –Weather to validate inputs. Defaults to True.
-
kwargs
(typing.Any
, default:{}
) –Additional arguments for the Metric class, regarding compute-methods. Please refer to torchmetrics for more examples.
Raises:
-
ValueError
–If
matching_threshold
is not a float in the [0,1] range.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
ignore_index
instance-attribute
¶
ignore_index = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
ignore_index
)
matching_threshold
instance-attribute
¶
matching_threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
matching_threshold
)
multidim_average
instance-attribute
¶
multidim_average = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
multidim_average
)
threshold
instance-attribute
¶
threshold = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
threshold
)
validate_args
instance-attribute
¶
validate_args = darts_segmentation.metrics.binary_instance_stat_scores.BinaryInstanceStatScores(
validate_args
)
compute
¶
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
plot
¶
update
¶
Update the metric state.
If the predictions are logits (not between 0 and 1), they are converted to probabilities using a sigmoid and then binarized using the threshold. If the predictions are probabilities, they are binarized using the threshold.
Parameters:
-
preds
(torch.Tensor
) –Predictions from model (logits or probabilities).
-
target
(torch.Tensor
) –Ground truth labels.
Raises:
-
ValueError
–If
preds
andtarget
have different shapes. -
ValueError
–If the input targets are not binary masks.
Source code in darts-segmentation/src/darts_segmentation/metrics/binary_instance_stat_scores.py
BinarySegmentationMetrics
¶
BinarySegmentationMetrics(
*,
bands: darts_segmentation.utils.Bands,
val_set: str = "val",
test_set: str = "test",
plot_every_n_val_epochs: int = 5,
is_crossval: bool = False,
batch_size: int = 8,
patch_size: int = 512,
)
Bases: lightning.pytorch.callbacks.Callback
Callback for validation metrics and visualizations.
Initialize the ValidationCallback.
Parameters:
-
bands
(darts_segmentation.utils.Bands
) –List of bands to combine for the visualization.
-
val_set
(str
, default:'val'
) –Name of the validation set. Only used for naming the validation metrics. Defaults to "val".
-
test_set
(str
, default:'test'
) –Name of the test set. Only used for naming the test metrics. Defaults to "test".
-
plot_every_n_val_epochs
(int
, default:5
) –Plot validation samples every n epochs. Defaults to 5.
-
is_crossval
(bool
, default:False
) –Whether the training is done with cross-validation. This will change the logging behavior of scalar metrics from logging to {val_set} to just "val". The logging behaviour of the samples is not affected. Defaults to False.
-
batch_size
(int
, default:8
) –Batch size. Needed for throughput measurements. Defaults to 8.
-
patch_size
(int
, default:512
) –Patch size. Needed for throughput measurements. Defaults to 512.
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
band_names
instance-attribute
¶
band_names = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
bands
).names
batch_size
instance-attribute
¶
batch_size = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
batch_size
)
is_crossval
instance-attribute
¶
is_crossval = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
is_crossval
)
patch_size
instance-attribute
¶
patch_size = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
patch_size
)
plot_every_n_val_epochs
instance-attribute
¶
plot_every_n_val_epochs = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
plot_every_n_val_epochs
)
test_instance_cmx
instance-attribute
¶
test_instance_cmx: (
darts_segmentation.metrics.BinaryInstanceConfusionMatrix
)
test_instance_prc
instance-attribute
¶
test_instance_prc: darts_segmentation.metrics.BinaryInstancePrecisionRecallCurve
test_set
instance-attribute
¶
test_set = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
test_set
)
val_set
instance-attribute
¶
val_set = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
val_set
)
is_val_plot_epoch
¶
Check if the current epoch is an epoch where validation samples should be plotted.
Parameters:
-
current_epoch
(int
) –The current epoch.
-
check_val_every_n_epoch
(int | None
) –The number of epochs to check for plotting. If None, no plotting is done.
Returns:
-
bool
(bool
) –True if the current epoch is a plot epoch, False otherwise.
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_test_batch_end
¶
on_test_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_test_epoch_end
¶
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_train_batch_end
¶
on_train_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_train_epoch_end
¶
on_validation_batch_end
¶
on_validation_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_validation_epoch_end
¶
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
setup
¶
setup(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
stage: darts_segmentation.training.callbacks.Stage,
)
Setups the callback.
Creates metrics required for the specific stage:
- For the "fit" stage, creates training and validation metrics and visualizations.
- For the "validate" stage, only creates validation metrics and visualizations.
- For the "test" stage, only creates test metrics and visualizations.
- For the "predict" stage, no metrics or visualizations are created.
Always maps the trainer and pl_module to the callback.
Training and validation metrics are "simple" metrics from torchmetrics. The validation visualizations are more complex metrics from torchmetrics. The test metrics and vsiualizations are the same as the validation ones, and also include custom "Instance" metrics.
Parameters:
-
trainer
(lightning.Trainer
) –The lightning trainer.
-
pl_module
(lightning.LightningModule
) –The lightning module.
-
stage
(typing.Literal['fit', 'validate', 'test', 'predict']
) –The current stage. One of: "fit", "validate", "test", "predict".
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 |
|
teardown
¶
teardown(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
stage: darts_segmentation.training.callbacks.Stage,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
BinarySegmentationPreview
¶
BinarySegmentationPreview(
*,
bands: darts_segmentation.utils.Bands,
val_set: str = "val",
test_set: str = "test",
plot_every_n_val_epochs: int = 5,
)
Bases: lightning.pytorch.callbacks.Callback
Callback for validation metrics and visualizations.
Initialize the ValidationCallback.
Parameters:
-
bands
(darts_segmentation.utils.Bands
) –List of bands to combine for the visualization.
-
val_set
(str
, default:'val'
) –Name of the validation set. Only used for naming the validation metrics. Defaults to "val".
-
test_set
(str
, default:'test'
) –Name of the test set. Only used for naming the test metrics. Defaults to "test".
-
plot_every_n_val_epochs
(int
, default:5
) –Plot validation samples every n epochs. Defaults to 5.
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
band_names
instance-attribute
¶
band_names = darts_segmentation.training.callbacks.BinarySegmentationPreview(
bands
).names
plot_every_n_val_epochs
instance-attribute
¶
plot_every_n_val_epochs = darts_segmentation.training.callbacks.BinarySegmentationPreview(
plot_every_n_val_epochs
)
test_set
instance-attribute
¶
test_set = darts_segmentation.training.callbacks.BinarySegmentationPreview(
test_set
)
val_set
instance-attribute
¶
val_set = darts_segmentation.training.callbacks.BinarySegmentationPreview(
val_set
)
is_val_plot_epoch
¶
Check if the current epoch is an epoch where validation samples should be plotted.
Parameters:
-
current_epoch
(int
) –The current epoch.
-
check_val_every_n_epoch
(int | None
) –The number of epochs to check for plotting. If None, no plotting is done.
Returns:
-
bool
(bool
) –True if the current epoch is a plot epoch, False otherwise.
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_test_batch_end
¶
on_test_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_test_epoch_end
¶
on_validation_batch_end
¶
on_validation_batch_end(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
on_validation_epoch_end
¶
setup
¶
setup(
trainer: lightning.Trainer,
pl_module: lightning.LightningModule,
stage: darts_segmentation.training.callbacks.Stage,
)
Setups the callback.
Parameters:
-
trainer
(lightning.Trainer
) –The lightning trainer.
-
pl_module
(lightning.LightningModule
) –The lightning module.
-
stage
(typing.Literal['fit', 'validate', 'test', 'predict']
) –The current stage. One of: "fit", "validate", "test", "predict".
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
plot_sample
¶
plot_sample(
x: torch.Tensor,
y: torch.Tensor,
y_pred: torch.Tensor,
band_names: list[str],
) -> tuple[
matplotlib.pyplot.Figure,
dict[str, matplotlib.pyplot.Axes],
]
Plot a single sample with the input, the ground truth and the prediction.
This function does a few expections on the input: - The input is expected to be normalized to 0-1. - The prediction is expected to be converted from logits to prediction. - The target is expected to be a int or long tensor with values of: 0 (negative class) 1 (positive class) and 2 (invalid pixels).
Parameters:
-
x
(torch.Tensor
) –The input tensor [C, H, W] (float).
-
y
(torch.Tensor
) –The ground truth tensor [H, W] (int).
-
y_pred
(torch.Tensor
) –The prediction tensor [H, W] (float).
-
band_names
(list[str]
) –The combinations of the input bands.
Returns:
-
tuple[matplotlib.pyplot.Figure, dict[str, matplotlib.pyplot.Axes]]
–tuple[Figure, dict[str, Axes]]: The figure and the axes of the plot.
Source code in darts-segmentation/src/darts_segmentation/training/viz.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
|