Skip to content

darts_segmentation.training

Training related functions and classes for Image Segmentation.

Classes:

Functions:

BinarySegmentationMetrics

BinarySegmentationMetrics(
    *,
    input_combination: list[str],
    val_set: str = "val",
    test_set: str = "test",
    plot_every_n_val_epochs: int = 5,
    is_crossval: bool = False,
)

Bases: lightning.pytorch.callbacks.Callback

Callback for validation metrics and visualizations.

Initialize the ValidationCallback.

Parameters:

  • input_combination (list[str]) –

    List of input names to combine for the visualization.

  • val_set (str, default: 'val' ) –

    Name of the validation set. Only used for naming the validation metrics. Defaults to "val".

  • test_set (str, default: 'test' ) –

    Name of the test set. Only used for naming the test metrics. Defaults to "test".

  • plot_every_n_val_epochs (int, default: 5 ) –

    Plot validation samples every n epochs. Defaults to 5.

  • is_crossval (bool, default: False ) –

    Whether the training is done with cross-validation. This will change the logging behavior of scalar metrics from logging to {val_set} to just "val". The logging behaviour of the samples is not affected. Defaults to False.

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def __init__(
    self,
    *,
    input_combination: list[str],
    val_set: str = "val",
    test_set: str = "test",
    plot_every_n_val_epochs: int = 5,
    is_crossval: bool = False,
):
    """Initialize the ValidationCallback.

    Args:
        input_combination (list[str]): List of input names to combine for the visualization.
        val_set (str, optional): Name of the validation set. Only used for naming the validation metrics.
            Defaults to "val".
        test_set (str, optional): Name of the test set. Only used for naming the test metrics. Defaults to "test".
        plot_every_n_val_epochs (int, optional): Plot validation samples every n epochs. Defaults to 5.
        is_crossval (bool, optional): Whether the training is done with cross-validation.
            This will change the logging behavior of scalar metrics from logging to {val_set} to just "val".
            The logging behaviour of the samples is not affected.
            Defaults to False.

    """
    assert "/" not in val_set, "val_set must not contain '/'"
    assert "/" not in test_set, "test_set must not contain '/'"
    self.val_set = val_set
    self.test_set = test_set
    self.plot_every_n_val_epochs = plot_every_n_val_epochs
    self.input_combination = input_combination
    self.is_crossval = is_crossval

input_combination instance-attribute

input_combination = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
    input_combination
)

is_crossval instance-attribute

is_crossval = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
    is_crossval
)

pl_module instance-attribute

pl_module: lightning.LightningModule

plot_every_n_val_epochs instance-attribute

plot_every_n_val_epochs = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
    plot_every_n_val_epochs
)

stage instance-attribute

stage: darts_segmentation.training.callbacks.Stage

test_cmx instance-attribute

test_cmx: torchmetrics.ConfusionMatrix

test_instance_cmx instance-attribute

test_instance_prc instance-attribute

test_metrics instance-attribute

test_metrics: torchmetrics.MetricCollection

test_prc instance-attribute

test_prc: torchmetrics.PrecisionRecallCurve

test_roc instance-attribute

test_roc: torchmetrics.ROC

test_set instance-attribute

test_set = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
    test_set
)

train_metrics instance-attribute

train_metrics: torchmetrics.MetricCollection

trainer instance-attribute

trainer: lightning.Trainer

val_cmx instance-attribute

val_cmx: torchmetrics.ConfusionMatrix

val_metrics instance-attribute

val_metrics: torchmetrics.MetricCollection

val_prc instance-attribute

val_prc: torchmetrics.PrecisionRecallCurve

val_roc instance-attribute

val_roc: torchmetrics.ROC

val_set instance-attribute

val_set = darts_segmentation.training.callbacks.BinarySegmentationMetrics(
    val_set
)

is_val_plot_epoch

is_val_plot_epoch(
    current_epoch: int, check_val_every_n_epoch: int | None
) -> bool

Check if the current epoch is an epoch where validation samples should be plotted.

Parameters:

  • current_epoch (int) –

    The current epoch.

  • check_val_every_n_epoch (int | None) –

    The number of epochs to check for plotting. If None, no plotting is done.

Returns:

  • bool ( bool ) –

    True if the current epoch is a plot epoch, False otherwise.

Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def is_val_plot_epoch(self, current_epoch: int, check_val_every_n_epoch: int | None) -> bool:
    """Check if the current epoch is an epoch where validation samples should be plotted.

    Args:
        current_epoch (int): The current epoch.
        check_val_every_n_epoch (int | None): The number of epochs to check for plotting.
            If None, no plotting is done.

    Returns:
        bool: True if the current epoch is a plot epoch, False otherwise.

    """
    if check_val_every_n_epoch is None:
        return False
    n = self.plot_every_n_val_epochs * check_val_every_n_epoch
    return ((current_epoch + 1) % n) == 0 or current_epoch == 0

on_test_batch_end

on_test_batch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    outputs,
    batch,
    batch_idx,
    dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_test_batch_end(  # noqa: D102
    self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0
):
    pl_module.log(f"{self.test_set}/loss", outputs["loss"])
    x, y = batch
    assert "y_hat" in outputs, (
        "Output does not contain 'y_hat' tensor."
        " Please make sure the 'test_step' method returns a dict with 'y_hat' and 'loss' keys."
        " The 'y_hat' should be the model's prediction (a pytorch tensor of shape [B, C, H, W])."
        " The 'loss' should be the loss value (a scalar tensor).",
    )
    y_hat = outputs["y_hat"]

    pl_module.test_metrics.update(y_hat, y)
    pl_module.test_roc.update(y_hat, y)
    pl_module.test_prc.update(y_hat, y)
    pl_module.test_cmx.update(y_hat, y)
    pl_module.test_instance_prc.update(y_hat, y)
    pl_module.test_instance_cmx.update(y_hat, y)

    # Create figures for the samples (plot at maximum 24)
    is_last_batch = trainer.num_val_batches == (batch_idx + 1)
    max_batch_idx = (24 // x.shape[0]) - 1  # Does only work if NOT last batch, since last batch may be smaller
    # If num_val_batches is 1 then this batch is the last one, but we still want to log it. despite its size
    # Will plot the first 24 samples of the first batch if batch-size is larger than 24
    should_log_batch = (
        (max_batch_idx >= batch_idx and not is_last_batch)
        or trainer.num_val_batches == 1
        or (max_batch_idx == -1 and batch_idx == 0)
    )
    if should_log_batch:
        for i in range(min(x.shape[0], 24)):
            fig, _ = plot_sample(x[i], y[i], y_hat[i], self.input_combination)
            for pllogger in pl_module.loggers:
                if isinstance(pllogger, CSVLogger):
                    fig_dir = Path(pllogger.log_dir) / "figures" / f"{self.test_set}-samples"
                    fig_dir.mkdir(exist_ok=True, parents=True)
                    fig.savefig(fig_dir / f"sample_{pl_module.global_step}_{batch_idx}_{i}.png")
                if isinstance(pllogger, WandbLogger):
                    wandb_run: Run = pllogger.experiment
                    # We don't commit the log yet, so that the step is increased with the next lightning log
                    # Which happens at the end of the validation epoch
                    img_name = f"{self.test_set}-samples/sample_{batch_idx}_{i}"
                    wandb_run.log({img_name: wandb.Image(fig)}, commit=False)
            fig.clear()
            plt.close(fig)

on_test_epoch_end

on_test_epoch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
    pl_module.test_cmx.compute()
    pl_module.test_roc.compute()
    pl_module.test_prc.compute()
    pl_module.test_instance_prc.compute()
    pl_module.test_instance_cmx.compute()

    # Plot roc, prc and confusion matrix to disk and wandb
    fig_cmx, _ = pl_module.test_cmx.plot(cmap="Blues")
    fig_roc, _ = pl_module.test_roc.plot(score=True)
    fig_prc, _ = pl_module.test_prc.plot(score=True)
    fig_instance_cmx, _ = pl_module.test_instance_cmx.plot(cmap="Blues")
    fig_instance_prc, _ = pl_module.test_instance_prc.plot(score=True)

    # Check for a wandb or csv logger to log the images
    for pllogger in pl_module.loggers:
        if isinstance(pllogger, CSVLogger):
            fig_dir = Path(pllogger.log_dir) / "figures" / f"{self.test_set}-samples"
            fig_dir.mkdir(exist_ok=True, parents=True)
            fig_cmx.savefig(fig_dir / f"cmx_{pl_module.global_step}.png")
            fig_roc.savefig(fig_dir / f"roc_{pl_module.global_step}.png")
            fig_prc.savefig(fig_dir / f"prc_{pl_module.global_step}.png")
            fig_instance_cmx.savefig(fig_dir / f"instance_cmx_{pl_module.global_step}.png")
            fig_instance_prc.savefig(fig_dir / f"instance_prc_{pl_module.global_step}.png")
        if isinstance(pllogger, WandbLogger):
            wandb_run: Run = pllogger.experiment
            wandb_run.log({f"{self.test_set}/cmx": wandb.Image(fig_cmx)}, commit=False)
            wandb_run.log({f"{self.test_set}/roc": wandb.Image(fig_roc)}, commit=False)
            wandb_run.log({f"{self.test_set}/prc": wandb.Image(fig_prc)}, commit=False)
            wandb_run.log({f"{self.test_set}/instance_cmx": wandb.Image(fig_instance_cmx)}, commit=False)
            wandb_run.log({f"{self.test_set}/instance_prc": wandb.Image(fig_instance_prc)}, commit=False)

    fig_cmx.clear()
    fig_roc.clear()
    fig_prc.clear()
    fig_instance_cmx.clear()
    fig_instance_prc.clear()
    plt.close("all")

    # This will also commit the accumulated plots
    pl_module.log_dict(pl_module.test_metrics.compute())

    pl_module.test_metrics.reset()
    pl_module.test_roc.reset()
    pl_module.test_prc.reset()
    pl_module.test_cmx.reset()
    pl_module.test_instance_prc.reset()
    pl_module.test_instance_cmx.reset()

on_train_batch_end

on_train_batch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    outputs,
    batch,
    batch_idx,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx):  # noqa: D102
    pl_module.log("train/loss", outputs["loss"])
    _, y = batch
    # Expect the output to has a tensor called "y_hat"
    assert "y_hat" in outputs, (
        "Output does not contain 'y_hat' tensor."
        " Please make sure the 'training_step' method returns a dict with 'y_hat' and 'loss' keys."
        " The 'y_hat' should be the model's prediction (a pytorch tensor of shape [B, C, H, W])."
        " The 'loss' should be the loss value (a scalar tensor).",
    )
    y_hat = outputs["y_hat"]
    pl_module.train_metrics(y_hat, y)
    pl_module.log_dict(pl_module.train_metrics, on_step=True, on_epoch=False)

on_train_epoch_end

on_train_epoch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
    pl_module.train_metrics.reset()

on_validation_batch_end

on_validation_batch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    outputs,
    batch,
    batch_idx,
    dataloader_idx=0,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_validation_batch_end(  # noqa: D102
    self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx=0
):
    pl_module.log(f"{self._val_prefix}/loss", outputs["loss"])
    x, y = batch
    # Expect the output to has a tensor called "y_hat"
    assert "y_hat" in outputs, (
        "Output does not contain 'y_hat' tensor."
        " Please make sure the 'validation_step' method returns a dict with 'y_hat' and 'loss' keys."
        " The 'y_hat' should be the model's prediction (a pytorch tensor of shape [B, C, H, W])."
        " The 'loss' should be the loss value (a scalar tensor).",
    )
    y_hat = outputs["y_hat"]

    pl_module.val_metrics.update(y_hat, y)
    pl_module.val_roc.update(y_hat, y)
    pl_module.val_prc.update(y_hat, y)
    pl_module.val_cmx.update(y_hat, y)

    # Create figures for the samples (plot at maximum 24)
    is_last_batch = trainer.num_val_batches == (batch_idx + 1)
    max_batch_idx = (24 // x.shape[0]) - 1  # Does only work if NOT last batch, since last batch may be smaller
    # If num_val_batches is 1 then this batch is the last one, but we still want to log it. despite its size
    # Will plot the first 24 samples of the first batch if batch-size is larger than 24
    should_log_batch = (
        (max_batch_idx >= batch_idx and not is_last_batch)
        or trainer.num_val_batches == 1
        or (max_batch_idx == -1 and batch_idx == 0)
    )
    is_val_plot_epoch = self.is_val_plot_epoch(pl_module.current_epoch, trainer.check_val_every_n_epoch)
    if is_val_plot_epoch and should_log_batch:
        for i in range(min(x.shape[0], 24)):
            fig, _ = plot_sample(x[i], y[i], y_hat[i], self.input_combination)
            for pllogger in pl_module.loggers:
                if isinstance(pllogger, CSVLogger):
                    fig_dir = Path(pllogger.log_dir) / "figures" / f"{self.val_set}-samples"
                    fig_dir.mkdir(exist_ok=True, parents=True)
                    fig.savefig(fig_dir / f"sample_{pl_module.global_step}_{batch_idx}_{i}.png")
                if isinstance(pllogger, WandbLogger):
                    wandb_run: Run = pllogger.experiment
                    # We don't commit the log yet, so that the step is increased with the next lightning log
                    # Which happens at the end of the validation epoch
                    img_name = f"{self.val_set}-samples/sample_{batch_idx}_{i}"
                    wandb_run.log({img_name: wandb.Image(fig)}, commit=False)
            fig.clear()
            plt.close(fig)

on_validation_epoch_end

on_validation_epoch_end(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule):  # noqa: D102
    # Only do this every self.plot_every_n_val_epochs epochs
    is_val_plot_epoch = self.is_val_plot_epoch(pl_module.current_epoch, trainer.check_val_every_n_epoch)
    if is_val_plot_epoch:
        pl_module.val_cmx.compute()
        pl_module.val_roc.compute()
        pl_module.val_prc.compute()

        # Plot roc, prc and confusion matrix to disk and wandb
        fig_cmx, _ = pl_module.val_cmx.plot(cmap="Blues")
        fig_roc, _ = pl_module.val_roc.plot(score=True)
        fig_prc, _ = pl_module.val_prc.plot(score=True)

        # Check for a wandb or csv logger to log the images
        for pllogger in pl_module.loggers:
            if isinstance(pllogger, CSVLogger):
                fig_dir = Path(pllogger.log_dir) / "figures" / f"{self._val_prefix}-samples"
                fig_dir.mkdir(exist_ok=True, parents=True)
                fig_cmx.savefig(fig_dir / f"cmx_{pl_module.global_step}png")
                fig_roc.savefig(fig_dir / f"roc_{pl_module.global_step}png")
                fig_prc.savefig(fig_dir / f"prc_{pl_module.global_step}.png")
            if isinstance(pllogger, WandbLogger):
                wandb_run: Run = pllogger.experiment
                wandb_run.log({f"{self._val_prefix}/cmx": wandb.Image(fig_cmx)}, commit=False)
                wandb_run.log({f"{self._val_prefix}/roc": wandb.Image(fig_roc)}, commit=False)
                wandb_run.log({f"{self._val_prefix}/prc": wandb.Image(fig_prc)}, commit=False)

        fig_cmx.clear()
        fig_roc.clear()
        fig_prc.clear()
        plt.close("all")

    # This will also commit the accumulated plots
    pl_module.log_dict(pl_module.val_metrics.compute())

    pl_module.val_metrics.reset()
    pl_module.val_roc.reset()
    pl_module.val_prc.reset()
    pl_module.val_cmx.reset()

setup

setup(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    stage: darts_segmentation.training.callbacks.Stage,
)

Setups the callback.

Creates metrics required for the specific stage:

  • For the "fit" stage, creates training and validation metrics and visualizations.
  • For the "validate" stage, only creates validation metrics and visualizations.
  • For the "test" stage, only creates test metrics and visualizations.
  • For the "predict" stage, no metrics or visualizations are created.

Always maps the trainer and pl_module to the callback.

Training and validation metrics are "simple" metrics from torchmetrics. The validation visualizations are more complex metrics from torchmetrics. The test metrics and vsiualizations are the same as the validation ones, and also include custom "Instance" metrics.

Parameters:

  • trainer (lightning.Trainer) –

    The lightning trainer.

  • pl_module (lightning.LightningModule) –

    The lightning module.

  • stage (typing.Literal['fit', 'validate', 'test', 'predict']) –

    The current stage. One of: "fit", "validate", "test", "predict".

Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Stage):
    """Setups the callback.

    Creates metrics required for the specific stage:

    - For the "fit" stage, creates training and validation metrics and visualizations.
    - For the "validate" stage, only creates validation metrics and visualizations.
    - For the "test" stage, only creates test metrics and visualizations.
    - For the "predict" stage, no metrics or visualizations are created.

    Always maps the trainer and pl_module to the callback.

    Training and validation metrics are "simple" metrics from torchmetrics.
    The validation visualizations are more complex metrics from torchmetrics.
    The test metrics and vsiualizations are the same as the validation ones,
    and also include custom "Instance" metrics.

    Args:
        trainer (Trainer): The lightning trainer.
        pl_module (LightningModule): The lightning module.
        stage (Literal["fit", "validate", "test", "predict"]): The current stage.
            One of: "fit", "validate", "test", "predict".

    """
    # Save references to the trainer and pl_module
    self.trainer = trainer
    self.pl_module = pl_module
    self.stage = stage

    # We don't want to use memory in the predict stage
    if stage == "predict":
        return

    metric_kwargs = {"task": "binary", "validate_args": False, "ignore_index": 2}
    metrics = MetricCollection(
        {
            "Accuracy": Accuracy(**metric_kwargs),
            "Precision": Precision(**metric_kwargs),
            "Specificity": Specificity(**metric_kwargs),
            "Recall": Recall(**metric_kwargs),
            "F1Score": F1Score(**metric_kwargs),
            "JaccardIndex": JaccardIndex(**metric_kwargs),
            "CohenKappa": CohenKappa(**metric_kwargs),
            "HammingDistance": HammingDistance(**metric_kwargs),
        }
    )

    added_metrics: list[str] = []

    # Train metrics only for the fit stage
    if stage == "fit":
        pl_module.train_metrics = metrics.clone(prefix="train/")
        added_metrics += list(pl_module.train_metrics.keys())
    # Validation metrics and visualizations for the fit and validate stages
    if stage == "fit" or stage == "validate":
        pl_module.val_metrics = metrics.clone(prefix=f"{self._val_prefix}/")
        pl_module.val_metrics.add_metrics(
            {
                "AUROC": AUROC(thresholds=20, **metric_kwargs),
                "AveragePrecision": AveragePrecision(thresholds=20, **metric_kwargs),
            }
        )
        pl_module.val_roc = ROC(thresholds=20, **metric_kwargs)
        pl_module.val_prc = PrecisionRecallCurve(thresholds=20, **metric_kwargs)
        pl_module.val_cmx = ConfusionMatrix(normalize="true", **metric_kwargs)
        added_metrics += list(pl_module.val_metrics.keys())
        added_metrics += [f"{self._val_prefix}/{m}" for m in ["roc", "prc", "cmx"]]

    # Test metrics and visualizations for the test stage
    if stage == "test":
        pl_module.test_metrics = metrics.clone(prefix=f"{pl_module.test_set}/")
        pl_module.test_metrics.add_metrics(
            {
                "AUROC": AUROC(thresholds=20, **metric_kwargs),
                "AveragePrecision": AveragePrecision(thresholds=20, **metric_kwargs),
            }
        )
        pl_module.test_roc = ROC(thresholds=20, **metric_kwargs)
        pl_module.test_prc = PrecisionRecallCurve(thresholds=20, **metric_kwargs)
        pl_module.test_cmx = ConfusionMatrix(normalize="true", **metric_kwargs)

        # Instance Metrics
        instance_metric_kwargs = {"validate_args": False, "ignore_index": 2, "matching_threshold": 0.3}
        pl_module.test_metrics.add_metrics(
            {
                "InstanceAccuracy": BinaryInstanceAccuracy(**instance_metric_kwargs),
                "InstancePrecision": BinaryInstancePrecision(**instance_metric_kwargs),
                "InstanceRecall": BinaryInstanceRecall(**instance_metric_kwargs),
                "InstanceF1Score": BinaryInstanceF1Score(**instance_metric_kwargs),
                "InstanceAveragePrecision": BinaryInstanceAveragePrecision(thresholds=20, **instance_metric_kwargs),
            }
        )
        boundary_metric_kwargs = {"validate_args": False, "ignore_index": 2}
        pl_module.test_metrics.add_metrics(
            {
                "InstanceBoundaryIoU": BinaryBoundaryIoU(**boundary_metric_kwargs),
            }
        )
        pl_module.test_instance_prc = BinaryInstancePrecisionRecallCurve(thresholds=20, **instance_metric_kwargs)
        pl_module.test_instance_cmx = BinaryInstanceConfusionMatrix(normalize=True, **instance_metric_kwargs)

        added_metrics += list(pl_module.test_metrics.keys())
        added_metrics += [f"{self.test_set}/{m}" for m in ["roc", "prc", "cmx", "instance_prc", "instance_cmx"]]

    # Log the added metrics
    sep = "\n\t- "
    logger.debug(f"Added metrics:{sep + sep.join(added_metrics)}")

teardown

teardown(
    trainer: lightning.Trainer,
    pl_module: lightning.LightningModule,
    stage: darts_segmentation.training.callbacks.Stage,
)
Source code in darts-segmentation/src/darts_segmentation/training/callbacks.py
def teardown(self, trainer: Trainer, pl_module: LightningModule, stage: Stage):  # noqa: D102
    # Delete the references to the trainer and pl_module
    del self.trainer
    del self.pl_module
    del self.stage

    # No need to delete anything if we are in the predict stage
    if stage == "predict":
        return

    if stage == "fit":
        del pl_module.train_metrics

    if stage == "fit" or stage == "validate":
        del pl_module.val_metrics
        del pl_module.val_roc
        del pl_module.val_prc
        del pl_module.val_cmx

    if stage == "test":
        del pl_module.test_metrics
        del pl_module.test_roc
        del pl_module.test_prc
        del pl_module.test_cmx

DartsDataModule

DartsDataModule(
    data_dir: pathlib.Path,
    batch_size: int,
    fold: int = 0,
    augment: bool = True,
    num_workers: int = 0,
    in_memory: bool = False,
)

Bases: lightning.LightningDataModule

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(
    self,
    data_dir: Path,
    batch_size: int,
    fold: int = 0,  # Not used for test
    augment: bool = True,  # Not used for test
    num_workers: int = 0,
    in_memory: bool = False,
):
    super().__init__()
    self.save_hyperparameters()
    self.data_dir = data_dir
    self.batch_size = batch_size
    self.fold = fold
    self.augment = augment
    self.num_workers = num_workers
    self.in_memory = in_memory

    data_dir = Path(data_dir)

    store = zarr.storage.DirectoryStore(data_dir)
    zroot = zarr.group(store=store)
    self.nsamples = len(zroot["x"])

augment instance-attribute

augment = darts_segmentation.training.data.DartsDataModule(
    augment
)

batch_size instance-attribute

batch_size = (
    darts_segmentation.training.data.DartsDataModule(
        batch_size
    )
)

data_dir instance-attribute

data_dir = darts_segmentation.training.data.DartsDataModule(
    data_dir
)

fold instance-attribute

in_memory instance-attribute

in_memory = (
    darts_segmentation.training.data.DartsDataModule(
        in_memory
    )
)

nsamples instance-attribute

nsamples = len(zroot['x'])

num_workers instance-attribute

num_workers = (
    darts_segmentation.training.data.DartsDataModule(
        num_workers
    )
)

setup

setup(
    stage: typing.Literal[
        "fit", "validate", "test", "predict"
    ]
    | None = None,
)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None):
    if stage in ["fit", "validate"]:
        kf = KFold(n_splits=5)
        train_idx, val_idx = list(kf.split(range(self.nsamples)))[self.fold]

        dsclass = DartsDatasetInMemory if self.in_memory else DartsDatasetZarr
        self.train = dsclass(self.data_dir, self.augment, train_idx)
        self.val = dsclass(self.data_dir, False, val_idx)
    if stage == "test":
        dsclass = DartsDatasetInMemory if self.in_memory else DartsDatasetZarr
        self.test = dsclass(self.data_dir, False)

test_dataloader

test_dataloader()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def test_dataloader(self):
    return DataLoader(self.test, batch_size=self.batch_size, num_workers=self.num_workers)

train_dataloader

train_dataloader()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def train_dataloader(self):
    return DataLoader(self.train, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

val_dataloader

val_dataloader()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def val_dataloader(self):
    return DataLoader(self.val, batch_size=self.batch_size, num_workers=self.num_workers)

DartsDataset

DartsDataset(
    data_dir: pathlib.Path | str,
    augment: bool,
    indices: list[int] | None = None,
)

Bases: torch.utils.data.Dataset

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(self, data_dir: Path | str, augment: bool, indices: list[int] | None = None):
    if isinstance(data_dir, str):
        data_dir = Path(data_dir)

    self.x_files = sorted((data_dir / "x").glob("*.pt"))
    self.y_files = sorted((data_dir / "y").glob("*.pt"))
    assert len(self.x_files) == len(self.y_files), (
        f"Dataset corrupted! Got {len(self.x_files)=} and {len(self.y_files)=}!"
    )
    if indices is not None:
        self.x_files = [self.x_files[i] for i in indices]
        self.y_files = [self.y_files[i] for i in indices]

    self.transform = (
        A.Compose(
            [
                A.HorizontalFlip(),
                A.VerticalFlip(),
                A.RandomRotate90(),
                # A.Blur(),
                A.RandomBrightnessContrast(),
                A.MultiplicativeNoise(per_channel=True, elementwise=True),
                # ToTensorV2(),
            ]
        )
        if augment
        else None
    )

transform instance-attribute

transform = (
    albumentations.Compose(
        [
            albumentations.HorizontalFlip(),
            albumentations.VerticalFlip(),
            albumentations.RandomRotate90(),
            albumentations.RandomBrightnessContrast(),
            albumentations.MultiplicativeNoise(
                per_channel=True, elementwise=True
            ),
        ]
    )
    if darts_segmentation.training.data.DartsDataset(
        augment
    )
    else None
)

x_files instance-attribute

x_files = sorted(
    darts_segmentation.training.data.DartsDataset(data_dir)
    / "x".glob("*.pt")
)

y_files instance-attribute

y_files = sorted(
    darts_segmentation.training.data.DartsDataset(data_dir)
    / "y".glob("*.pt")
)

__getitem__

__getitem__(idx)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __getitem__(self, idx):
    xfile = self.x_files[idx]
    yfile = self.y_files[idx]
    assert xfile.stem == yfile.stem, f"Dataset corrupted! Files must have the same name, but got {xfile=} {yfile=}!"

    x = torch.load(xfile).numpy()
    y = torch.load(yfile).int().numpy()

    # Apply augmentations
    if self.transform is not None:
        augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
        x = augmented["image"].transpose(2, 0, 1)
        y = augmented["mask"]

    return x, y

__len__

__len__()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __len__(self):
    return len(self.x_files)

DartsDatasetInMemory

DartsDatasetInMemory(
    data_dir: pathlib.Path | str,
    augment: bool,
    indices: list[int] | None = None,
)

Bases: torch.utils.data.Dataset

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(self, data_dir: Path | str, augment: bool, indices: list[int] | None = None):
    if isinstance(data_dir, str):
        data_dir = Path(data_dir)

    x_files = sorted((data_dir / "x").glob("*.pt"))
    y_files = sorted((data_dir / "y").glob("*.pt"))
    assert len(x_files) == len(y_files), f"Dataset corrupted! Got {len(x_files)=} and {len(y_files)=}!"
    if indices is not None:
        x_files = [x_files[i] for i in indices]
        y_files = [y_files[i] for i in indices]

    self.x = []
    self.y = []
    for xfile, yfile in zip(x_files, y_files):
        assert xfile.stem == yfile.stem, (
            f"Dataset corrupted! Files must have the same name, but got {xfile=} {yfile=}!"
        )
        x = torch.load(xfile).numpy()
        y = torch.load(yfile).int().numpy()
        self.x.append(x)
        self.y.append(y)

    self.transform = (
        A.Compose(
            [
                A.HorizontalFlip(),
                A.VerticalFlip(),
                A.RandomRotate90(),
                # A.Blur(),
                A.RandomBrightnessContrast(),
                A.MultiplicativeNoise(per_channel=True, elementwise=True),
                # ToTensorV2(),
            ]
        )
        if augment
        else None
    )

transform instance-attribute

transform = (
    albumentations.Compose(
        [
            albumentations.HorizontalFlip(),
            albumentations.VerticalFlip(),
            albumentations.RandomRotate90(),
            albumentations.RandomBrightnessContrast(),
            albumentations.MultiplicativeNoise(
                per_channel=True, elementwise=True
            ),
        ]
    )
    if darts_segmentation.training.data.DartsDatasetInMemory(
        augment
    )
    else None
)

x instance-attribute

x = []

y instance-attribute

y = []

__getitem__

__getitem__(idx)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __getitem__(self, idx):
    x = self.x[idx]
    y = self.y[idx]

    # Apply augmentations
    if self.transform is not None:
        augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
        x = augmented["image"].transpose(2, 0, 1)
        y = augmented["mask"]

    return x, y

__len__

__len__()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __len__(self):
    return len(self.x)

DartsDatasetZarr

DartsDatasetZarr(
    data_dir: pathlib.Path | str,
    augment: bool,
    indices: list[int] | None = None,
)

Bases: torch.utils.data.Dataset

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(self, data_dir: Path | str, augment: bool, indices: list[int] | None = None):
    if isinstance(data_dir, str):
        data_dir = Path(data_dir)

    store = zarr.storage.LocalStore(data_dir)
    self.zroot = zarr.group(store=store)

    assert "x" in self.zroot and "y" in self.zroot, (
        f"Dataset corrupted! {self.zroot.info=} must contain 'x' or 'y' arrays!"
    )

    self.indices = indices if indices is not None else list(range(self.zroot["x"].shape[0]))

    self.transform = (
        A.Compose(
            [
                A.HorizontalFlip(),
                A.VerticalFlip(),
                A.RandomRotate90(),
                # A.Blur(),
                A.RandomBrightnessContrast(),
                A.MultiplicativeNoise(per_channel=True, elementwise=True),
                # ToTensorV2(),
            ]
        )
        if augment
        else None
    )

indices instance-attribute

indices = (
    darts_segmentation.training.data.DartsDatasetZarr(
        indices
    )
    if darts_segmentation.training.data.DartsDatasetZarr(
        indices
    )
    is not None
    else list(
        range(
            darts_segmentation.training.data.DartsDatasetZarr(
                self
            )
            .zroot["x"]
            .shape[0]
        )
    )
)

transform instance-attribute

transform = (
    albumentations.Compose(
        [
            albumentations.HorizontalFlip(),
            albumentations.VerticalFlip(),
            albumentations.RandomRotate90(),
            albumentations.RandomBrightnessContrast(),
            albumentations.MultiplicativeNoise(
                per_channel=True, elementwise=True
            ),
        ]
    )
    if darts_segmentation.training.data.DartsDatasetZarr(
        augment
    )
    else None
)

zroot instance-attribute

zroot = zarr.group(store=store)

__getitem__

__getitem__(idx)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __getitem__(self, idx):
    i = self.indices[idx]

    x = self.zroot["x"][i]
    y = self.zroot["y"][i]

    # Apply augmentations
    if self.transform is not None:
        augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
        x = augmented["image"].transpose(2, 0, 1)
        y = augmented["mask"]

    return x, y

__len__

__len__()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __len__(self):
    return len(self.indices)

SMPSegmenter

SMPSegmenter(
    config: darts_segmentation.segment.SMPSegmenterConfig,
    learning_rate: float = 1e-05,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    **kwargs: dict[str, typing.Any],
)

Bases: lightning.LightningModule

Lightning module for training a segmentation model using the segmentation_models_pytorch library.

Initialize the SMPSegmenter.

Parameters:

  • config (darts_segmentation.segment.SMPSegmenterConfig) –

    Configuration for the segmentation model.

  • learning_rate (float, default: 1e-05 ) –

    Initial learning rate. Defaults to 1e-5.

  • gamma (float, default: 0.9 ) –

    Multiplicative factor of learning rate decay. Defaults to 0.9.

  • focal_loss_alpha (float, default: None ) –

    Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.

  • focal_loss_gamma (float, default: 2.0 ) –

    Focal loss power factor. Defaults to 2.0.

  • kwargs (dict[str, typing.Any], default: {} ) –

    Additional keyword arguments which should be saved to the hyperparameter file.

Methods:

Attributes:

Source code in darts-segmentation/src/darts_segmentation/training/module.py
def __init__(
    self,
    config: SMPSegmenterConfig,
    learning_rate: float = 1e-5,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    **kwargs: dict[str, Any],
):
    """Initialize the SMPSegmenter.

    Args:
        config (SMPSegmenterConfig): Configuration for the segmentation model.
        learning_rate (float, optional): Initial learning rate. Defaults to 1e-5.
        gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.9.
        focal_loss_alpha (float, optional): Weight factor to balance positive and negative samples.
            Alpha must be in [0...1] range, high values will give more weight to positive class.
            None will not weight samples. Defaults to None.
        focal_loss_gamma (float, optional): Focal loss power factor. Defaults to 2.0.
        kwargs (dict[str, Any]): Additional keyword arguments which should be saved to the hyperparameter file.

    """
    super().__init__()

    # This saves config, learning_rate and gamma under self.hparams
    self.save_hyperparameters(ignore=["test_set", "val_set"])
    self.model = smp.create_model(**config["model"], activation="sigmoid")

    # Assumes that the training preparation was done with setting invalid pixels in the mask to 2
    self.loss_fn = smp.losses.FocalLoss(
        mode="binary", alpha=focal_loss_alpha, gamma=focal_loss_gamma, ignore_index=2
    )

loss_fn instance-attribute

loss_fn = segmentation_models_pytorch.losses.FocalLoss(
    mode="binary",
    alpha=darts_segmentation.training.module.SMPSegmenter(
        focal_loss_alpha
    ),
    gamma=darts_segmentation.training.module.SMPSegmenter(
        focal_loss_gamma
    ),
    ignore_index=2,
)

model instance-attribute

model = segmentation_models_pytorch.create_model(
    **darts_segmentation.training.module.SMPSegmenter(
        config
    )["model"],
    activation="sigmoid",
)

__repr__

__repr__()
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def __repr__(self):  # noqa: D105
    return f"SMPSegmenter({self.hparams['config']['model']})"

configure_optimizers

configure_optimizers()
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def configure_optimizers(self):  # noqa: D102
    optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.hparams.gamma)
    return [optimizer], [scheduler]

on_train_epoch_end

on_train_epoch_end()
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def on_train_epoch_end(self):  # noqa: D102
    self.log("learning_rate", self.lr_schedulers().get_last_lr()[0])

test_step

test_step(batch, batch_idx)
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def test_step(self, batch, batch_idx):  # noqa: D102
    x, y = batch
    y_hat = self.model(x).squeeze(1)
    loss = self.loss_fn(y_hat, y.long())
    return {
        "loss": loss,
        "y_hat": y_hat,
    }

training_step

training_step(batch, batch_idx)
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def training_step(self, batch, batch_idx):  # noqa: D102
    x, y = batch
    y_hat = self.model(x).squeeze(1)
    loss = self.loss_fn(y_hat, y.long())
    return {
        "loss": loss,
        "y_hat": y_hat,
    }

validation_step

validation_step(batch, batch_idx)
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def validation_step(self, batch, batch_idx):  # noqa: D102
    x, y = batch
    y_hat = self.model(x).squeeze(1)
    loss = self.loss_fn(y_hat, y.long())
    return {
        "loss": loss,
        "y_hat": y_hat,
    }

create_training_patches

create_training_patches(
    tile: xarray.Dataset,
    labels: geopandas.GeoDataFrame,
    bands: list[str],
    norm_factors: dict[str, float],
    patch_size: int,
    overlap: int,
    exclude_nopositive: bool,
    exclude_nan: bool,
    device: typing.Literal["cuda", "cpu"] | int,
    mask_erosion_size: int,
) -> collections.abc.Generator[
    tuple[torch.tensor, torch.tensor]
]

Create training patches from a tile and labels.

Parameters:

  • tile (xarray.Dataset) –

    The input tile, containing preprocessed, harmonized data.

  • labels (geopandas.GeoDataFrame) –

    The labels to be used for training.

  • bands (list[str]) –

    The bands to be used for training. Must be present in the tile.

  • norm_factors (dict[str, float]) –

    The normalization factors for the bands.

  • patch_size (int) –

    The size of the patches.

  • overlap (int) –

    The size of the overlap.

  • exclude_nopositive (bool) –

    Whether to exclude patches where the labels do not contain positives.

  • exclude_nan (bool) –

    Whether to exclude patches where the input data has nan values.

  • device (typing.Literal['cuda', 'cpu'] | int) –

    The device to use for the erosion.

  • mask_erosion_size (int) –

    The size of the disk to use for erosion.

Yields:

Raises:

  • ValueError

    If a band is not found in the preprocessed data.

Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
def create_training_patches(
    tile: xr.Dataset,
    labels: gpd.GeoDataFrame,
    bands: list[str],
    norm_factors: dict[str, float],
    patch_size: int,
    overlap: int,
    exclude_nopositive: bool,
    exclude_nan: bool,
    device: Literal["cuda", "cpu"] | int,
    mask_erosion_size: int,
) -> Generator[tuple[torch.tensor, torch.tensor]]:
    """Create training patches from a tile and labels.

    Args:
        tile (xr.Dataset): The input tile, containing preprocessed, harmonized data.
        labels (gpd.GeoDataFrame): The labels to be used for training.
        bands (list[str]): The bands to be used for training. Must be present in the tile.
        norm_factors (dict[str, float]): The normalization factors for the bands.
        patch_size (int): The size of the patches.
        overlap (int): The size of the overlap.
        exclude_nopositive (bool): Whether to exclude patches where the labels do not contain positives.
        exclude_nan (bool): Whether to exclude patches where the input data has nan values.
        device (Literal["cuda", "cpu"] | int): The device to use for the erosion.
        mask_erosion_size (int): The size of the disk to use for erosion.

    Yields:
        Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors.
            The input has the format (C, H, W), the labels (H, W).

    Raises:
        ValueError: If a band is not found in the preprocessed data.

    """
    if len(labels) == 0 and exclude_nopositive:
        logger.warning("No labels found in the labels GeoDataFrame. Skipping.")
        return

    # Rasterize the labels
    if len(labels) > 0:
        labels_rasterized = 1 - make_geocube(labels, measurements=["id"], like=tile).id.isnull()
    else:
        labels_rasterized = xr.zeros_like(tile["valid_data_mask"])

    # Filter out the nodata values (class 2 -> invalid data)
    mask = erode_mask(tile["valid_data_mask"], mask_erosion_size, device)
    mask = tile["valid_data_mask"]
    labels_rasterized = xr.where(mask, labels_rasterized, 2)

    # Normalize the bands and clip the values
    for band in bands:
        if band not in tile:
            raise ValueError(f"Band '{band}' not found in the preprocessed data.")
        with xr.set_options(keep_attrs=True):
            tile[band] = tile[band] * norm_factors[band]
            tile[band] = tile[band].clip(0, 1)

    # Replace invalid values with nan (used for nan check later on)
    tile = xr.where(tile["valid_data_mask"], tile, float("nan"))

    # Convert to dataaray and select the bands (bands are now in specified order)
    tile = tile.to_dataarray(dim="band").sel(band=bands)

    # Transpose to (C, H, W)
    tile = tile.transpose("band", "y", "x")
    labels_rasterized = labels_rasterized.transpose("y", "x")

    # Convert to tensor
    tensor_tile = torch.tensor(tile.values).float()
    tensor_labels = torch.tensor(labels_rasterized.values).float()

    assert tensor_tile.dim() == 3, f"Expects tensor_tile to has shape (C, H, W), got {tensor_tile.shape}"
    assert tensor_labels.dim() == 2, f"Expects tensor_labels to has shape (H, W), got {tensor_labels.shape}"

    # Create patches
    tensor_patches = create_patches(tensor_tile.unsqueeze(0), patch_size, overlap)
    tensor_patches = tensor_patches.reshape(-1, len(bands), patch_size, patch_size)
    tensor_labels = create_patches(tensor_labels.unsqueeze(0).unsqueeze(0), patch_size, overlap)
    tensor_labels = tensor_labels.reshape(-1, patch_size, patch_size)

    # Turn the patches into a list of tuples
    n_patches = tensor_patches.shape[0]
    for i in range(n_patches):
        x = tensor_patches[i]
        y = tensor_labels[i]

        if exclude_nopositive and not (y == 1).any():
            continue

        if exclude_nan and torch.isnan(x).any():
            continue

        # Skip where there are less than 10% visible pixel
        if ((y != 2).sum() / y.numel()) < 0.1:
            continue

        # Skip patches where everything is nan
        if torch.isnan(x).all():
            continue

        # Convert all nan values to 0
        x[torch.isnan(x)] = 0

        logger.debug(f"Yielding patch {i} with\n\t{x=}\n\t{y=}")
        yield x, y