darts.legacy_training.train_smp¶
Run the training of the SMP model.
Please see https://smp.readthedocs.io/en/latest/index.html for model configurations.
Each training run is assigned a unique name and id pair and optionally a trial name.
The name, which the user can provide, should be used as a grouping mechanism of equal hyperparameter and code.
Hence, different versions of the same name should only differ by random state or run settings parameter, like logs.
Each version is assigned a unique id.
Artifacts (metrics & checkpoints) are then stored under {artifact_dir}/{run_name}/{run_id}
in no-crossval runs.
If trial_name
is specified, the artifacts are stored under {artifact_dir}/{trial_name}/{run_name}-{run_id}
.
Wandb logs are always stored under {wandb_entity}/{wandb_project}/{run_name}
, regardless of trial_name
.
However, they are further grouped by the trial_name
(via job_type), if specified.
Both run_name
and run_id
are also stored in the hparams of each checkpoint.
You can specify the frequency on how often logs will be written and validation will be performed.
- log_every_n_steps
specifies how often train-logs will be written. This does not affect validation.
- check_val_every_n_epoch
specifies how often validation will be performed.
This will also affect early stopping.
- early_stopping_patience
specifies how many epochs to wait for improvement before stopping.
In epochs, this would be check_val_every_n_epoch * early_stopping_patience
.
- plot_every_n_val_epochs
specifies how often validation samples will be plotted.
Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs
.
The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:
preprocessed-data/ # the top-level directory
├── config.toml
├── cross-val.zarr/ # this zarr group contains the dataarrays x and y for the training and validation
├── test.zarr/ # this zarr group contains the dataarrays x and y for the left-out-region test set
├── val-test.zarr/ # this zarr group contains the dataarrays x and y for the random selected validation set
└── labels.geojson
Parameters:
-
train_data_dir
(pathlib.Path
) –Path to the training data directory (top-level).
-
artifact_dir
(pathlib.Path
, default:pathlib.Path('lightning_logs')
) –Path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("lightning_logs").
-
fold
(int
, default:0
) –The current fold to train on. Must be in [0, 4]. Defaults to 0.
-
continue_from_checkpoint
(pathlib.Path | None
, default:None
) –Path to a checkpoint to continue training from. Defaults to None.
-
model_arch
(str
, default:'Unet'
) –Model architecture to use. Defaults to "Unet".
-
model_encoder
(str
, default:'dpn107'
) –Encoder to use. Defaults to "dpn107".
-
model_encoder_weights
(str | None
, default:None
) –Path to the encoder weights. Defaults to None.
-
augment
(bool
, default:True
) –Weather to apply augments or not. Defaults to True.
-
learning_rate
(float
, default:0.001
) –Learning Rate. Defaults to 1e-3.
-
gamma
(float
, default:0.9
) –Multiplicative factor of learning rate decay. Defaults to 0.9.
-
focal_loss_alpha
(float
, default:None
) –Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. None will not weight samples. Defaults to None.
-
focal_loss_gamma
(float
, default:2.0
) –Focal loss power factor. Defaults to 2.0.
-
batch_size
(int
, default:8
) –Batch Size. Defaults to 8.
-
max_epochs
(int
, default:100
) –Maximum number of epochs to train. Defaults to 100.
-
log_every_n_steps
(int
, default:10
) –Log every n steps. Defaults to 10.
-
check_val_every_n_epoch
(int
, default:3
) –Check validation every n epochs. Defaults to 3.
-
early_stopping_patience
(int
, default:5
) –Number of epochs to wait for improvement before stopping. Defaults to 5.
-
plot_every_n_val_epochs
(int
, default:5
) –Plot validation samples every n epochs. Defaults to 5.
-
random_seed
(int
, default:42
) –Random seed for deterministic training. Defaults to 42.
-
num_workers
(int
, default:0
) –Number of Dataloader workers. Defaults to 0.
-
device
(int | str
, default:'auto'
) –The device to run the model on. Defaults to "auto".
-
wandb_entity
(str | None
, default:None
) –Weights and Biases Entity. Defaults to None.
-
wandb_project
(str | None
, default:None
) –Weights and Biases Project. Defaults to None.
-
wandb_group
(str | None
, default:None
) –Wandb group. Usefull for CV-Sweeps. Defaults to None.
-
run_name
(str | None
, default:None
) –Name of this run, as a further grouping method for logs etc. If None, will generate a random one. Defaults to None.
-
run_id
(str | None
, default:None
) –ID of the run. If None, will generate a random one. Defaults to None.
-
trial_name
(str | None
, default:None
) –Name of the cross-validation run / trial. This effects primary logging and artifact storage. If None, will do nothing. Defaults to None.
Returns:
-
Trainer
(pytorch_lightning.Trainer
) –The trainer object used for training.
Source code in darts/src/darts/legacy_training/train.py
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 |
|