Skip to content

darts_segmentation.training.create_training_patches

Create training patches from a tile and labels.

Parameters:

  • tile (xarray.Dataset) –

    The input tile, containing preprocessed, harmonized data.

  • labels (geopandas.GeoDataFrame) –

    The labels to be used for training.

  • bands (list[str]) –

    The bands to be used for training. Must be present in the tile.

  • norm_factors (dict[str, float]) –

    The normalization factors for the bands.

  • patch_size (int) –

    The size of the patches.

  • overlap (int) –

    The size of the overlap.

  • exclude_nopositive (bool) –

    Whether to exclude patches where the labels do not contain positives.

  • exclude_nan (bool) –

    Whether to exclude patches where the input data has nan values.

  • device (typing.Literal['cuda', 'cpu'] | int) –

    The device to use for the erosion.

  • mask_erosion_size (int) –

    The size of the disk to use for erosion.

Yields:

Raises:

  • ValueError

    If a band is not found in the preprocessed data.

Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
def create_training_patches(
    tile: xr.Dataset,
    labels: gpd.GeoDataFrame,
    bands: list[str],
    norm_factors: dict[str, float],
    patch_size: int,
    overlap: int,
    exclude_nopositive: bool,
    exclude_nan: bool,
    device: Literal["cuda", "cpu"] | int,
    mask_erosion_size: int,
) -> Generator[tuple[torch.tensor, torch.tensor]]:
    """Create training patches from a tile and labels.

    Args:
        tile (xr.Dataset): The input tile, containing preprocessed, harmonized data.
        labels (gpd.GeoDataFrame): The labels to be used for training.
        bands (list[str]): The bands to be used for training. Must be present in the tile.
        norm_factors (dict[str, float]): The normalization factors for the bands.
        patch_size (int): The size of the patches.
        overlap (int): The size of the overlap.
        exclude_nopositive (bool): Whether to exclude patches where the labels do not contain positives.
        exclude_nan (bool): Whether to exclude patches where the input data has nan values.
        device (Literal["cuda", "cpu"] | int): The device to use for the erosion.
        mask_erosion_size (int): The size of the disk to use for erosion.

    Yields:
        Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors.
            The input has the format (C, H, W), the labels (H, W).

    Raises:
        ValueError: If a band is not found in the preprocessed data.

    """
    if len(labels) == 0 and exclude_nopositive:
        logger.warning("No labels found in the labels GeoDataFrame. Skipping.")
        return

    # Rasterize the labels
    if len(labels) > 0:
        labels_rasterized = 1 - make_geocube(labels, measurements=["id"], like=tile).id.isnull()
    else:
        labels_rasterized = xr.zeros_like(tile["valid_data_mask"])

    # Filter out the nodata values (class 2 -> invalid data)
    mask = erode_mask(tile["valid_data_mask"], mask_erosion_size, device)
    mask = tile["valid_data_mask"]
    labels_rasterized = xr.where(mask, labels_rasterized, 2)

    # Normalize the bands and clip the values
    for band in bands:
        if band not in tile:
            raise ValueError(f"Band '{band}' not found in the preprocessed data.")
        with xr.set_options(keep_attrs=True):
            tile[band] = tile[band] * norm_factors[band]
            tile[band] = tile[band].clip(0, 1)

    # Replace invalid values with nan (used for nan check later on)
    tile = xr.where(tile["valid_data_mask"], tile, float("nan"))

    # Convert to dataaray and select the bands (bands are now in specified order)
    tile = tile.to_dataarray(dim="band").sel(band=bands)

    # Transpose to (C, H, W)
    tile = tile.transpose("band", "y", "x")
    labels_rasterized = labels_rasterized.transpose("y", "x")

    # Convert to tensor
    tensor_tile = torch.tensor(tile.values).float()
    tensor_labels = torch.tensor(labels_rasterized.values).float()

    assert tensor_tile.dim() == 3, f"Expects tensor_tile to has shape (C, H, W), got {tensor_tile.shape}"
    assert tensor_labels.dim() == 2, f"Expects tensor_labels to has shape (H, W), got {tensor_labels.shape}"

    # Create patches
    tensor_patches = create_patches(tensor_tile.unsqueeze(0), patch_size, overlap)
    tensor_patches = tensor_patches.reshape(-1, len(bands), patch_size, patch_size)
    tensor_labels = create_patches(tensor_labels.unsqueeze(0).unsqueeze(0), patch_size, overlap)
    tensor_labels = tensor_labels.reshape(-1, patch_size, patch_size)

    # Turn the patches into a list of tuples
    n_patches = tensor_patches.shape[0]
    for i in range(n_patches):
        x = tensor_patches[i]
        y = tensor_labels[i]

        if exclude_nopositive and not (y == 1).any():
            continue

        if exclude_nan and torch.isnan(x).any():
            continue

        # Skip where there are less than 10% visible pixel
        if ((y != 2).sum() / y.numel()) < 0.1:
            continue

        # Skip patches where everything is nan
        if torch.isnan(x).all():
            continue

        # Convert all nan values to 0
        x[torch.isnan(x)] = 0

        logger.debug(f"Yielding patch {i} with\n\t{x=}\n\t{y=}")
        yield x, y