Skip to content

module

darts_segmentation.training.module

Training script for DARTS segmentation.

logger module-attribute

logger = logging.getLogger(
    __name__.replace("darts_", "darts.")
)

LitSMP

LitSMP(
    config: darts_segmentation.segment.SMPSegmenterConfig,
    learning_rate: float = 1e-05,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    **kwargs: dict[str, typing.Any],
)

Bases: lightning.LightningModule

Lightning module for training a segmentation model using the segmentation_models_pytorch library.

Initialize the LitSMP.

Parameters:

  • config (darts_segmentation.segment.SMPSegmenterConfig) –

    Configuration for the segmentation model.

  • learning_rate (float, default: 1e-05 ) –

    Initial learning rate. Defaults to 1e-5.

  • gamma (float, default: 0.9 ) –

    Multiplicative factor of learning rate decay. Defaults to 0.9.

  • focal_loss_alpha (float, default: None ) –

    Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.

  • focal_loss_gamma (float, default: 2.0 ) –

    Focal loss power factor. Defaults to 2.0.

  • kwargs (dict[str, typing.Any], default: {} ) –

    Additional keyword arguments which should be saved to the hyperparameter file.

Source code in darts-segmentation/src/darts_segmentation/training/module.py
def __init__(
    self,
    config: SMPSegmenterConfig,
    learning_rate: float = 1e-5,
    gamma: float = 0.9,
    focal_loss_alpha: float | None = None,
    focal_loss_gamma: float = 2.0,
    **kwargs: dict[str, Any],
):
    """Initialize the LitSMP.

    Args:
        config (SMPSegmenterConfig): Configuration for the segmentation model.
        learning_rate (float, optional): Initial learning rate. Defaults to 1e-5.
        gamma (float, optional): Multiplicative factor of learning rate decay. Defaults to 0.9.
        focal_loss_alpha (float, optional): Weight factor to balance positive and negative samples.
            Alpha must be in [0...1] range, high values will give more weight to positive class.
            None will not weight samples. Defaults to None.
        focal_loss_gamma (float, optional): Focal loss power factor. Defaults to 2.0.
        kwargs (dict[str, Any]): Additional keyword arguments which should be saved to the hyperparameter file.

    """
    super().__init__()

    # This saves config, learning_rate and gamma under self.hparams
    self.save_hyperparameters(ignore=["test_set", "val_set"])
    self.model = smp.create_model(**config["model"], activation="sigmoid")

    # Assumes that the training preparation was done with setting invalid pixels in the mask to 2
    self.loss_fn = smp.losses.FocalLoss(
        mode="binary", alpha=focal_loss_alpha, gamma=focal_loss_gamma, ignore_index=2
    )

loss_fn instance-attribute

loss_fn = segmentation_models_pytorch.losses.FocalLoss(
    mode="binary",
    alpha=darts_segmentation.training.module.LitSMP(
        focal_loss_alpha
    ),
    gamma=darts_segmentation.training.module.LitSMP(
        focal_loss_gamma
    ),
    ignore_index=2,
)

model instance-attribute

model = segmentation_models_pytorch.create_model(
    **darts_segmentation.training.module.LitSMP(config)[
        "model"
    ],
    activation="sigmoid",
)

__repr__

__repr__()
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def __repr__(self):  # noqa: D105
    return f"LitSMP({self.hparams['config']['model']})"

configure_optimizers

configure_optimizers()
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def configure_optimizers(self):  # noqa: D102
    optimizer = optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=self.hparams.gamma)
    return [optimizer], [scheduler]

on_train_epoch_end

on_train_epoch_end()
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def on_train_epoch_end(self):  # noqa: D102
    self.log("learning_rate", self.lr_schedulers().get_last_lr()[0])

test_step

test_step(batch, batch_idx)
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def test_step(self, batch, batch_idx):  # noqa: D102
    x, y = batch
    y_hat = self.model(x).squeeze(1)
    loss = self.loss_fn(y_hat, y.long())
    return {
        "loss": loss,
        "y_hat": y_hat,
    }

training_step

training_step(batch, batch_idx)
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def training_step(self, batch, batch_idx):  # noqa: D102
    x, y = batch
    y_hat = self.model(x).squeeze(1)
    loss = self.loss_fn(y_hat, y.long())
    return {
        "loss": loss,
        "y_hat": y_hat,
    }

validation_step

validation_step(batch, batch_idx)
Source code in darts-segmentation/src/darts_segmentation/training/module.py
def validation_step(self, batch, batch_idx):  # noqa: D102
    x, y = batch
    y_hat = self.model(x).squeeze(1)
    loss = self.loss_fn(y_hat, y.long())
    return {
        "loss": loss,
        "y_hat": y_hat,
    }

SMPSegmenterConfig

Bases: typing.TypedDict

Configuration for the segmentor.

bands instance-attribute

model instance-attribute

model: dict[str, typing.Any]

from_ckpt classmethod

Validate the config for the segmentor.

Parameters:

Returns:

Source code in darts-segmentation/src/darts_segmentation/segment.py
@classmethod
def from_ckpt(cls, config: dict[str, Any]) -> "SMPSegmenterConfig":
    """Validate the config for the segmentor.

    Args:
        config: The configuration to validate.

    Returns:
        The validated configuration.

    """
    # Handling legacy case that the config contains the old keys
    if "input_combination" in config and "norm_factors" in config:
        # Check if all input_combination features are in norm_factors
        config["bands"] = Bands([Band(name, config["norm_factors"][name]) for name in config["input_combination"]])
        config.pop("norm_factors")
        config.pop("input_combination")

    assert "model" in config, "Model config is missing!"
    assert "bands" in config, "Bands config is missing!"
    # The Bands object is always pickled as a dict for interoperability, so we need to convert it back
    if not isinstance(config["bands"], Bands):
        config["bands"] = Bands.from_config(config["bands"])
    return config