data
darts_segmentation.training.data
¶
Training script for DARTS segmentation.
Augmentation
module-attribute
¶
Augmentation = typing.Literal[
"HorizontalFlip",
"VerticalFlip",
"RandomRotate90",
"Blur",
"RandomBrightnessContrast",
"MultiplicativeNoise",
"Posterize",
]
Bands
¶
Bases: collections.UserList[darts_segmentation.utils.Band]
Wrapper for the list of bands.
factors
property
¶
names
property
¶
offsets
property
¶
__reduce__
¶
Source code in darts-segmentation/src/darts_segmentation/utils.py
filter
¶
filter(
band_names: list[str],
) -> darts_segmentation.utils.Bands
Filter the bands by name.
Parameters:
Returns:
-
Bands
(darts_segmentation.utils.Bands
) –The filtered Bands object.
Source code in darts-segmentation/src/darts_segmentation/utils.py
from_config
classmethod
¶
from_config(
config: dict[
typing.Literal[
"bands", "band_factors", "band_offsets"
],
list,
]
| dict[str, tuple[float, float]],
) -> darts_segmentation.utils.Bands
Create a Bands object from a config dictionary.
Parameters:
-
config
(dict
) –The config dictionary containing the band information. Expects config to be a dictionary with keys "bands", "band_factors" and "band_offsets", with the values to be lists of the same length.
Returns:
-
Bands
(darts_segmentation.utils.Bands
) –The Bands object.
Source code in darts-segmentation/src/darts_segmentation/utils.py
from_dict
classmethod
¶
Create a Bands object from a dictionary.
Parameters:
-
config
(dict[str, tuple[float, float]]
) –The dictionary containing the band information. Expects the keys to be the band names and the values to be tuples of (factor, offset). Example: {"band1": (1.0, 0.0), "band2": (2.0, 1.0)}
Returns:
-
Bands
(darts_segmentation.utils.Bands
) –The Bands object.
Source code in darts-segmentation/src/darts_segmentation/utils.py
to_config
¶
Convert the Bands object to a config dictionary.
Returns:
-
dict
(dict[typing.Literal['bands', 'band_factors', 'band_offsets'], list]
) –The config dictionary containing the band information.
Source code in darts-segmentation/src/darts_segmentation/utils.py
DartsDataModule
¶
DartsDataModule(
data_dir: pathlib.Path,
batch_size: int,
data_split_method: typing.Literal[
"random", "region", "sample"
]
| None = None,
data_split_by: list[str | float] | None = None,
fold_method: typing.Literal[
"kfold",
"shuffle",
"stratified",
"region",
"region-stratified",
]
| None = "kfold",
total_folds: int = 5,
fold: int = 0,
subsample: int | None = None,
bands: darts_segmentation.utils.Bands
| list[str]
| None = None,
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None = None,
num_workers: int = 0,
in_memory: bool = False,
)
Bases: lightning.LightningDataModule
Initialize the data module.
Supports spliting the data into train and test set while also defining cv-folds. Folding only applies to the non-test set and splits this into a train and validation set.
Example
-
Normal train-validate. (Can also be used for testing on the complete dataset)
-
Specifying a test split by random (20% of the data will be used for testing)
-
Specific fold for cross-validation (On the complete dataset, because data_split_method is "none"). This will be take the third of a total of7 folds to determine the validation set.
In general this should be used in combination with a cross-validation loop.
for fold in range(total_folds):
dm = DartsDataModule(
data_dir,
batch_size,
fold_method="region-stratified",
fold=fold,
total_folds=total_folds)
...
- Don't split anything -> only train
Parameters:
-
data_dir
(pathlib.Path
) –The path to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script.
-
batch_size
(int
) –Batch size for training and validation.
-
data_split_method
(typing.Literal['random', 'region', 'sample'] | None
, default:None
) –The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the test size can be specified by providing a list with a single a float between 0 and 1 to data_split_by This will be the fraction of the data to be used for testing. E.g. [0.2] will use 20% of the data for testing. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.
-
data_split_by
(list[str | float] | None
, default:None
) –Select by which regions/samples to split or the size of test set. Defaults to None.
-
fold_method
(typing.Literal['kfold', 'shuffle', 'stratified', 'region', 'region-stratified'] | None
, default:'kfold'
) –Method for cross-validation split. Defaults to "kfold".
-
total_folds
(int
, default:5
) –Total number of folds in cross-validation. Defaults to 5.
-
fold
(int
, default:0
) –Index of the current fold. Defaults to 0.
-
subsample
(int | None
, default:None
) –If set, will subsample the dataset to this number of samples. This is useful for debugging and testing. Defaults to None.
-
bands
(darts_segmentation.utils.Bands | list[str] | None
, default:None
) –List of bands to use. Expects the data_dir to contain a config.toml with a "darts.bands" key, with which the indices of the bands will be mapped to. Defaults to None.
-
augment
(bool
, default:None
) –Whether to augment the data. Does nothing for testing. Defaults to True.
-
num_workers
(int
, default:0
) –Number of workers for data loading. See torch.utils.data.DataLoader. Defaults to 0.
-
in_memory
(bool
, default:False
) –Whether to load the data into memory. Defaults to False.
Source code in darts-segmentation/src/darts_segmentation/training/data.py
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 |
|
bands
instance-attribute
¶
bands = (
[
data_bands.index(b)
for b in darts_segmentation.training.data.DartsDataModule(
bands
)
]
if darts_segmentation.training.data.DartsDataModule(
bands
)
else None
)
batch_size
instance-attribute
¶
batch_size = (
darts_segmentation.training.data.DartsDataModule(
batch_size
)
)
data_dir
instance-attribute
¶
data_dir = darts_segmentation.training.data.DartsDataModule(
data_dir
)
data_split_by
instance-attribute
¶
data_split_by = (
darts_segmentation.training.data.DartsDataModule(
data_split_by
)
)
data_split_method
instance-attribute
¶
data_split_method = (
darts_segmentation.training.data.DartsDataModule(
data_split_method
)
)
fold_method
instance-attribute
¶
fold_method = (
darts_segmentation.training.data.DartsDataModule(
fold_method
)
)
in_memory
instance-attribute
¶
in_memory = (
darts_segmentation.training.data.DartsDataModule(
in_memory
)
)
num_workers
instance-attribute
¶
num_workers = (
darts_segmentation.training.data.DartsDataModule(
num_workers
)
)
subsample
instance-attribute
¶
subsample = (
darts_segmentation.training.data.DartsDataModule(
subsample
)
)
total_folds
instance-attribute
¶
total_folds = (
darts_segmentation.training.data.DartsDataModule(
total_folds
)
)
setup
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
test_dataloader
¶
train_dataloader
¶
DartsDatasetInMemory
¶
DartsDatasetInMemory(
data_dir: pathlib.Path | str,
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None = None,
indices: list[int] | None = None,
bands: list[int] | None = None,
)
Bases: torch.utils.data.Dataset
Source code in darts-segmentation/src/darts_segmentation/training/data.py
transform
instance-attribute
¶
transform = darts_segmentation.training.augmentations.get_augmentation(
darts_segmentation.training.data.DartsDatasetInMemory(
augment
)
)
__getitem__
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
DartsDatasetZarr
¶
DartsDatasetZarr(
data_dir: pathlib.Path | str,
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None = None,
indices: list[int] | None = None,
bands: list[int] | None = None,
)
Bases: torch.utils.data.Dataset
Source code in darts-segmentation/src/darts_segmentation/training/data.py
indices
instance-attribute
¶
indices = (
darts_segmentation.training.data.DartsDatasetZarr(
indices
)
if darts_segmentation.training.data.DartsDatasetZarr(
indices
)
is not None
else list(
range(
darts_segmentation.training.data.DartsDatasetZarr(
self
)
.zroot["x"]
.shape[0]
)
)
)
transform
instance-attribute
¶
transform = darts_segmentation.training.augmentations.get_augmentation(
darts_segmentation.training.data.DartsDatasetZarr(
augment
)
)
__getitem__
¶
Source code in darts-segmentation/src/darts_segmentation/training/data.py
_get_fold
¶
_get_fold(
metadata: geopandas.GeoDataFrame,
fold_method: typing.Literal[
"kfold",
"shuffle",
"stratified",
"region",
"region-stratified",
"none",
]
| None,
n_folds: int,
fold: int,
) -> tuple[list[int], list[int]]
Source code in darts-segmentation/src/darts_segmentation/training/data.py
_log_stats
¶
_log_stats(metadata: geopandas.GeoDataFrame, mode: str)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
_split_metadata
¶
_split_metadata(
metadata: geopandas.GeoDataFrame,
data_split_method: typing.Literal[
"random", "region", "sample", "none"
]
| None,
data_split_by: list[str | float] | None,
)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
get_augmentation
¶
get_augmentation(
augment: list[
darts_segmentation.training.augmentations.Augmentation
]
| None,
) -> albumentations.Compose | None
Get augmentations for segmentation tasks.
Parameters:
-
augment
(list[darts_segmentation.training.augmentations.Augmentation] | None
) –List of augmentations to apply. If None or emtpy, no augmentations are applied. If not empty, augmentations are applied in the order they are listed. Available augmentations: - HorizontalFlip - VerticalFlip - RandomRotate90 - Blur - RandomBrightnessContrast - MultiplicativeNoise
Raises:
-
ValueError
–If an unknown augmentation is provided.
Returns:
-
albumentations.Compose | None
–A.Compose | None: A Compose object containing the augmentations. If no augmentations are provided, returns None.