Skip to content

data

darts_segmentation.training.data

Training script for DARTS segmentation.

Augmentation module-attribute

Augmentation = typing.Literal[
    "HorizontalFlip",
    "VerticalFlip",
    "RandomRotate90",
    "Blur",
    "RandomBrightnessContrast",
    "MultiplicativeNoise",
    "Posterize",
]

logger module-attribute

logger = logging.getLogger(
    __name__.replace("darts_", "darts.")
)

Bands

Bases: collections.UserList[darts_segmentation.utils.Band]

Wrapper for the list of bands.

factors property

factors: list[float]

Get the factors of the bands.

Returns:

  • list[float]

    list[float]: The factors of the bands.

names property

names: list[str]

Get the names of the bands.

Returns:

  • list[str]

    list[str]: The names of the bands.

offsets property

offsets: list[float]

Get the offsets of the bands.

Returns:

  • list[float]

    list[float]: The offsets of the bands.

__reduce__

__reduce__()
Source code in darts-segmentation/src/darts_segmentation/utils.py
def __reduce__(self):  # noqa: D105
    # This is needed to pickle (and unpickle) the Bands object as a dict
    # This is needed, because this way we don't need to have this class present when unpickling
    # a pytorch checkpoint
    return (dict, (self.to_config(),))

__repr__

__repr__() -> str
Source code in darts-segmentation/src/darts_segmentation/utils.py
def __repr__(self) -> str:  # noqa: D105
    band_info = ", ".join([f"{band.name}(*{band.factor:.5f}+{band.offset:.5f})" for band in self])
    return f"Bands({band_info})"

filter

filter(
    band_names: list[str],
) -> darts_segmentation.utils.Bands

Filter the bands by name.

Parameters:

  • band_names (list[str]) –

    The names of the bands to keep.

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
def filter(self, band_names: list[str]) -> "Bands":
    """Filter the bands by name.

    Args:
        band_names (list[str]): The names of the bands to keep.

    Returns:
        Bands: The filtered Bands object.

    """
    return Bands([band for band in self if band.name in band_names])

from_config classmethod

from_config(
    config: dict[
        typing.Literal[
            "bands", "band_factors", "band_offsets"
        ],
        list,
    ]
    | dict[str, tuple[float, float]],
) -> darts_segmentation.utils.Bands

Create a Bands object from a config dictionary.

Parameters:

  • config (dict) –

    The config dictionary containing the band information. Expects config to be a dictionary with keys "bands", "band_factors" and "band_offsets", with the values to be lists of the same length.

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
@classmethod
def from_config(
    cls,
    config: dict[Literal["bands", "band_factors", "band_offsets"], list] | dict[str, tuple[float, float]],
) -> "Bands":
    """Create a Bands object from a config dictionary.

    Args:
        config (dict): The config dictionary containing the band information.
            Expects config to be a dictionary with keys "bands", "band_factors" and "band_offsets",
            with the values to be lists of the same length.

    Returns:
        Bands: The Bands object.

    """
    assert "bands" in config and "band_factors" in config and "band_offsets" in config, (
        f"Config must contain keys 'bands', 'band_factors' and 'band_offsets'.Got {config} instead."
    )
    return cls(
        [
            Band(name=name, factor=factor, offset=offset)
            for name, factor, offset in zip(config["bands"], config["band_factors"], config["band_offsets"])
        ]
    )

from_dict classmethod

from_dict(
    config: dict[str, tuple[float, float]],
) -> darts_segmentation.utils.Bands

Create a Bands object from a dictionary.

Parameters:

  • config (dict[str, tuple[float, float]]) –

    The dictionary containing the band information. Expects the keys to be the band names and the values to be tuples of (factor, offset). Example: {"band1": (1.0, 0.0), "band2": (2.0, 1.0)}

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
@classmethod
def from_dict(cls, config: dict[str, tuple[float, float]]) -> "Bands":
    """Create a Bands object from a dictionary.

    Args:
        config (dict[str, tuple[float, float]]): The dictionary containing the band information.
            Expects the keys to be the band names and the values to be tuples of (factor, offset).
            Example: {"band1": (1.0, 0.0), "band2": (2.0, 1.0)}

    Returns:
        Bands: The Bands object.

    """
    return cls([Band(name=name, factor=factor, offset=offset) for name, (factor, offset) in config.items()])

to_config

to_config() -> dict[
    typing.Literal["bands", "band_factors", "band_offsets"],
    list,
]

Convert the Bands object to a config dictionary.

Returns:

  • dict ( dict[typing.Literal['bands', 'band_factors', 'band_offsets'], list] ) –

    The config dictionary containing the band information.

Source code in darts-segmentation/src/darts_segmentation/utils.py
def to_config(self) -> dict[Literal["bands", "band_factors", "band_offsets"], list]:
    """Convert the Bands object to a config dictionary.

    Returns:
        dict: The config dictionary containing the band information.

    """
    return {
        "bands": [band.name for band in self],
        "band_factors": [band.factor for band in self],
        "band_offsets": [band.offset for band in self],
    }

to_dict

to_dict() -> dict[str, tuple[float, float]]

Convert the Bands object to a dictionary.

Returns:

  • dict[str, tuple[float, float]]

    dict[str, tuple[float, float]]: The dictionary containing the band information.

Source code in darts-segmentation/src/darts_segmentation/utils.py
def to_dict(self) -> dict[str, tuple[float, float]]:
    """Convert the Bands object to a dictionary.

    Returns:
        dict[str, tuple[float, float]]: The dictionary containing the band information.

    """
    return {band.name: (band.factor, band.offset) for band in self}

DartsDataModule

DartsDataModule(
    data_dir: pathlib.Path,
    batch_size: int,
    data_split_method: typing.Literal[
        "random", "region", "sample"
    ]
    | None = None,
    data_split_by: list[str | float] | None = None,
    fold_method: typing.Literal[
        "kfold",
        "shuffle",
        "stratified",
        "region",
        "region-stratified",
    ]
    | None = "kfold",
    total_folds: int = 5,
    fold: int = 0,
    subsample: int | None = None,
    bands: darts_segmentation.utils.Bands
    | list[str]
    | None = None,
    augment: list[
        darts_segmentation.training.augmentations.Augmentation
    ]
    | None = None,
    num_workers: int = 0,
    in_memory: bool = False,
)

Bases: lightning.LightningDataModule

Initialize the data module.

Supports spliting the data into train and test set while also defining cv-folds. Folding only applies to the non-test set and splits this into a train and validation set.

Example
  1. Normal train-validate. (Can also be used for testing on the complete dataset)

    dm = DartsDataModule(data_dir, batch_size)
    

  2. Specifying a test split by random (20% of the data will be used for testing)

    dm = DartsDataModule(data_dir, batch_size, data_split_method="random")
    

  3. Specific fold for cross-validation (On the complete dataset, because data_split_method is "none"). This will be take the third of a total of7 folds to determine the validation set.

    dm = DartsDataModule(data_dir, batch_size, fold_method="region-stratified", fold=2, total_folds=7)
    

In general this should be used in combination with a cross-validation loop.

for fold in range(total_folds):
    dm = DartsDataModule(
        data_dir,
        batch_size,
        fold_method="region-stratified",
        fold=fold,
        total_folds=total_folds)
    ...

  1. Don't split anything -> only train
    dm = DartsDataModule(data_dir, batch_size, fold_method=None)
    

Parameters:

  • data_dir (pathlib.Path) –

    The path to the data to be used for training. Expects a directory containing: 1. a zarr group called "data.zarr" containing a "x" and "y" array 2. a geoparquet file called "metadata.parquet" containing the metadata for the data. This metadata should contain at least the following columns: - "sample_id": The id of the sample - "region": The region the sample belongs to - "empty": Whether the image is empty The index should refer to the index of the sample in the zarr data. This directory should be created by a preprocessing script.

  • batch_size (int) –

    Batch size for training and validation.

  • data_split_method (typing.Literal['random', 'region', 'sample'] | None, default: None ) –

    The method to use for splitting the data into a train and a test set. "random" will split the data randomly, the seed is always 42 and the test size can be specified by providing a list with a single a float between 0 and 1 to data_split_by This will be the fraction of the data to be used for testing. E.g. [0.2] will use 20% of the data for testing. "region" will split the data by one or multiple regions, which can be specified by providing a str or list of str to data_split_by. "sample" will split the data by sample ids, which can also be specified similar to "region". If None, no split is done and the complete dataset is used for both training and testing. The train split will further be split in the cross validation process. Defaults to None.

  • data_split_by (list[str | float] | None, default: None ) –

    Select by which regions/samples to split or the size of test set. Defaults to None.

  • fold_method (typing.Literal['kfold', 'shuffle', 'stratified', 'region', 'region-stratified'] | None, default: 'kfold' ) –

    Method for cross-validation split. Defaults to "kfold".

  • total_folds (int, default: 5 ) –

    Total number of folds in cross-validation. Defaults to 5.

  • fold (int, default: 0 ) –

    Index of the current fold. Defaults to 0.

  • subsample (int | None, default: None ) –

    If set, will subsample the dataset to this number of samples. This is useful for debugging and testing. Defaults to None.

  • bands (darts_segmentation.utils.Bands | list[str] | None, default: None ) –

    List of bands to use. Expects the data_dir to contain a config.toml with a "darts.bands" key, with which the indices of the bands will be mapped to. Defaults to None.

  • augment (bool, default: None ) –

    Whether to augment the data. Does nothing for testing. Defaults to True.

  • num_workers (int, default: 0 ) –

    Number of workers for data loading. See torch.utils.data.DataLoader. Defaults to 0.

  • in_memory (bool, default: False ) –

    Whether to load the data into memory. Defaults to False.

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(
    self,
    data_dir: Path,
    batch_size: int,
    # data_split is for the test split
    data_split_method: Literal["random", "region", "sample"] | None = None,
    data_split_by: list[str | float] | None = None,
    # fold is for cross-validation split (train/val)
    fold_method: Literal["kfold", "shuffle", "stratified", "region", "region-stratified"] | None = "kfold",
    total_folds: int = 5,
    fold: int = 0,
    subsample: int | None = None,
    bands: Bands | list[str] | None = None,
    augment: list[Augmentation] | None = None,  # Not used for val or test
    num_workers: int = 0,
    in_memory: bool = False,
):
    """Initialize the data module.

    Supports spliting the data into train and test set while also defining cv-folds.
    Folding only applies to the non-test set and splits this into a train and validation set.

    Example:
        1. Normal train-validate. (Can also be used for testing on the complete dataset)
        ```py
        dm = DartsDataModule(data_dir, batch_size)
        ```

        2. Specifying a test split by random (20% of the data will be used for testing)
        ```py
        dm = DartsDataModule(data_dir, batch_size, data_split_method="random")
        ```

        3. Specific fold for cross-validation (On the complete dataset, because data_split_method is "none").
        This will be take the third of a total of7 folds to determine the validation set.
        ```py
        dm = DartsDataModule(data_dir, batch_size, fold_method="region-stratified", fold=2, total_folds=7)
        ```

        In general this should be used in combination with a cross-validation loop.
        ```py
        for fold in range(total_folds):
            dm = DartsDataModule(
                data_dir,
                batch_size,
                fold_method="region-stratified",
                fold=fold,
                total_folds=total_folds)
            ...
        ```

        4. Don't split anything -> only train
        ```py
        dm = DartsDataModule(data_dir, batch_size, fold_method=None)
        ```

    Args:
        data_dir (Path): The path to the data to be used for training.
            Expects a directory containing:
            1. a zarr group called "data.zarr" containing a "x" and "y" array
            2. a geoparquet file called "metadata.parquet" containing the metadata for the data.
                This metadata should contain at least the following columns:
                - "sample_id": The id of the sample
                - "region": The region the sample belongs to
                - "empty": Whether the image is empty
                The index should refer to the index of the sample in the zarr data.
            This directory should be created by a preprocessing script.
        batch_size (int): Batch size for training and validation.
        data_split_method (Literal["random", "region", "sample"] | None, optional):
            The method to use for splitting the data into a train and a test set.
            "random" will split the data randomly, the seed is always 42 and the test size can be specified
            by providing a list with a single a float between 0 and 1 to data_split_by
            This will be the fraction of the data to be used for testing.
            E.g. [0.2] will use 20% of the data for testing.
            "region" will split the data by one or multiple regions,
            which can be specified by providing a str or list of str to data_split_by.
            "sample" will split the data by sample ids, which can also be specified similar to "region".
            If None, no split is done and the complete dataset is used for both training and testing.
            The train split will further be split in the cross validation process.
            Defaults to None.
        data_split_by (list[str | float] | None, optional): Select by which regions/samples to split or
            the size of test set. Defaults to None.
        fold_method (Literal["kfold", "shuffle", "stratified", "region", "region-stratified"] | None, optional):
            Method for cross-validation split. Defaults to "kfold".
        total_folds (int, optional): Total number of folds in cross-validation. Defaults to 5.
        fold (int, optional): Index of the current fold. Defaults to 0.
        subsample (int | None, optional): If set, will subsample the dataset to this number of samples.
            This is useful for debugging and testing. Defaults to None.
        bands (Bands | list[str] | None, optional): List of bands to use.
            Expects the data_dir to contain a config.toml with a "darts.bands" key,
            with which the indices of the bands will be mapped to.
            Defaults to None.
        augment (bool, optional): Whether to augment the data. Does nothing for testing. Defaults to True.
        num_workers (int, optional): Number of workers for data loading. See torch.utils.data.DataLoader.
            Defaults to 0.
        in_memory (bool, optional): Whether to load the data into memory. Defaults to False.

    """
    super().__init__()
    self.save_hyperparameters(ignore=["num_workers", "in_memory"])
    self.data_dir = data_dir
    self.batch_size = batch_size

    self.fold = fold
    self.data_split_method = data_split_method
    self.data_split_by = data_split_by
    self.fold_method = fold_method
    self.total_folds = total_folds

    self.subsample = subsample
    self.augment = augment
    self.num_workers = num_workers
    self.in_memory = in_memory

    data_dir = Path(data_dir)

    metadata_file = data_dir / "metadata.parquet"
    assert metadata_file.exists(), f"Metadata file {metadata_file} not found!"

    config_file = data_dir / "config.toml"
    assert config_file.exists(), f"Config file {config_file} not found!"
    data_bands = toml.load(config_file)["darts"]["bands"]
    bands = bands.names if isinstance(bands, Bands) else bands
    self.bands = [data_bands.index(b) for b in bands] if bands else None

    zdir = data_dir / "data.zarr"
    assert zdir.exists(), f"Data directory {zdir} not found!"
    zroot = zarr.group(store=LocalStore(data_dir / "data.zarr"))
    self.nsamples = zroot["x"].shape[0]
    logger.debug(f"Data directory {zdir} found with {self.nsamples} samples.")

augment instance-attribute

bands instance-attribute

bands = (
    [
        data_bands.index(b)
        for b in darts_segmentation.training.data.DartsDataModule(
            bands
        )
    ]
    if darts_segmentation.training.data.DartsDataModule(
        bands
    )
    else None
)

batch_size instance-attribute

batch_size = (
    darts_segmentation.training.data.DartsDataModule(
        batch_size
    )
)

data_dir instance-attribute

data_split_by instance-attribute

data_split_by = (
    darts_segmentation.training.data.DartsDataModule(
        data_split_by
    )
)

data_split_method instance-attribute

data_split_method = (
    darts_segmentation.training.data.DartsDataModule(
        data_split_method
    )
)

fold instance-attribute

fold_method instance-attribute

fold_method = (
    darts_segmentation.training.data.DartsDataModule(
        fold_method
    )
)

in_memory instance-attribute

in_memory = (
    darts_segmentation.training.data.DartsDataModule(
        in_memory
    )
)

nsamples instance-attribute

nsamples = zroot['x'].shape[0]

num_workers instance-attribute

num_workers = (
    darts_segmentation.training.data.DartsDataModule(
        num_workers
    )
)

subsample instance-attribute

subsample = (
    darts_segmentation.training.data.DartsDataModule(
        subsample
    )
)

total_folds instance-attribute

total_folds = (
    darts_segmentation.training.data.DartsDataModule(
        total_folds
    )
)

setup

setup(
    stage: typing.Literal[
        "fit", "validate", "test", "predict"
    ]
    | None = None,
)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def setup(self, stage: Literal["fit", "validate", "test", "predict"] | None = None):
    if stage == "predict" or stage is None:
        return

    metadata = gpd.read_parquet(self.data_dir / "metadata.parquet")
    if self.subsample is not None:
        metadata = metadata.sample(n=self.subsample, random_state=42)
    train_metadata, test_metadata = _split_metadata(metadata, self.data_split_method, self.data_split_by)

    _log_stats(train_metadata, "train-split")
    _log_stats(test_metadata, "test-split")

    # Log stats about the data

    if stage in ["fit", "validate"]:
        train_index, val_index = _get_fold(train_metadata, self.fold_method, self.total_folds, self.fold)
        _log_stats(metadata.loc[train_index], "train-fold")
        _log_stats(metadata.loc[val_index], "val-fold")

        dsclass = DartsDatasetInMemory if self.in_memory else DartsDatasetZarr
        self.train = dsclass(self.data_dir / "data.zarr", self.augment, train_index, self.bands)
        self.val = dsclass(self.data_dir / "data.zarr", None, val_index, self.bands)
    if stage == "test":
        test_index = test_metadata.index.tolist()
        dsclass = DartsDatasetInMemory if self.in_memory else DartsDatasetZarr
        self.test = dsclass(self.data_dir / "data.zarr", None, test_index, self.bands)

test_dataloader

test_dataloader()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def test_dataloader(self):
    return DataLoader(
        self.test,
        batch_size=self.batch_size,
        num_workers=self.num_workers,
        persistent_workers=True,
    )

train_dataloader

train_dataloader()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def train_dataloader(self):
    return DataLoader(
        self.train,
        batch_size=self.batch_size,
        num_workers=self.num_workers,
        shuffle=True,
        drop_last=True,
        persistent_workers=True,
    )

val_dataloader

val_dataloader()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def val_dataloader(self):
    return DataLoader(
        self.val,
        batch_size=self.batch_size,
        num_workers=self.num_workers,
        persistent_workers=True,
    )

DartsDatasetInMemory

DartsDatasetInMemory(
    data_dir: pathlib.Path | str,
    augment: list[
        darts_segmentation.training.augmentations.Augmentation
    ]
    | None = None,
    indices: list[int] | None = None,
    bands: list[int] | None = None,
)

Bases: torch.utils.data.Dataset

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(
    self,
    data_dir: Path | str,
    augment: list[Augmentation] | None = None,
    indices: list[int] | None = None,
    bands: list[int] | None = None,
):
    data_dir = Path(data_dir) if isinstance(data_dir, str) else data_dir

    store = zarr.storage.LocalStore(data_dir)
    self.zroot = zarr.group(store=store)

    assert "x" in self.zroot and "y" in self.zroot, (
        f"Dataset corrupted! {self.zroot.info=} must contain 'x' or 'y' arrays!"
    )

    self.x = []
    self.y = []
    indices = indices or list(range(self.zroot["x"].shape[0]))
    for i in indices:
        x = self.zroot["x"][i, bands] if bands else self.zroot["x"][i]
        y = self.zroot["y"][i]
        self.x.append(x)
        self.y.append(y)

    self.transform = get_augmentation(augment)

x instance-attribute

x = []

y instance-attribute

y = []

zroot instance-attribute

zroot = zarr.group(store=store)

__getitem__

__getitem__(idx)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __getitem__(self, idx):
    x = self.x[idx]
    y = self.y[idx]

    # Apply augmentations
    if self.transform is not None:
        augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
        x = augmented["image"].transpose(2, 0, 1)
        y = augmented["mask"]

    return x, y

__len__

__len__()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __len__(self):
    return len(self.x)

DartsDatasetZarr

DartsDatasetZarr(
    data_dir: pathlib.Path | str,
    augment: list[
        darts_segmentation.training.augmentations.Augmentation
    ]
    | None = None,
    indices: list[int] | None = None,
    bands: list[int] | None = None,
)

Bases: torch.utils.data.Dataset

Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __init__(
    self,
    data_dir: Path | str,
    augment: list[Augmentation] | None = None,
    indices: list[int] | None = None,
    bands: list[int] | None = None,
):
    data_dir = Path(data_dir) if isinstance(data_dir, str) else data_dir

    store = zarr.storage.LocalStore(data_dir)
    self.zroot = zarr.group(store=store)

    assert "x" in self.zroot and "y" in self.zroot, (
        f"Dataset corrupted! {self.zroot.info=} must contain 'x' or 'y' arrays!"
    )

    self.indices = indices if indices is not None else list(range(self.zroot["x"].shape[0]))
    self.bands = bands

    self.transform = get_augmentation(augment)

bands instance-attribute

indices instance-attribute

indices = (
    darts_segmentation.training.data.DartsDatasetZarr(
        indices
    )
    if darts_segmentation.training.data.DartsDatasetZarr(
        indices
    )
    is not None
    else list(
        range(
            darts_segmentation.training.data.DartsDatasetZarr(
                self
            )
            .zroot["x"]
            .shape[0]
        )
    )
)

zroot instance-attribute

zroot = zarr.group(store=store)

__getitem__

__getitem__(idx)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __getitem__(self, idx):
    i = self.indices[idx]

    x = self.zroot["x"][i, self.bands] if self.bands else self.zroot["x"][i]
    y = self.zroot["y"][i]

    # Apply augmentations
    if self.transform is not None:
        augmented = self.transform(image=x.transpose(1, 2, 0), mask=y)
        x = augmented["image"].transpose(2, 0, 1)
        y = augmented["mask"]

    return x, y

__len__

__len__()
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def __len__(self):
    return len(self.indices)

_get_fold

_get_fold(
    metadata: geopandas.GeoDataFrame,
    fold_method: typing.Literal[
        "kfold",
        "shuffle",
        "stratified",
        "region",
        "region-stratified",
        "none",
    ]
    | None,
    n_folds: int,
    fold: int,
) -> tuple[list[int], list[int]]
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def _get_fold(
    metadata: gpd.GeoDataFrame,
    fold_method: Literal["kfold", "shuffle", "stratified", "region", "region-stratified", "none"] | None,
    n_folds: int,
    fold: int,
) -> tuple[list[int], list[int]]:
    fold = fold if fold_method is not None else 0
    fold_method = fold_method or "none"
    match fold_method:
        case "none":
            foldgen = [(metadata.index.tolist(), metadata.index.tolist())]
        case "kfold":
            foldgen = KFold(n_folds).split(metadata)
        case "shuffle":
            foldgen = StratifiedShuffleSplit(n_splits=n_folds, random_state=42).split(metadata, ~metadata["empty"])
        case "stratified":
            foldgen = StratifiedKFold(n_folds, random_state=42, shuffle=True).split(metadata, ~metadata["empty"])
        case "region":
            foldgen = GroupShuffleSplit(n_folds).split(metadata, groups=metadata["region"])
        case "region-stratified":
            foldgen = StratifiedGroupKFold(n_folds, random_state=42, shuffle=True).split(
                metadata, ~metadata["empty"], groups=metadata["region"]
            )
        case _:
            raise ValueError(f"Unknown fold method: {fold_method}")

    for i, (train_index, val_index) in enumerate(foldgen):
        if i != fold:
            continue
        # Turn index into metadata index
        train_index = metadata.index[train_index].tolist()
        val_index = metadata.index[val_index].tolist()
        return train_index, val_index

    raise ValueError(f"Fold {fold} not found")

_log_stats

_log_stats(metadata: geopandas.GeoDataFrame, mode: str)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def _log_stats(metadata: gpd.GeoDataFrame, mode: str):
    n_pos = (~metadata["empty"]).sum()
    n_neg = metadata["empty"].sum()
    logger.debug(
        f"{mode} dataset: {n_pos} positive, {n_neg} negative ({len(metadata)} total)"
        f" with {metadata['region'].nunique()} unique regions and {metadata['sample_id'].nunique()} unique sample ids"
    )

_split_metadata

_split_metadata(
    metadata: geopandas.GeoDataFrame,
    data_split_method: typing.Literal[
        "random", "region", "sample", "none"
    ]
    | None,
    data_split_by: list[str | float] | None,
)
Source code in darts-segmentation/src/darts_segmentation/training/data.py
def _split_metadata(
    metadata: gpd.GeoDataFrame,
    data_split_method: Literal["random", "region", "sample", "none"] | None,
    data_split_by: list[str | float] | None,
):
    # Match statement doesn't like None
    data_split_method = data_split_method or "none"

    match data_split_method:
        case "none":
            return metadata, metadata
        case "random":
            assert isinstance(data_split_by, list) and len(data_split_by) == 1
            data_split_by = data_split_by[0]
            assert isinstance(data_split_by, float)
            for seed in range(100):
                train_metadata = metadata.sample(frac=data_split_by, random_state=seed)
                test_metadata = metadata.drop(train_metadata.index)
                if (~test_metadata["empty"]).sum() == 0:
                    logger.warning("Test set is empty, retrying with another random seed...")
                    continue
                return train_metadata, test_metadata
            else:
                raise ValueError("Could not split data randomly, please check your data.")
        case "region":
            assert isinstance(data_split_by, list) and len(data_split_by) > 0
            train_metadata = metadata[~metadata["region"].isin(data_split_by)]
            test_metadata = metadata[metadata["region"].isin(data_split_by)]
            return train_metadata, test_metadata
        case "sample":
            assert isinstance(data_split_by, list) and len(data_split_by) > 0
            train_metadata = metadata[~metadata["sample_id"].isin(data_split_by)]
            test_metadata = metadata[metadata["sample_id"].isin(data_split_by)]
            return train_metadata, test_metadata
        case _:
            raise ValueError(f"Invalid data split method: {data_split_method}")

get_augmentation

get_augmentation(
    augment: list[
        darts_segmentation.training.augmentations.Augmentation
    ]
    | None,
) -> albumentations.Compose | None

Get augmentations for segmentation tasks.

Parameters:

  • augment (list[darts_segmentation.training.augmentations.Augmentation] | None) –

    List of augmentations to apply. If None or emtpy, no augmentations are applied. If not empty, augmentations are applied in the order they are listed. Available augmentations: - HorizontalFlip - VerticalFlip - RandomRotate90 - Blur - RandomBrightnessContrast - MultiplicativeNoise

Raises:

  • ValueError

    If an unknown augmentation is provided.

Returns:

  • albumentations.Compose | None

    A.Compose | None: A Compose object containing the augmentations. If no augmentations are provided, returns None.

Source code in darts-segmentation/src/darts_segmentation/training/augmentations.py
def get_augmentation(augment: list[Augmentation] | None) -> "A.Compose | None":
    """Get augmentations for segmentation tasks.

    Args:
        augment (list[Augmentation] | None): List of augmentations to apply.
            If None or emtpy, no augmentations are applied.
            If not empty, augmentations are applied in the order they are listed.
            Available augmentations:
                - HorizontalFlip
                - VerticalFlip
                - RandomRotate90
                - Blur
                - RandomBrightnessContrast
                - MultiplicativeNoise

    Raises:
        ValueError: If an unknown augmentation is provided.

    Returns:
        A.Compose | None: A Compose object containing the augmentations.
            If no augmentations are provided, returns None.

    """
    import albumentations as A  # noqa: N812

    if not isinstance(augment, list) or len(augment) == 0:
        return None
    transforms = []
    for aug in augment:
        match aug:
            case "HorizontalFlip":
                transforms.append(A.HorizontalFlip())
            case "VerticalFlip":
                transforms.append(A.VerticalFlip())
            case "RandomRotate90":
                transforms.append(A.RandomRotate90())
            case "Blur":
                transforms.append(A.Blur())
            case "RandomBrightnessContrast":
                transforms.append(A.RandomBrightnessContrast())
            case "MultiplicativeNoise":
                transforms.append(A.MultiplicativeNoise(per_channel=True, elementwise=True))
            case "Posterize":
                # First convert to uint8, then apply posterization, then convert back to float32
                # * Note: This does only work for float32 images.
                transforms += [
                    A.FromFloat(dtype="uint8"),
                    A.Posterize(num_bits=6, p=1.0),
                    A.ToFloat(),
                ]
            case _:
                raise ValueError(f"Unknown augmentation: {aug}")
    return A.Compose(transforms)