train
darts_segmentation.training.train
¶
Training scripts for DARTS.
DataConfig
dataclass
¶
DataConfig(
train_data_dir: pathlib.Path = pathlib.Path("train"),
data_split_method: typing.Literal[
"random", "region", "sample"
]
| None = None,
data_split_by: list[str | float] | None = None,
fold_method: typing.Literal[
"kfold",
"shuffle",
"stratified",
"region",
"region-stratified",
] = "kfold",
total_folds: int = 5,
subsample: int | None = None,
)
Data related parameters for training.
Defines the script inputs for the training script and can be propagated by the cross-validation and tuning scripts.
Attributes:
-
train_data_dir
(pathlib.Path
) –The path (top-level) to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script. Defaults to "train".
-
batch_size
(int
) –Batch size for training and validation.
-
data_split_method
(typing.Literal['random', 'region', 'sample'] | None
) –The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the test size can be specified by providing a list with a single a float between 0 and 1 to data_split_by This will be the fraction of the data to be used for testing. E.g. [0.2] will use 20% of the data for testing. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.
-
data_split_by
(list[str | float] | None
) –Select by which regions/samples to split or the size of test set. Defaults to None.
-
fold_method
(typing.Literal['kfold', 'shuffle', 'stratified', 'region', 'region-stratified']
) –Method for cross-validation split. Defaults to "kfold".
-
total_folds
(int
) –Total number of folds in cross-validation. Defaults to 5.
-
subsample
(int | None
) –If set, will subsample the dataset to this number of samples. This is useful for debugging and testing. Defaults to None.
DeviceConfig
dataclass
¶
DeviceConfig(
accelerator: typing.Literal[
"auto", "cpu", "gpu", "mps", "tpu"
] = "auto",
strategy: typing.Literal[
"auto",
"ddp",
"ddp_fork",
"ddp_notebook",
"fsdp",
"cv-parallel",
"tune-parallel",
] = "auto",
devices: list[int | str] = lambda: ["auto"](),
num_nodes: int = 1,
)
Device and Distributed Strategy related parameters.
Attributes:
-
accelerator
(typing.Literal['auto', 'cpu', 'gpu', 'mps', 'tpu']
) –Accelerator to use. Defaults to "auto".
-
strategy
(typing.Literal['auto', 'ddp', 'ddp_fork', 'ddp_notebook', 'fsdp', 'cv-parallel', 'tune-parallel', 'cv-parallel', 'tune-parallel']
) –Distributed strategy to use. Defaults to "auto".
-
devices
(list[int | str]
) –List of devices to use. Defaults to ["auto"].
-
num_nodes
(int
) –Number of nodes to use for distributed training. Defaults to 1.
accelerator
class-attribute
instance-attribute
¶
devices
class-attribute
instance-attribute
¶
devices: list[int | str] = dataclasses.field(
default_factory=lambda: ["auto"]
)
lightning_strategy
property
¶
lightning_strategy: str
Get the Lightning strategy for the current configuration.
Returns:
-
str
(str
) –The Lightning strategy to use.
strategy
class-attribute
instance-attribute
¶
strategy: typing.Literal[
"auto",
"ddp",
"ddp_fork",
"ddp_notebook",
"fsdp",
"cv-parallel",
"tune-parallel",
] = "auto"
in_parallel
¶
in_parallel(
device: int | str | None = None,
) -> darts_segmentation.training.train.DeviceConfig
Turn the current configuration into a suitable configuration for parallel training.
Parameters:
-
device
(int | str | None
, default:None
) –The device to use for parallel training. If None, assumes non-multiprocessing parallel training and propagate all devices. Defaults to None.
Returns:
-
DeviceConfig
(darts_segmentation.training.train.DeviceConfig
) –A new DeviceConfig instance that is suitable for parallel training.
Source code in darts-segmentation/src/darts_segmentation/training/train.py
Hyperparameters
dataclass
¶
Hyperparameters(
model_arch: str = "Unet",
model_encoder: str = "dpn107",
model_encoder_weights: str | None = None,
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None = None,
learning_rate: float = 0.001,
gamma: float = 0.9,
focal_loss_alpha: float | None = None,
focal_loss_gamma: float = 2.0,
batch_size: int = 8,
bands: list[str] | None = None,
)
Hyperparameters for Cyclopts CLI.
Attributes:
-
model_arch
(str
) –Architecture of the model to use.
-
model_encoder
(str
) –Encoder type for the model.
-
model_encoder_weights
(str | None
) –Weights for the encoder, if any.
-
augment
(list[darts_segmentation.training.augmentations.Augmentation] | None
) –List of augmentations to apply.
-
learning_rate
(float
) –Learning rate for training.
-
gamma
(float
) –Decay factor for learning rate.
-
focal_loss_alpha
(float | None
) –Alpha parameter for focal loss, if using.
-
focal_loss_gamma
(float
) –Gamma parameter for focal loss.
-
batch_size
(int
) –Batch size for training.
-
bands
(list[str] | None
) –List of bands to use. Defaults to None.
augment
class-attribute
instance-attribute
¶
augment: (
list[
darts_segmentation.training.augmentations.Augmentation
]
| None
) = None
LoggingConfig
dataclass
¶
LoggingConfig(
artifact_dir: pathlib.Path = pathlib.Path("artifacts"),
log_every_n_steps: int = 10,
check_val_every_n_epoch: int = 3,
plot_every_n_val_epochs: int = 5,
wandb_entity: str | None = None,
wandb_project: str | None = None,
)
Logging related parameters for training.
Defines the script inputs for the training script and can be propagated by the cross-validation and tuning scripts.
Attributes:
-
artifact_dir
(pathlib.Path
) –Top-level path to the training output directory. Will contain checkpoints and metrics. Defaults to Path("artifacts").
-
log_every_n_steps
(int
) –Log every n steps. Defaults to 10.
-
check_val_every_n_epoch
(int
) –Check validation every n epochs. Defaults to 3.
-
plot_every_n_val_epochs
(int
) –Plot validation samples every n epochs. Defaults to 5.
-
wandb_entity
(str | None
) –Weights and Biases Entity. Defaults to None.
-
wandb_project
(str | None
) –Weights and Biases Project. Defaults to None.
artifact_dir
class-attribute
instance-attribute
¶
artifact_dir_at_cv
¶
Nest the artifact directory for cross-validation runs.
Similar to parse_artifact_dir_for_run
, but meant to be used by the cross-validation script.
Also creates the directory if it does not exist.
Parameters:
-
tune_name
(str | None
) –Name of the tuning, if applicable.
Returns:
Source code in darts-segmentation/src/darts_segmentation/training/train.py
artifact_dir_at_run
¶
Nest the artifact directory to avoid cluttering the root directory.
For cv it is expected that the cv function already nests the artifact directory Meaning for cv the artifact_dir of this function should be either {artifact_dir}/_cross_validations/{cv_name} or {artifact_dir}/{tune_name}/{cv_name}
Also creates the directory if it does not exist.
Parameters:
Raises:
-
ValueError
–If tune_name is specified, but cv_name is not, which is invalid.
Returns:
Source code in darts-segmentation/src/darts_segmentation/training/train.py
TrainRunConfig
dataclass
¶
TrainRunConfig(
name: str | None = None,
cv_name: str | None = None,
tune_name: str | None = None,
fold: int = 0,
random_seed: int = 42,
)
Run related parameters for training.
Defines the script inputs for the training script. Must be build by the cross-validation and tuning scripts.
Attributes:
-
name
(str | None
) –Name of the run. If None is generated automatically. Defaults to None.
-
cv_name
(str | None
) –Name of the cross-validation. Should only be specified by a cross-validation script. Defaults to None.
-
tune_name
(str | None
) –Name of the tuning. Should only be specified by a tuning script. Defaults to None.
-
fold
(int
) –Index of the current fold. Defaults to 0.
-
random_seed
(int
) –Random seed for deterministic training. Defaults to 42.
TrainingConfig
dataclass
¶
TrainingConfig(
continue_from_checkpoint: pathlib.Path | None = None,
max_epochs: int = 100,
early_stopping_patience: int = 5,
num_workers: int = 0,
)
Training related parameters for training.
Defines the script inputs for the training script and can be propagated by the cross-validation and tuning scripts.
Attributes:
-
continue_from_checkpoint
(pathlib.Path | None
) –Path to a checkpoint to continue training from. Defaults to None.
-
max_epochs
(int
) –Maximum number of epochs to train. Defaults to 100.
-
early_stopping_patience
(int
) –Number of epochs to wait for improvement before stopping. Defaults to 5.
-
num_workers
(int
) –Number of Dataloader workers. Defaults to 0.
continue_from_checkpoint
class-attribute
instance-attribute
¶
convert_lightning_checkpoint
¶
convert_lightning_checkpoint(
*,
lightning_checkpoint: pathlib.Path,
out_directory: pathlib.Path,
checkpoint_name: str,
framework: str = "smp",
)
Convert a lightning checkpoint to our own format.
The final checkpoint will contain the model configuration and the state dict. It will be saved to:
Parameters:
-
lightning_checkpoint
(pathlib.Path
) –Path to the lightning checkpoint.
-
out_directory
(pathlib.Path
) –Output directory for the converted checkpoint.
-
checkpoint_name
(str
) –A unique name of the new checkpoint.
-
framework
(str
, default:'smp'
) –The framework used for the model. Defaults to "smp".
Source code in darts-segmentation/src/darts_segmentation/training/train.py
test_smp
¶
test_smp(
*,
train_data_dir: pathlib.Path,
run_id: str,
run_name: str,
model_ckp: pathlib.Path | None = None,
batch_size: int = 8,
data_split_method: typing.Literal[
"random", "region", "sample"
]
| None = None,
data_split_by: list[str] | str | float | None = None,
bands: list[str] | None = None,
artifact_dir: pathlib.Path = pathlib.Path("artifacts"),
num_workers: int = 0,
device_config: darts_segmentation.training.train.DeviceConfig = darts_segmentation.training.train.DeviceConfig(),
wandb_entity: str | None = None,
wandb_project: str | None = None,
) -> pytorch_lightning.Trainer
Run the testing of the SMP model.
The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:
preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson
Parameters:
-
train_data_dir
(pathlib.Path
) –The path (top-level) to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script.
-
run_id
(str
) –ID of the run.
-
run_name
(str
) –Name of the run.
-
model_ckp
(pathlib.Path | None
, default:None
) –Path to the model checkpoint. If None, try to find the latest checkpoint in
artifact_dir / run_name / run_id / checkpoints
. Defaults to None. -
batch_size
(int
, default:8
) –Batch size for training and validation.
-
data_split_method
(typing.Literal['random', 'region', 'sample'] | None
, default:None
) –The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the size of the test set can be specified by providing a float between 0 and 1 to data_split_by. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.
-
data_split_by
(list[str] | str | float | None
, default:None
) –Select by which seed/regions/samples split. Defaults to None.
-
bands
(list[str] | None
, default:None
) –List of bands to use. Defaults to None.
-
artifact_dir
(pathlib.Path
, default:pathlib.Path('artifacts')
) –Directory to save artifacts. Defaults to Path("lightning_logs").
-
num_workers
(int
, default:0
) –Number of workers for the DataLoader. Defaults to 0.
-
device_config
(darts_segmentation.training.train.DeviceConfig
, default:darts_segmentation.training.train.DeviceConfig()
) –Device and distributed strategy related parameters.
-
wandb_entity
(str | None
, default:None
) –WandB entity. Defaults to None.
-
wandb_project
(str | None
, default:None
) –WandB project. Defaults to None.
Returns:
-
Trainer
(pytorch_lightning.Trainer
) –The trainer object used for training.
Source code in darts-segmentation/src/darts_segmentation/training/train.py
536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 |
|
train_smp
¶
train_smp(
*,
run: darts_segmentation.training.train.TrainRunConfig = darts_segmentation.training.train.TrainRunConfig(),
training_config: darts_segmentation.training.train.TrainingConfig = darts_segmentation.training.train.TrainingConfig(),
data_config: darts_segmentation.training.train.DataConfig = darts_segmentation.training.train.DataConfig(),
logging_config: darts_segmentation.training.train.LoggingConfig = darts_segmentation.training.train.LoggingConfig(),
device_config: darts_segmentation.training.train.DeviceConfig = darts_segmentation.training.train.DeviceConfig(),
hparams: darts_segmentation.training.hparams.Hyperparameters = darts_segmentation.training.hparams.Hyperparameters(),
)
Run the training of the SMP model, specifically binary segmentation.
Please see https://smp.readthedocs.io/en/latest/index.html for model configurations of architecture and encoder.
Please also consider reading our training guide (docs/guides/training.md).
This training function is meant for single training runs but is also used for cross-validation and hyperparameter tuning by cv.py and tune.py. This strongly affects where artifacts are stored:
- Run was created by a tune:
{artifact_dir}/{tune_name}/{cv_name}/{run_name}-{run_id}
- Run was created by a cross-validation:
{artifact_dir}/_cross_validations/{cv_name}/{run_name}-{run_id}
- Single runs:
{artifact_dir}/_runs/{run_name}-{run_id}
run_name
can be specified by the user, else it is generated automatically.
In case of cross-validation, the run name is generated automatically by the cross-validation.
run_id
is generated automatically by the training function.
Both are saved to the final checkpoint.
You can specify the frequency on how often logs will be written and validation will be performed.
- log_every_n_steps
specifies how often train-logs will be written. This does not affect validation.
- check_val_every_n_epoch
specifies how often validation will be performed.
This will also affect early stopping.
- early_stopping_patience
specifies how many epochs to wait for improvement before stopping.
In epochs, this would be check_val_every_n_epoch * early_stopping_patience
.
- plot_every_n_val_epochs
specifies how often validation samples will be plotted.
Since plotting is quite costly, you can reduce the frequency. Works similar like early stopping.
In epochs, this would be check_val_every_n_epoch * plot_every_n_val_epochs
.
Example: There are 400 training samples and the batch size is 2, resulting in 200 training steps per epoch.
If log_every_n_steps
is set to 50 then the training logs and metrics will be logged 4 times per epoch.
If check_val_every_n_epoch
is set to 5 then validation will be performed every 5 epochs.
If plot_every_n_val_epochs
is set to 2 then validation samples will be plotted every 10 epochs.
If early_stopping_patience
is set to 3 then early stopping will be performed after 15 epochs without improvement.
The data structure of the training data expects the "preprocessing" step to be done beforehand, which results in the following data structure:
preprocessed-data/ # the top-level directory
├── config.toml
├── data.zarr/ # this zarr group contains the dataarrays x and y
├── metadata.parquet # this contains information necessary to split the data into train, val, and test sets.
└── labels.geojson
Parameters:
-
data_config
(darts_segmentation.training.train.DataConfig
, default:darts_segmentation.training.train.DataConfig()
) –Data related parameters for training.
-
run
(darts_segmentation.training.train.TrainRunConfig
, default:darts_segmentation.training.train.TrainRunConfig()
) –Run related parameters for training.
-
logging_config
(darts_segmentation.training.train.LoggingConfig
, default:darts_segmentation.training.train.LoggingConfig()
) –Logging related parameters for training.
-
device_config
(darts_segmentation.training.train.DeviceConfig
, default:darts_segmentation.training.train.DeviceConfig()
) –Device and distributed strategy related parameters.
-
training_config
(darts_segmentation.training.train.TrainingConfig
, default:darts_segmentation.training.train.TrainingConfig()
) –Training related parameters for training.
-
hparams
(darts_segmentation.training.hparams.Hyperparameters
, default:darts_segmentation.training.hparams.Hyperparameters()
) –Hyperparameters for the model.
Returns:
-
–
pl.Trainer: The trainer object used for training. Contains also metrics.
Source code in darts-segmentation/src/darts_segmentation/training/train.py
293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 |
|