Skip to content

segment

darts_segmentation.segment

Functionality for segmenting tiles.

DEFAULT_DEVICE module-attribute

DEFAULT_DEVICE = torch.device(
    "cuda" if torch.cuda.is_available() else "cpu"
)

logger module-attribute

logger = logging.getLogger(
    __name__.replace("darts_", "darts.")
)

Band dataclass

Band(name: str, factor: float = 1.0, offset: float = 0.0)

Wrapper for the band information.

factor class-attribute instance-attribute

factor: float = 1.0

name instance-attribute

name: str

offset class-attribute instance-attribute

offset: float = 0.0

Bands

Bases: collections.UserList[darts_segmentation.utils.Band]

Wrapper for the list of bands.

factors property

factors: list[float]

Get the factors of the bands.

Returns:

  • list[float]

    list[float]: The factors of the bands.

names property

names: list[str]

Get the names of the bands.

Returns:

  • list[str]

    list[str]: The names of the bands.

offsets property

offsets: list[float]

Get the offsets of the bands.

Returns:

  • list[float]

    list[float]: The offsets of the bands.

__reduce__

__reduce__()
Source code in darts-segmentation/src/darts_segmentation/utils.py
def __reduce__(self):  # noqa: D105
    # This is needed to pickle (and unpickle) the Bands object as a dict
    # This is needed, because this way we don't need to have this class present when unpickling
    # a pytorch checkpoint
    return (dict, (self.to_config(),))

__repr__

__repr__() -> str
Source code in darts-segmentation/src/darts_segmentation/utils.py
def __repr__(self) -> str:  # noqa: D105
    band_info = ", ".join([f"{band.name}(*{band.factor:.5f}+{band.offset:.5f})" for band in self])
    return f"Bands({band_info})"

filter

filter(
    band_names: list[str],
) -> darts_segmentation.utils.Bands

Filter the bands by name.

Parameters:

  • band_names (list[str]) –

    The names of the bands to keep.

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
def filter(self, band_names: list[str]) -> "Bands":
    """Filter the bands by name.

    Args:
        band_names (list[str]): The names of the bands to keep.

    Returns:
        Bands: The filtered Bands object.

    """
    return Bands([band for band in self if band.name in band_names])

from_config classmethod

from_config(
    config: dict[
        typing.Literal[
            "bands", "band_factors", "band_offsets"
        ],
        list,
    ]
    | dict[str, tuple[float, float]],
) -> darts_segmentation.utils.Bands

Create a Bands object from a config dictionary.

Parameters:

  • config (dict) –

    The config dictionary containing the band information. Expects config to be a dictionary with keys "bands", "band_factors" and "band_offsets", with the values to be lists of the same length.

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
@classmethod
def from_config(
    cls,
    config: dict[Literal["bands", "band_factors", "band_offsets"], list] | dict[str, tuple[float, float]],
) -> "Bands":
    """Create a Bands object from a config dictionary.

    Args:
        config (dict): The config dictionary containing the band information.
            Expects config to be a dictionary with keys "bands", "band_factors" and "band_offsets",
            with the values to be lists of the same length.

    Returns:
        Bands: The Bands object.

    """
    assert "bands" in config and "band_factors" in config and "band_offsets" in config, (
        f"Config must contain keys 'bands', 'band_factors' and 'band_offsets'.Got {config} instead."
    )
    return cls(
        [
            Band(name=name, factor=factor, offset=offset)
            for name, factor, offset in zip(config["bands"], config["band_factors"], config["band_offsets"])
        ]
    )

from_dict classmethod

from_dict(
    config: dict[str, tuple[float, float]],
) -> darts_segmentation.utils.Bands

Create a Bands object from a dictionary.

Parameters:

  • config (dict[str, tuple[float, float]]) –

    The dictionary containing the band information. Expects the keys to be the band names and the values to be tuples of (factor, offset). Example: {"band1": (1.0, 0.0), "band2": (2.0, 1.0)}

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
@classmethod
def from_dict(cls, config: dict[str, tuple[float, float]]) -> "Bands":
    """Create a Bands object from a dictionary.

    Args:
        config (dict[str, tuple[float, float]]): The dictionary containing the band information.
            Expects the keys to be the band names and the values to be tuples of (factor, offset).
            Example: {"band1": (1.0, 0.0), "band2": (2.0, 1.0)}

    Returns:
        Bands: The Bands object.

    """
    return cls([Band(name=name, factor=factor, offset=offset) for name, (factor, offset) in config.items()])

to_config

to_config() -> dict[
    typing.Literal["bands", "band_factors", "band_offsets"],
    list,
]

Convert the Bands object to a config dictionary.

Returns:

  • dict ( dict[typing.Literal['bands', 'band_factors', 'band_offsets'], list] ) –

    The config dictionary containing the band information.

Source code in darts-segmentation/src/darts_segmentation/utils.py
def to_config(self) -> dict[Literal["bands", "band_factors", "band_offsets"], list]:
    """Convert the Bands object to a config dictionary.

    Returns:
        dict: The config dictionary containing the band information.

    """
    return {
        "bands": [band.name for band in self],
        "band_factors": [band.factor for band in self],
        "band_offsets": [band.offset for band in self],
    }

to_dict

to_dict() -> dict[str, tuple[float, float]]

Convert the Bands object to a dictionary.

Returns:

  • dict[str, tuple[float, float]]

    dict[str, tuple[float, float]]: The dictionary containing the band information.

Source code in darts-segmentation/src/darts_segmentation/utils.py
def to_dict(self) -> dict[str, tuple[float, float]]:
    """Convert the Bands object to a dictionary.

    Returns:
        dict[str, tuple[float, float]]: The dictionary containing the band information.

    """
    return {band.name: (band.factor, band.offset) for band in self}

SMPSegmenter

SMPSegmenter(
    model_checkpoint: pathlib.Path | str,
    device: torch.device = darts_segmentation.segment.DEFAULT_DEVICE,
)

An actor that keeps a model as its state and segments tiles.

Initialize the segmenter.

Parameters:

Source code in darts-segmentation/src/darts_segmentation/segment.py
def __init__(self, model_checkpoint: Path | str, device: torch.device = DEFAULT_DEVICE):
    """Initialize the segmenter.

    Args:
        model_checkpoint (Path): The path to the model checkpoint.
        device (torch.device): The device to run the model on.
            Defaults to torch.device("cuda") if cuda is available, else torch.device("cpu").

    """
    model_checkpoint = model_checkpoint if isinstance(model_checkpoint, Path) else Path(model_checkpoint)
    self.device = device
    ckpt = torch.load(model_checkpoint, map_location=self.device)
    self.config = SMPSegmenterConfig.from_ckpt(ckpt["config"])
    # Overwrite the encoder weights with None, because we load our own
    self.config["model"] |= {"encoder_weights": None}
    self.model = smp.create_model(**self.config["model"])
    self.model.to(self.device)
    self.model.load_state_dict(ckpt["statedict"])
    self.model.eval()

    logger.debug(f"Successfully loaded model from {model_checkpoint.resolve()} with inputs: {self.config['bands']}")

config instance-attribute

device instance-attribute

model instance-attribute

model: torch.nn.Module = (
    segmentation_models_pytorch.create_model(
        **darts_segmentation.segment.SMPSegmenter(
            self
        ).config["model"]
    )
)

__call__

__call__(
    input: xarray.Dataset | list[xarray.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xarray.Dataset | list[xarray.Dataset]

Run inference on a single tile or a list of tiles.

Parameters:

  • input (xarray.Dataset | list[xarray.Dataset]) –

    A single tile or a list of tiles.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

Returns:

Raises:

  • ValueError

    in case the input is not an xr.Dataset or a list of xr.Dataset

Source code in darts-segmentation/src/darts_segmentation/segment.py
def __call__(
    self,
    input: xr.Dataset | list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xr.Dataset | list[xr.Dataset]:
    """Run inference on a single tile or a list of tiles.

    Args:
        input (xr.Dataset | list[xr.Dataset]): A single tile or a list of tiles.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

    Returns:
        A single tile or a list of tiles augmented by a predicted `probabilities` layer, depending on the input.
        Each `probability` has type float32 and range [0, 1].

    Raises:
        ValueError: in case the input is not an xr.Dataset or a list of xr.Dataset

    """
    if isinstance(input, xr.Dataset):
        return self.segment_tile(
            input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )
    elif isinstance(input, list):
        return self.segment_tile_batched(
            input, patch_size=patch_size, overlap=overlap, batch_size=batch_size, reflection=reflection
        )
    else:
        raise ValueError(f"Expected xr.Dataset or list of xr.Dataset, got {type(input)}")

segment_tile

segment_tile(
    tile: xarray.Dataset,
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> xarray.Dataset

Run inference on a tile.

Parameters:

  • tile (xarray.Dataset) –

    The input tile, containing preprocessed, harmonized data.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

Returns:

  • xarray.Dataset

    Input tile augmented by a predicted probabilities layer with type float32 and range [0, 1].

Source code in darts-segmentation/src/darts_segmentation/segment.py
@stopwatch.f(
    "Segmenting tile",
    printer=logger.debug,
    print_kwargs=["patch_size", "overlap", "batch_size", "reflection"],
)
def segment_tile(
    self, tile: xr.Dataset, patch_size: int = 1024, overlap: int = 16, batch_size: int = 8, reflection: int = 0
) -> xr.Dataset:
    """Run inference on a tile.

    Args:
        tile: The input tile, containing preprocessed, harmonized data.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

    Returns:
        Input tile augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

    """
    # Convert the tile to a tensor
    tensor_tile = self.tile2tensor(tile)

    # Create a batch dimension, because predict expects it
    tensor_tile = tensor_tile.unsqueeze(0)

    probabilities = predict_in_patches(
        self.model, tensor_tile, patch_size, overlap, batch_size, reflection, self.device
    ).squeeze(0)

    # Highly sophisticated DL-based predictor
    # TODO: is there a better way to pass metadata?
    tile["probabilities"] = tile["red"].copy(data=probabilities.cpu().numpy())
    tile["probabilities"].attrs = {"long_name": "Probabilities"}
    tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan"))

    # Cleanup cuda memory
    del tensor_tile, probabilities
    free_torch()

    return tile

segment_tile_batched

segment_tile_batched(
    tiles: list[xarray.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> list[xarray.Dataset]

Run inference on a list of tiles.

Parameters:

  • tiles (list[xarray.Dataset]) –

    The input tiles, containing preprocessed, harmonized data.

  • patch_size (int, default: 1024 ) –

    The size of the patches. Defaults to 1024.

  • overlap (int, default: 16 ) –

    The size of the overlap. Defaults to 16.

  • batch_size (int, default: 8 ) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.

  • reflection (int, default: 0 ) –

    Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

Returns:

  • list[xarray.Dataset]

    A list of input tiles augmented by a predicted probabilities layer with type float32 and range [0, 1].

Source code in darts-segmentation/src/darts_segmentation/segment.py
@stopwatch.f(
    "Segmenting tiles",
    printer=logger.debug,
    print_kwargs=["patch_size", "overlap", "batch_size", "reflection"],
)
def segment_tile_batched(
    self,
    tiles: list[xr.Dataset],
    patch_size: int = 1024,
    overlap: int = 16,
    batch_size: int = 8,
    reflection: int = 0,
) -> list[xr.Dataset]:
    """Run inference on a list of tiles.

    Args:
        tiles: The input tiles, containing preprocessed, harmonized data.
        patch_size (int): The size of the patches. Defaults to 1024.
        overlap (int): The size of the overlap. Defaults to 16.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.

    Returns:
        A list of input tiles augmented by a predicted `probabilities` layer with type float32 and range [0, 1].

    """
    # Convert the tiles to tensors
    # TODO: maybe create a batched tile2tensor function?
    # tensor_tiles = [self.tile2tensor(tile).to(self.dev) for tile in tiles]
    tensor_tiles = self.tile2tensor_batched(tiles)

    # Create a batch dimension, because predict expects it
    tensor_tiles = torch.stack(tensor_tiles, dim=0)

    probabilities = predict_in_patches(
        self.model, tensor_tiles, patch_size, overlap, batch_size, reflection, self.device
    )

    # Highly sophisticated DL-based predictor
    for tile, probs in zip(tiles, probabilities):
        # TODO: is there a better way to pass metadata?
        tile["probabilities"] = tile["red"].copy(data=probs.cpu().numpy())
        tile["probabilities"].attrs = {"long_name": "Probabilities"}
        tile["probabilities"] = tile["probabilities"].fillna(float("nan")).rio.write_nodata(float("nan"))

    # Cleanup cuda memory
    del tensor_tiles, probabilities
    free_torch()

    return tiles

tile2tensor

tile2tensor(tile: xarray.Dataset) -> torch.Tensor

Take a tile and convert it to a pytorch tensor.

Respects the input combination from the config.

Returns:

  • torch.Tensor

    A torch tensor for the full tile consisting of the bands specified in self.band_combination.

Source code in darts-segmentation/src/darts_segmentation/segment.py
def tile2tensor(self, tile: xr.Dataset) -> torch.Tensor:
    """Take a tile and convert it to a pytorch tensor.

    Respects the input combination from the config.

    Returns:
        A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.

    """
    bands = []
    # e.g. band.names: ["red", "green", "blue", "relative_elevation", ...]
    # tile.data_vars: ["red", "green", "blue", "relative_elevation", ...]

    for band in self.config["bands"]:
        band_data = tile[band.name]
        # Normalize the band data to the range [0, 1]
        # Follows CF conventions for scaling and offsetting
        # decode_values = encoded_values * scale_factor + add_offset
        # the range [0, 1] is the decoded range
        band_data = band_data * band.factor + band.offset
        band_data = band_data.clip(min=0, max=1)
        bands.append(torch.from_numpy(band_data.to_numpy().astype("float32")))

    return torch.stack(bands, dim=0)

tile2tensor_batched

tile2tensor_batched(
    tiles: list[xarray.Dataset],
) -> torch.Tensor

Take a list of tiles and convert them to a pytorch tensor.

Respects the the input combination from the config.

Returns:

  • torch.Tensor

    A torch tensor for the full tile consisting of the bands specified in self.band_combination.

Source code in darts-segmentation/src/darts_segmentation/segment.py
def tile2tensor_batched(self, tiles: list[xr.Dataset]) -> torch.Tensor:
    """Take a list of tiles and convert them to a pytorch tensor.

    Respects the the input combination from the config.

    Returns:
        A torch tensor for the full tile consisting of the bands specified in `self.band_combination`.

    """
    bands = []
    for band in self.config["bands"]:
        for tile in tiles:
            band_data = tile[band.name]
            # Normalize the band data
            band_data = band_data * band.factor + band.offset
            band_data = band_data.clip(min=0, max=1)
            bands.append(torch.from_numpy(band_data.to_numpy().astype("float32")))
    # TODO: Test this
    return torch.stack(bands, dim=0).reshape(len(tiles), len(self.config["bands"]), *bands[0].shape)

SMPSegmenterConfig

Bases: typing.TypedDict

Configuration for the segmentor.

bands instance-attribute

model instance-attribute

model: dict[str, typing.Any]

from_ckpt classmethod

Validate the config for the segmentor.

Parameters:

Returns:

Source code in darts-segmentation/src/darts_segmentation/segment.py
@classmethod
def from_ckpt(cls, config: dict[str, Any]) -> "SMPSegmenterConfig":
    """Validate the config for the segmentor.

    Args:
        config: The configuration to validate.

    Returns:
        The validated configuration.

    """
    # Handling legacy case that the config contains the old keys
    if "input_combination" in config and "norm_factors" in config:
        # Check if all input_combination features are in norm_factors
        config["bands"] = Bands([Band(name, config["norm_factors"][name]) for name in config["input_combination"]])
        config.pop("norm_factors")
        config.pop("input_combination")

    assert "model" in config, "Model config is missing!"
    assert "bands" in config, "Bands config is missing!"
    # The Bands object is always pickled as a dict for interoperability, so we need to convert it back
    if not isinstance(config["bands"], Bands):
        config["bands"] = Bands.from_config(config["bands"])
    return config

predict_in_patches

predict_in_patches(
    model: torch.nn.Module,
    tensor_tiles: torch.Tensor,
    patch_size: int,
    overlap: int,
    batch_size: int,
    reflection: int,
    device=torch.device,
    return_weights: bool = False,
) -> torch.Tensor

Predict on a tensor.

Parameters:

  • model (torch.nn.Module) –

    The model to use for prediction.

  • tensor_tiles (torch.Tensor) –

    The input tensor. Shape: (BS, C, H, W).

  • patch_size (int) –

    The size of the patches.

  • overlap (int) –

    The size of the overlap.

  • batch_size (int) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches.

  • reflection (int) –

    Reflection-Padding which will be applied to the edges of the tensor.

  • device (torch.device, default: torch.device ) –

    The device to use for the prediction.

  • return_weights (bool, default: False ) –

    Whether to return the weights. Can be used for debugging. Defaults to False.

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
@torch.no_grad()
def predict_in_patches(
    model: nn.Module,
    tensor_tiles: torch.Tensor,
    patch_size: int,
    overlap: int,
    batch_size: int,
    reflection: int,
    device=torch.device,
    return_weights: bool = False,
) -> torch.Tensor:
    """Predict on a tensor.

    Args:
        model: The model to use for prediction.
        tensor_tiles: The input tensor. Shape: (BS, C, H, W).
        patch_size (int): The size of the patches.
        overlap (int): The size of the overlap.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor.
        device (torch.device): The device to use for the prediction.
        return_weights (bool, optional): Whether to return the weights. Can be used for debugging. Defaults to False.

    Returns:
        The predicted tensor.

    """
    logger.debug(
        f"Predicting on a tensor with shape {tensor_tiles.shape} "
        f"with patch_size {patch_size}, overlap {overlap} and batch_size {batch_size} on device {device}"
    )
    assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
    # Add a 1px + reflection border to avoid pixel loss when applying the soft margin and to reduce edge-artefacts
    p = 1 + reflection
    tensor_tiles = torch.nn.functional.pad(tensor_tiles, (p, p, p, p), mode="reflect")
    bs, c, h, w = tensor_tiles.shape
    step_size = patch_size - overlap
    nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)

    # Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
    patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap)

    # Flatten the patches so they fit to the model
    # (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size)
    patches = patches.view(bs * nh * nw, c, patch_size, patch_size)

    # Create a soft margin for the patches
    margin_ramp = torch.cat(
        [
            torch.linspace(0, 1, overlap),
            torch.ones(patch_size - 2 * overlap),
            torch.linspace(1, 0, overlap),
        ]
    )
    soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1)
    soft_margin = soft_margin.to(patches.device)

    # Infer logits with model and turn into probabilities with sigmoid in a batched manner
    # TODO: check with ingmar and jonas if moving all patches to the device at the same time is a good idea
    patched_probabilities = torch.zeros_like(patches[:, 0, :, :])
    patches = patches.split(batch_size)
    n_skipped = 0
    for i, batch in enumerate(patches):
        # If batch contains only nans, skip it
        if torch.isnan(batch).all(axis=0).any():
            patched_probabilities[i * batch_size : (i + 1) * batch_size] = 0
            n_skipped += 1
            continue
        # If batch contains some nans, replace them with zeros
        batch[torch.isnan(batch)] = 0

        batch = batch.to(device)
        # logger.debug(f"Predicting on batch {i + 1}/{len(patches)}")
        patched_probabilities[i * batch_size : (i + 1) * batch_size] = (
            torch.sigmoid(model(batch)).squeeze(1).to(patched_probabilities.device)
        )
        batch = batch.to(patched_probabilities.device)  # Transfer back to the original device to avoid memory leaks

    if n_skipped > 0:
        logger.debug(f"Skipped {n_skipped} batches because they only contained NaNs")

    patched_probabilities = patched_probabilities.view(bs, nh, nw, patch_size, patch_size)

    # Reconstruct the image from the patches
    prediction = torch.zeros(bs, h, w, device=tensor_tiles.device)
    weights = torch.zeros(bs, h, w, device=tensor_tiles.device)

    for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, overlap):
        patch = patched_probabilities[:, patch_idx_h, patch_idx_w]
        prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin
        weights[:, y : y + patch_size, x : x + patch_size] += soft_margin

    # Avoid division by zero
    weights = torch.where(weights == 0, torch.ones_like(weights), weights)
    prediction = prediction / weights

    # Remove the 1px border and the padding
    prediction = prediction[:, p:-p, p:-p]

    if return_weights:
        return prediction, weights
    else:
        return prediction