darts_segmentation.training.prepare_training
¶
Functions to prepare the training data for the segmentation model training.
PatchCoords
dataclass
¶
Wrapper which stores the coordinate information of a patch in the original image.
from_tensor
classmethod
¶
from_tensor(
coords: torch.Tensor, patch_size: int
) -> (
darts_segmentation.training.prepare_training.PatchCoords
)
Create a PatchCoords object from the returned coord tensor of create_patches.
Parameters:
-
coords(torch.Tensor) –The coordinates of the patch in the original image, from
create_patches. -
patch_size(int) –The size of the patch.
Returns:
-
PatchCoords(darts_segmentation.training.prepare_training.PatchCoords) –The coordinates of the patch in the original image.
Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
TrainDatasetBuilder
dataclass
¶
TrainDatasetBuilder(
train_data_dir: pathlib.Path,
patch_size: int,
overlap: int,
bands: list[str],
exclude_nopositive: bool,
exclude_nan: bool,
device: typing.Literal["cuda", "cpu"] | int,
append: bool = False,
)
Helper class to create all necessary files for a DARTS training dataset.
This class manages the creation of a training dataset stored in Zarr format with associated metadata. It handles patch creation, quality filtering, and metadata tracking.
The dataset structure
- data.zarr/x: Input patches (N, C, H, W) as float32
- data.zarr/y: Label patches (N, H, W) as uint8 with values 0/1/2
- metadata.parquet: Patch metadata including coordinates and geometry
- config.toml: Dataset configuration and parameters
Attributes:
-
train_data_dir(pathlib.Path) –Directory where the dataset will be saved.
-
patch_size(int) –Size of each patch in pixels.
-
overlap(int) –Overlap between adjacent patches in pixels.
-
bands(list[str]) –List of band names to include in the dataset.
-
exclude_nopositive(bool) –Exclude patches without positive labels.
-
exclude_nan(bool) –Exclude patches with any NaN values.
-
device(typing.Literal['cuda', 'cpu'] | int) –Device for patch creation operations.
-
append(bool) –If True, append to existing dataset. Defaults to False.
__len__
¶
__post_init__
¶
Initialize the TrainDatasetBuilder class based on provided dataclass params.
This will setup everything needed to add patches to the dataset:
- Create the train_data_dir if it does not exist
- Create an emtpy zarr store
- Initialize the metadata list
Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
add_tile
¶
add_tile(
tile: xarray.Dataset,
labels: geopandas.GeoDataFrame,
region: str,
sample_id: str,
extent: geopandas.GeoDataFrame | None = None,
metadata: dict[str, str] | None = None,
)
Add a tile to the dataset by creating and appending patches.
This method processes a single tile by creating training patches and appending them to the Zarr arrays. Patch metadata including coordinates, geometry, and custom fields are tracked for later use.
Parameters:
-
tile(xarray.Dataset) –The input tile, containing preprocessed, harmonized data. Must contain the specified bands and a 'quality_data_mask' variable.
-
labels(geopandas.GeoDataFrame) –The labels to be used for training. Geometries will be rasterized as positive samples (class 1).
-
region(str) –The region identifier for this tile (e.g., "Alaska", "Canada"). Stored in metadata for tracking and filtering.
-
sample_id(str) –A unique identifier for this tile/sample. Stored in metadata for tracking and filtering.
-
extent(geopandas.GeoDataFrame | None, default:None) –The extent of the valid training area. Pixels outside this extent will be marked as invalid (class 2) in labels. If None, no extent masking is applied.
-
metadata(dict[str, str], default:None) –Additional metadata to be added to the metadata file. Will not be used for training, but can be used for debugging or reproducibility. Path values will be automatically converted to strings.
Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
finalize
¶
Finalize the dataset by saving the metadata and the config file.
Parameters:
-
data_config(dict[str, str], default:None) –The data config to be saved in the config file. This should contain all the information needed to recreate the dataset. It will be saved as a toml file, along with the configuration provided in this dataclass.
Raises:
-
ValueError–If no patches were found in the dataset.
Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
create_labels
¶
create_labels(
tile: xarray.Dataset,
labels: geopandas.GeoDataFrame,
extent: geopandas.GeoDataFrame | None = None,
) -> xarray.DataArray
Create rasterized labels from vector labels and quality mask.
This function rasterizes the provided labels and applies quality filtering based on the tile's quality_data_mask. Areas outside the extent or with low quality are marked as invalid (class 2).
Label encoding
- 0: Negative (no RTS)
- 1: Positive (RTS present)
- 2: Invalid/masked (outside extent, low quality, or no data)
Parameters:
-
tile(xarray.Dataset) –The input tile, containing preprocessed, harmonized data. Must contain a 'quality_data_mask' variable where value 2 indicates best quality.
-
labels(geopandas.GeoDataFrame) –The labels to be used for training. Geometries will be rasterized as positive samples (class 1).
-
extent(geopandas.GeoDataFrame | None, default:None) –The extent of the valid training area. Pixels outside this extent will be marked as invalid (class 2). If None, no extent masking is applied.
Returns:
-
xarray.DataArray–xr.DataArray: The rasterized labels with shape (y, x) and values 0, 1, or 2. Pixels with quality_data_mask != 2 are automatically marked as invalid (class 2).
Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
create_patches
¶
create_patches(
tensor_tiles: torch.Tensor,
patch_size: int,
overlap: int,
return_coords: bool = False,
) -> torch.Tensor
Create patches from a tensor.
Parameters:
-
tensor_tiles(torch.Tensor) –The input tensor. Shape: (BS, C, H, W).
-
patch_size(int) –The size of the patches.
-
overlap(int) –The size of the overlap.
-
return_coords(bool, default:False) –Whether to return the coordinates of the patches. Can be used for debugging. Defaults to False.
Returns:
Source code in darts-segmentation/src/darts_segmentation/inference.py
create_training_patches
¶
create_training_patches(
tile: xarray.Dataset,
labels: geopandas.GeoDataFrame,
extent: geopandas.GeoDataFrame | None,
bands: list[str],
patch_size: int,
overlap: int,
exclude_nopositive: bool,
exclude_nan: bool,
device: typing.Literal["cuda", "cpu"] | int,
) -> tuple[
torch.tensor,
torch.tensor,
list[
darts_segmentation.training.prepare_training.PatchCoords
],
]
Create training patches from a tile and labels with quality filtering.
This function creates overlapping patches from the input tile and rasterized labels, applying several filtering criteria to ensure high-quality training data. Pixels with quality_data_mask == 0 are set to NaN in the input data.
Patch filtering
- Excludes patches with < 10% visible pixels (> 90% invalid/masked)
- Excludes patches where all bands are NaN
- Optionally excludes patches without positive labels (if exclude_nopositive=True)
- Optionally excludes patches with any NaN values (if exclude_nan=True)
Parameters:
-
tile(xarray.Dataset) –The input tile, containing preprocessed, harmonized data. Must contain a 'quality_data_mask' variable where 0=invalid and 2=best quality.
-
labels(geopandas.GeoDataFrame) –The labels to be used for training. Geometries will be rasterized as positive samples.
-
extent(geopandas.GeoDataFrame | None) –The extent of the valid training area. Pixels outside this extent will be marked as invalid in labels. If None, no extent masking is applied.
-
bands(list[str]) –The bands to extract and use for training. Will be normalized using the band manager.
-
patch_size(int) –The size of each patch in pixels (height and width).
-
overlap(int) –The overlap between adjacent patches in pixels.
-
exclude_nopositive(bool) –Whether to exclude patches where the labels do not contain any positive samples (class 1).
-
exclude_nan(bool) –Whether to exclude patches where the input data has any NaN values.
-
device(typing.Literal['cuda', 'cpu'] | int) –The device to use for tensor operations. Can be "cuda", "cpu", or an integer GPU index.
Returns:
-
tuple[torch.tensor, torch.tensor, list[darts_segmentation.training.prepare_training.PatchCoords]]–tuple[torch.tensor, torch.tensor, list[PatchCoords]]: A tuple containing: - Input patches with shape (N, C, H, W), NaN values replaced with 0 - Label patches with shape (N, H, W), values 0/1/2 - List of PatchCoords objects with coordinate information
Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 | |