Skip to content

darts_segmentation.segment

Functionality for segmenting tiles.

DEFAULT_DEVICE module-attribute

DEFAULT_DEVICE = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

logger module-attribute

logger = logging.getLogger(
    __name__.replace("darts_", "darts.")
)

SMPSegmenter

SMPSegmenter(
    model_checkpoint: pathlib.Path | str,
    device: torch.device = darts_segmentation.segment.DEFAULT_DEVICE,
)

Semantic segmentation model wrapper for RTS detection using Segmentation Models PyTorch.

This class provides a stateful inference interface for semantic segmentation models trained with the DARTS pipeline. It handles model loading, normalization, patch-based inference, and memory management.

Attributes:

Note

The segmenter automatically: - Loads model weights from PyTorch Lightning or legacy checkpoints - Normalizes input data using band-specific statistics from darts_utils.bands - Handles memory cleanup after inference to prevent GPU memory leaks

Example

Basic segmentation workflow:

from darts_segmentation import SMPSegmenter
import torch

# Initialize segmenter
segmenter = SMPSegmenter(
    model_checkpoint="path/to/model.ckpt",
    device=torch.device("cuda")
)

# Check required bands
print(segmenter.required_bands)
# {'blue', 'green', 'red', 'nir', 'ndvi', 'slope', 'hillshade', ...}

# Run inference on preprocessed tile
result = segmenter.segment_tile(
    tile=preprocessed_tile,
    patch_size=1024,
    overlap=16,
    batch_size=8
)

# Access predictions
probabilities = result["probabilities"]  # float32, range [0, 1]

Initialize the segmenter with a trained model checkpoint.

Parameters:

Note

The checkpoint must contain: - Model architecture configuration (config or hyper_parameters) - Trained weights (state_dict or statedict) - Required input bands list Using lightning checkpoints from our training pipeline is recommended.

Source code in darts-segmentation/src/darts_segmentation/segment.py
def __init__(self, model_checkpoint: Path | str, device: torch.device = DEFAULT_DEVICE):
    """Initialize the segmenter with a trained model checkpoint.

    Args:
        model_checkpoint (Path | str): Path to the model checkpoint file (.ckpt).
            Supports both PyTorch Lightning checkpoints and legacy formats.
        device (torch.device, optional): Device to load the model on.
            Defaults to CUDA if available, else CPU.

    Note:
        The checkpoint must contain:
        - Model architecture configuration (config or hyper_parameters)
        - Trained weights (state_dict or statedict)
        - Required input bands list
        Using lightning checkpoints from our training pipeline is recommended.

    """
    if isinstance(model_checkpoint, str):
        model_checkpoint = Path(model_checkpoint)
    self.device = device
    ckpt = torch.load(model_checkpoint, map_location=self.device, weights_only=False)
    self.config = SMPSegmenterConfig.from_ckpt(ckpt)
    # Overwrite the encoder weights with None, because we load our own
    self.config["model"] |= {"encoder_weights": None}
    self.model = smp.create_model(**self.config["model"])
    self.model.to(self.device)

    # Legacy version
    if "statedict" in ckpt.keys():
        statedict = ckpt["statedict"]
    else:
        statedict = ckpt["state_dict"]
        # Lightning Checkpoints are prefixed with "model." -> we need to remove them. This is an in-place function
        torch.nn.modules.utils.consume_prefix_in_state_dict_if_present(statedict, "model.")
    self.model.load_state_dict(statedict)
    self.model.eval()

    logger.debug(f"Successfully loaded model from {model_checkpoint.resolve()} with inputs: {self.config['bands']}")

device instance-attribute

model instance-attribute

model: torch.nn.Module = (
    segmentation_models_pytorch.create_model(
        **(
            darts_segmentation.segment.SMPSegmenter(
                self
            ).config["model"]
        )
    )
)

required_bands property

required_bands: set[str]

The bands required by this model.

__call__

__call__(
    input: xarray.Dataset | list[xarray.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xarray.Dataset | list[xarray.Dataset]

Run inference on a single tile or a list of tiles.

Parameters:

  • input (xarray.Dataset | list[xarray.Dataset]) –

    A single tile or a list of tiles.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

Returns:

Raises:

  • ValueError

    in case the input is not an xr.Dataset or a list of xr.Dataset

Source code in darts-segmentation/src/darts_segmentation/segment.py
def __call__(
    self,
    input: xr.Dataset | list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xr.Dataset | list[xr.Dataset]:
    """Run inference on a single tile or a list of tiles.

    Args:
        input (xr.Dataset | list[xr.Dataset]): A single tile or a list of tiles.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

    Returns:
        A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input.
        Each `probability` has type float32 and range [0, 1].

    Raises:
        ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset

    """
    if isinstance(input, xr.Dataset):
        return self.segment_tile(
            input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )
    elif isinstance(input, list):
        return NotImplementedError("Currently passing multiple datasets at once is not supported.")
    else:
        raise ValueError(f"Expected xr.Dataset or list of xr.Dataset, got {type(input)}")

segment_tile

segment_tile(
    tile: xarray.Dataset,
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xarray.Dataset

Run semantic segmentation inference on a single tile.

This method performs patch-based inference with optional overlap and reflection padding to handle edge artifacts. The tile is automatically normalized using band-specific statistics before inference.

Parameters:

  • tile (xarray.Dataset) –

    Input tile containing preprocessed data. Must include all bands specified in self.required_bands. Variables should be float32 reflectance or normalized feature values.

  • patch_size (int, default: 1024 ) –

    Size of square patches for inference in pixels. Larger patches use more memory but may be faster. Defaults to 1024.

  • overlap (int, default: 16 ) –

    Overlap between adjacent patches in pixels. Helps reduce edge artifacts. Defaults to 16.

  • batch_size (int, default: 8 ) –

    Number of patches to process simultaneously. Higher values use more GPU memory but may be faster. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection padding applied to tile edges in pixels. Reduces edge effects. Defaults to 0.

Returns:

  • xarray.Dataset

    xr.Dataset: Input tile augmented with a new data variable: - probabilities (float32): Segmentation probabilities in range [0, 1]. Attributes: long_name="Probabilities"

Note

Processing pipeline: 1. Extract and reorder bands according to model requirements 2. Normalize using darts_utils.bands.manager 3. Convert to torch tensor 4. Run patch-based inference with overlap blending 5. Convert predictions back to xarray

Memory management: - Automatically frees GPU memory after inference - Predictions are moved to CPU before returning

Example

Run inference with custom parameters:

result = segmenter.segment_tile(
    tile=preprocessed_tile,
    patch_size=512,  # Smaller patches for limited GPU memory
    overlap=32,      # More overlap for smoother predictions
    batch_size=4,    # Smaller batches for memory constraints
    reflection=16    # Add padding to reduce edge artifacts
)

# Extract probabilities
probs = result["probabilities"]
Source code in darts-segmentation/src/darts_segmentation/segment.py
@stopwatch.f(
    "Segmenting tile",
    printer=logger.debug,
    print_kwargs=["patch_size", "overlap", "batch_size", "reflection"],
)
def segment_tile(
    self, tile: xr.Dataset, patch_size: int = 1024, overlap: int = 16, batch_size: int = 8, reflection: int = 0
) -> xr.Dataset:
    """Run semantic segmentation inference on a single tile.

    This method performs patch-based inference with optional overlap and reflection padding
    to handle edge artifacts. The tile is automatically normalized using band-specific
    statistics before inference.

    Args:
        tile (xr.Dataset): Input tile containing preprocessed data. Must include all bands
            specified in `self.required_bands`. Variables should be float32 reflectance
            or normalized feature values.
        patch_size (int, optional): Size of square patches for inference in pixels.
            Larger patches use more memory but may be faster. Defaults to 1024.
        overlap (int, optional): Overlap between adjacent patches in pixels. Helps reduce
            edge artifacts. Defaults to 16.
        batch_size (int, optional): Number of patches to process simultaneously. Higher
            values use more GPU memory but may be faster. Defaults to 8.
        reflection (int, optional): Reflection padding applied to tile edges in pixels.
            Reduces edge effects. Defaults to 0.

    Returns:
        xr.Dataset: Input tile augmented with a new data variable:
            - probabilities (float32): Segmentation probabilities in range [0, 1].
              Attributes: long_name="Probabilities"

    Note:
        Processing pipeline:
        1. Extract and reorder bands according to model requirements
        2. Normalize using darts_utils.bands.manager
        3. Convert to torch tensor
        4. Run patch-based inference with overlap blending
        5. Convert predictions back to xarray

        Memory management:
        - Automatically frees GPU memory after inference
        - Predictions are moved to CPU before returning

    Example:
        Run inference with custom parameters:

        ```python
        result = segmenter.segment_tile(
            tile=preprocessed_tile,
            patch_size=512,  # Smaller patches for limited GPU memory
            overlap=32,      # More overlap for smoother predictions
            batch_size=4,    # Smaller batches for memory constraints
            reflection=16    # Add padding to reduce edge artifacts
        )

        # Extract probabilities
        probs = result["probabilities"]
        ```

    """
    # Convert the tile to a tensor
    tile = tile[self.config["bands"]].transpose("y", "x")
    tile = manager.normalize(tile)
    # ? The heavy operation is .to_dataarray()
    tensor_tile = torch.as_tensor(tile.to_dataarray().data)

    # Create a batch dimension, because predict expects it
    tensor_tile = tensor_tile.unsqueeze(0)

    probabilities = predict_in_patches(
        self.model, tensor_tile, patch_size, overlap, batch_size, reflection, self.device
    ).squeeze(0)

    # Highly sophisticated DL-based predictor
    tile["probabilities"] = (("y", "x"), probabilities.cpu().numpy())
    tile["probabilities"].attrs = {"long_name": "Probabilities"}

    # Cleanup cuda memory
    del tensor_tile, probabilities
    free_torch()

    return tile

SMPSegmenterConfig

Bases: typing.TypedDict

Configuration for the segmentor.

bands instance-attribute

bands: list[str]

model instance-attribute

model: dict[str, typing.Any]

from_ckpt classmethod

Load and validate the config from a checkpoint for the segmentor.

Parameters:

Returns:

Source code in darts-segmentation/src/darts_segmentation/segment.py
@classmethod
def from_ckpt(cls, ckpt: dict[str, Any]) -> "SMPSegmenterConfig":
    """Load and validate the config from a checkpoint for the segmentor.

    Args:
        ckpt: The checkpoint to load.

    Returns:
        The configuration.

    """
    # Legacy version: config and directly in ckpt
    if "config" in ckpt:
        config = ckpt["config"]
        # Handling legacy case that the config contains the old keys
        if "input_combination" in config and "norm_factors" in config:
            # Check if all input_combination features are in norm_factors
            config["bands"] = config["input_combination"]
            config.pop("norm_factors")
            config.pop("input_combination")
        # Another legacy case uses a deprecated "Bands" class, which is pickled into the config as dict
        if isinstance(config["bands"], dict):
            config["bands"] = config["bands"]["bands"]
    # New version: load directly from lightning checkpoint
    else:
        config = ckpt["hyper_parameters"]["config"]

    assert "model" in config, "Model config is missing!"
    assert "bands" in config, "Bands config is missing!"
    return config

predict_in_patches

predict_in_patches(
    model: torch.nn.Module,
    tensor_tiles: torch.Tensor,
    patch_size: int,
    overlap: int,
    batch_size: int,
    reflection: int,
    device: torch.device,
    return_weights: bool = False,
) -> torch.Tensor

Predict on a tensor.

Parameters:

  • model (torch.nn.Module) –

    The model to use for prediction.

  • tensor_tiles (torch.Tensor) –

    The input tensor. Shape: (BS, C, H, W).

  • patch_size (int) –

    The size of the patches.

  • overlap (int) –

    The size of the overlap.

  • batch_size (int) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches.

  • reflection (int) –

    Reflection-Padding which will be applied to the edges of the tensor.

  • device (torch.device) –

    The device to use for the prediction.

  • return_weights (bool, default: False ) –

    Whether to return the weights. Can be used for debugging. Defaults to False.

Returns:

Source code in darts-segmentation/src/darts_segmentation/inference.py
@torch.no_grad()
def predict_in_patches(
    model: nn.Module,
    tensor_tiles: torch.Tensor,
    patch_size: int,
    overlap: int,
    batch_size: int,
    reflection: int,
    device: torch.device,
    return_weights: bool = False,
) -> torch.Tensor:
    """Predict on a tensor.

    Args:
        model: The model to use for prediction.
        tensor_tiles: The input tensor. Shape: (BS, C, H, W).
        patch_size (int): The size of the patches.
        overlap (int): The size of the overlap.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor.
        device (torch.device): The device to use for the prediction.
        return_weights (bool, optional): Whether to return the weights. Can be used for debugging. Defaults to False.

    Returns:
        The predicted tensor.

    """
    logger.debug(
        f"Predicting on a tensor with shape {tensor_tiles.shape} "
        f"with patch_size {patch_size}, overlap {overlap} and batch_size {batch_size} on device {device}"
    )
    assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
    # Add a 1px + reflection border to avoid pixel loss when applying the soft margin and to reduce edge-artefacts
    p = 1 + reflection
    tensor_tiles = torch.nn.functional.pad(tensor_tiles, (p, p, p, p), mode="reflect")
    bs, c, h, w = tensor_tiles.shape
    step_size = patch_size - overlap
    nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)

    # Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
    patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap)

    # Flatten the patches so they fit to the model
    # (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size)
    patches = patches.view(bs * nh * nw, c, patch_size, patch_size)

    # Create a soft margin for the patches
    margin_ramp = torch.cat(
        [
            torch.linspace(0, 1, overlap),
            torch.ones(patch_size - 2 * overlap),
            torch.linspace(1, 0, overlap),
        ]
    )
    soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1)
    soft_margin = soft_margin.to(patches.device)

    # Infer logits with model and turn into probabilities with sigmoid in a batched manner
    # TODO: check with ingmar and jonas if moving all patches to the device at the same time is a good idea
    patched_probabilities = torch.zeros_like(patches[:, 0, :, :])
    patches = patches.split(batch_size)
    n_skipped = 0
    for i, batch in enumerate(patches):
        # If batch contains only nans, skip it
        if torch.isnan(batch).all(axis=0).any():
            patched_probabilities[i * batch_size : (i + 1) * batch_size] = 0
            n_skipped += 1
            continue
        # If batch contains some nans, replace them with zeros
        batch[torch.isnan(batch)] = 0

        batch = batch.to(device)
        # logger.debug(f"Predicting on batch {i + 1}/{len(patches)}")
        patched_probabilities[i * batch_size : (i + 1) * batch_size] = (
            torch.sigmoid(model(batch)).squeeze(1).to(patched_probabilities.device)
        )
        batch = batch.to(patched_probabilities.device)  # Transfer back to the original device to avoid memory leaks

    if n_skipped > 0:
        logger.debug(f"Skipped {n_skipped} batches because they only contained NaNs")

    patched_probabilities = patched_probabilities.view(bs, nh, nw, patch_size, patch_size)

    # Reconstruct the image from the patches
    prediction = torch.zeros(bs, h, w, device=tensor_tiles.device)
    weights = torch.zeros(bs, h, w, device=tensor_tiles.device)

    for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, overlap):
        patch = patched_probabilities[:, patch_idx_h, patch_idx_w]
        prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin
        weights[:, y : y + patch_size, x : x + patch_size] += soft_margin

    # Avoid division by zero
    weights = torch.where(weights == 0, torch.ones_like(weights), weights)
    prediction = prediction / weights

    # Remove the 1px border and the padding
    prediction = prediction[:, p:-p, p:-p]

    if return_weights:
        return prediction, weights
    else:
        return prediction