module
darts_segmentation.training.module
¶
Training script for DARTS segmentation.
LitSMP
¶
LitSMP(
config: darts_segmentation.segment.SMPSegmenterConfig,
learning_rate: float = 1e-05,
gamma: float = 0.9,
focal_loss_alpha: float | None = None,
focal_loss_gamma: float = 2.0,
**kwargs: dict[str, typing.Any],
)
Bases: lightning.LightningModule
Lightning module for training a segmentation model using the segmentation_models_pytorch library.
Initialize the LitSMP.
Parameters:
-
config
(darts_segmentation.segment.SMPSegmenterConfig
) –Configuration for the segmentation model.
-
learning_rate
(float
, default:1e-05
) –Initial learning rate. Defaults to 1e-5.
-
gamma
(float
, default:0.9
) –Multiplicative factor of learning rate decay. Defaults to 0.9.
-
focal_loss_alpha
(float
, default:None
) –Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.
-
focal_loss_gamma
(float
, default:2.0
) –Focal loss power factor. Defaults to 2.0.
-
kwargs
(dict[str, typing.Any]
, default:{}
) –Additional keyword arguments which should be saved to the hyperparameter file.
Source code in darts-segmentation/src/darts_segmentation/training/module.py
loss_fn
instance-attribute
¶
loss_fn = segmentation_models_pytorch.losses.FocalLoss(
mode="binary",
alpha=darts_segmentation.training.module.LitSMP(
focal_loss_alpha
),
gamma=darts_segmentation.training.module.LitSMP(
focal_loss_gamma
),
ignore_index=2,
)
model
instance-attribute
¶
model = segmentation_models_pytorch.create_model(
**darts_segmentation.training.module.LitSMP(config)[
"model"
],
activation="sigmoid",
)
__repr__
¶
configure_optimizers
¶
Source code in darts-segmentation/src/darts_segmentation/training/module.py
on_train_epoch_end
¶
test_step
¶
training_step
¶
validation_step
¶
SMPSegmenterConfig
¶
Configuration for the segmentor.
from_ckpt
classmethod
¶
from_ckpt(
config: dict[str, typing.Any],
) -> darts_segmentation.segment.SMPSegmenterConfig
Validate the config for the segmentor.
Parameters:
Returns:
-
darts_segmentation.segment.SMPSegmenterConfig
–The validated configuration.