Skip to content

darts_segmentation.training.SMPSegmenter

Bases: lightning.LightningModule

Lightning module for training a segmentation model using the segmentation_models_pytorch library.

Initialize the SMPSegmenter.

Parameters:

  • config (darts_segmentation.segment.SMPSegmenterConfig) –

    Configuration for the segmentation model.

  • learning_rate (float, default: 1e-05 ) –

    Initial learning rate. Defaults to 1e-5.

  • gamma (float, default: 0.9 ) –

    Multiplicative factor of learning rate decay. Defaults to 0.9.

  • focal_loss_alpha (float, default: None ) –

    Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.

  • focal_loss_gamma (float, default: 2.0 ) –

    Focal loss power factor. Defaults to 2.0.

  • kwargs (dict[str, typing.Any], default: {} ) –

    Additional keyword arguments which should be saved to the hyperparameter file.

Source code in darts-segmentation/src/darts_segmentation/training/module.py
def __init__(
    self,
    config: SMPSegmenterConfig,
    learning_rate: float = 1e-5,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    **kwargs: dict[str, Any],
):
    """Initialize the SMPSegmenter.

    Args:
        config (SMPSegmenterConfig): Configuration for the segmentation model.
        learning_rate (float, optional): Initial learning rate. Defaults to 1e-5.
        gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.9.
        focal_loss_alpha (float, optional): Weight factor to balance positive and negative samples.
            Alpha must be in [0...1] range, high values will give more weight to positive class.
            None will not weight samples. Defaults to None.
        focal_loss_gamma (float, optional): Focal loss power factor. Defaults to 2.0.
        kwargs (dict[str, Any]): Additional keyword arguments which should be saved to the hyperparameter file.

    """
    super().__init__()

    # This saves config, learning_rate and gamma under self.hparams
    self.save_hyperparameters(ignore=["test_set", "val_set"])
    self.model = smp.create_model(**config["model"], activation="sigmoid")

    # Assumes that the training preparation was done with setting invalid pixels in the mask to 2
    self.loss_fn = smp.losses.FocalLoss(
        mode="binary", alpha=focal_loss_alpha, gamma=focal_loss_gamma, ignore_index=2
    )

loss_fn instance-attribute

loss_fn = segmentation_models_pytorch.losses.FocalLoss(
    mode="binary",
    alpha=darts_segmentation.training.module.SMPSegmenter(
        focal_loss_alpha
    ),
    gamma=darts_segmentation.training.module.SMPSegmenter(
        focal_loss_gamma
    ),
    ignore_index=2,
)

model instance-attribute

model = segmentation_models_pytorch.create_model(
    **darts_segmentation.training.module.SMPSegmenter(
        config
    )["model"],
    activation="sigmoid",
)

__repr__

__repr__()
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def __repr__(self):  # noqa: D105
    return f"SMPSegmenter({self.hparams['config']['model']})"

configure_optimizers

configure_optimizers()
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def configure_optimizers(self):  # noqa: D102
    optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.hparams.gamma)
    return [optimizer], [scheduler]

on_train_epoch_end

on_train_epoch_end()
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def on_train_epoch_end(self):  # noqa: D102
    self.log("learning_rate", self.lr_schedulers().get_last_lr()[0])

test_step

test_step(batch, batch_idx)
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def test_step(self, batch, batch_idx):  # noqa: D102
    x, y = batch
    y_hat = self.model(x).squeeze(1)
    loss = self.loss_fn(y_hat, y.long())
    return {
        "loss": loss,
        "y_hat": y_hat,
    }

training_step

training_step(batch, batch_idx)
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def training_step(self, batch, batch_idx):  # noqa: D102
    x, y = batch
    y_hat = self.model(x).squeeze(1)
    loss = self.loss_fn(y_hat, y.long())
    return {
        "loss": loss,
        "y_hat": y_hat,
    }

validation_step

validation_step(batch, batch_idx)
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def validation_step(self, batch, batch_idx):  # noqa: D102
    x, y = batch
    y_hat = self.model(x).squeeze(1)
    loss = self.loss_fn(y_hat, y.long())
    return {
        "loss": loss,
        "y_hat": y_hat,
    }