Skip to content

darts_segmentation.predict_in_patches

Predict on a tensor.

Parameters:

  • model (torch.nn.Module) –

    The model to use for prediction.

  • tensor_tiles (torch.Tensor) –

    The input tensor. Shape: (BS, C, H, W).

  • patch_size (int) –

    The size of the patches.

  • overlap (int) –

    The size of the overlap.

  • batch_size (int) –

    The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches.

  • reflection (int) –

    Reflection-Padding which will be applied to the edges of the tensor.

  • device (torch.device, default: torch.device ) –

    The device to use for the prediction.

  • return_weights (bool, default: False ) –

    Whether to return the weights. Can be used for debugging. Defaults to False.

Returns:

Source code in darts-segmentation/src/darts_segmentation/utils.py
@torch.no_grad()
def predict_in_patches(
    model: nn.Module,
    tensor_tiles: torch.Tensor,
    patch_size: int,
    overlap: int,
    batch_size: int,
    reflection: int,
    device=torch.device,
    return_weights: bool = False,
) -> torch.Tensor:
    """Predict on a tensor.

    Args:
        model: The model to use for prediction.
        tensor_tiles: The input tensor. Shape: (BS, C, H, W).
        patch_size (int): The size of the patches.
        overlap (int): The size of the overlap.
        batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
            Tensor will be sliced into patches and these again will be infered in batches.
        reflection (int): Reflection-Padding which will be applied to the edges of the tensor.
        device (torch.device): The device to use for the prediction.
        return_weights (bool, optional): Whether to return the weights. Can be used for debugging. Defaults to False.

    Returns:
        The predicted tensor.

    """
    logger.debug(
        f"Predicting on a tensor with shape {tensor_tiles.shape} "
        f"with patch_size {patch_size}, overlap {overlap} and batch_size {batch_size} on device {device}"
    )
    assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
    # Add a 1px + reflection border to avoid pixel loss when applying the soft margin and to reduce edge-artefacts
    p = 1 + reflection
    tensor_tiles = torch.nn.functional.pad(tensor_tiles, (p, p, p, p), mode="reflect")
    bs, c, h, w = tensor_tiles.shape
    step_size = patch_size - overlap
    nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)

    # Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
    patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap)

    # Flatten the patches so they fit to the model
    # (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size)
    patches = patches.view(bs * nh * nw, c, patch_size, patch_size)

    # Create a soft margin for the patches
    margin_ramp = torch.cat(
        [
            torch.linspace(0, 1, overlap),
            torch.ones(patch_size - 2 * overlap),
            torch.linspace(1, 0, overlap),
        ]
    )
    soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1)
    soft_margin = soft_margin.to(patches.device)

    # Infer logits with model and turn into probabilities with sigmoid in a batched manner
    # TODO: check with ingmar and jonas if moving all patches to the device at the same time is a good idea
    patched_probabilities = torch.zeros_like(patches[:, 0, :, :])
    patches = patches.split(batch_size)
    n_skipped = 0
    for i, batch in enumerate(patches):
        # If batch contains only nans, skip it
        # TODO: This doesn't work as expected -> check if torch.isnan(batch).all() is correct
        if torch.isnan(batch).all(axis=0).any():
            patched_probabilities[i * batch_size : (i + 1) * batch_size] = 0
            n_skipped += 1
            continue
        # If batch contains some nans, replace them with zeros
        batch[torch.isnan(batch)] = 0

        batch = batch.to(device)
        # logger.debug(f"Predicting on batch {i + 1}/{len(patches)}")
        patched_probabilities[i * batch_size : (i + 1) * batch_size] = (
            torch.sigmoid(model(batch)).squeeze(1).to(patched_probabilities.device)
        )
        batch = batch.to(patched_probabilities.device)  # Transfer back to the original device to avoid memory leaks

    if n_skipped > 0:
        logger.debug(f"Skipped {n_skipped} batches because they only contained NaNs")

    patched_probabilities = patched_probabilities.view(bs, nh, nw, patch_size, patch_size)

    # Reconstruct the image from the patches
    prediction = torch.zeros(bs, h, w, device=tensor_tiles.device)
    weights = torch.zeros(bs, h, w, device=tensor_tiles.device)

    for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, overlap):
        patch = patched_probabilities[:, patch_idx_h, patch_idx_w]
        prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin
        weights[:, y : y + patch_size, x : x + patch_size] += soft_margin

    # Avoid division by zero
    weights = torch.where(weights == 0, torch.ones_like(weights), weights)
    prediction = prediction / weights

    # Remove the 1px border and the padding
    prediction = prediction[:, p:-p, p:-p]

    if return_weights:
        return prediction, weights
    else:
        return prediction