@torch.no_grad()
def predict_in_patches(
model: nn.Module,
tensor_tiles: torch.Tensor,
patch_size: int,
overlap: int,
batch_size: int,
reflection: int,
device=torch.device,
return_weights: bool = False,
) -> torch.Tensor:
"""Predict on a tensor.
Args:
model: The model to use for prediction.
tensor_tiles: The input tensor. Shape: (BS, C, H, W).
patch_size (int): The size of the patches.
overlap (int): The size of the overlap.
batch_size (int): The batch size for the prediction, NOT the batch_size of input tiles.
Tensor will be sliced into patches and these again will be infered in batches.
reflection (int): Reflection-Padding which will be applied to the edges of the tensor.
device (torch.device): The device to use for the prediction.
return_weights (bool, optional): Whether to return the weights. Can be used for debugging. Defaults to False.
Returns:
The predicted tensor.
"""
logger.debug(
f"Predicting on a tensor with shape {tensor_tiles.shape} "
f"with patch_size {patch_size}, overlap {overlap} and batch_size {batch_size} on device {device}"
)
assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
# Add a 1px + reflection border to avoid pixel loss when applying the soft margin and to reduce edge-artefacts
p = 1 + reflection
tensor_tiles = torch.nn.functional.pad(tensor_tiles, (p, p, p, p), mode="reflect")
bs, c, h, w = tensor_tiles.shape
step_size = patch_size - overlap
nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)
# Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
patches = create_patches(tensor_tiles, patch_size=patch_size, overlap=overlap)
# Flatten the patches so they fit to the model
# (BS, N_h, N_w, C, patch_size, patch_size) -> (BS * N_h * N_w, C, patch_size, patch_size)
patches = patches.view(bs * nh * nw, c, patch_size, patch_size)
# Create a soft margin for the patches
margin_ramp = torch.cat(
[
torch.linspace(0, 1, overlap),
torch.ones(patch_size - 2 * overlap),
torch.linspace(1, 0, overlap),
]
)
soft_margin = margin_ramp.reshape(1, 1, patch_size) * margin_ramp.reshape(1, patch_size, 1)
soft_margin = soft_margin.to(patches.device)
# Infer logits with model and turn into probabilities with sigmoid in a batched manner
# TODO: check with ingmar and jonas if moving all patches to the device at the same time is a good idea
patched_probabilities = torch.zeros_like(patches[:, 0, :, :])
patches = patches.split(batch_size)
n_skipped = 0
for i, batch in enumerate(patches):
# If batch contains only nans, skip it
# TODO: This doesn't work as expected -> check if torch.isnan(batch).all() is correct
if torch.isnan(batch).all(axis=0).any():
patched_probabilities[i * batch_size : (i + 1) * batch_size] = 0
n_skipped += 1
continue
# If batch contains some nans, replace them with zeros
batch[torch.isnan(batch)] = 0
batch = batch.to(device)
# logger.debug(f"Predicting on batch {i + 1}/{len(patches)}")
patched_probabilities[i * batch_size : (i + 1) * batch_size] = (
torch.sigmoid(model(batch)).squeeze(1).to(patched_probabilities.device)
)
batch = batch.to(patched_probabilities.device) # Transfer back to the original device to avoid memory leaks
if n_skipped > 0:
logger.debug(f"Skipped {n_skipped} batches because they only contained NaNs")
patched_probabilities = patched_probabilities.view(bs, nh, nw, patch_size, patch_size)
# Reconstruct the image from the patches
prediction = torch.zeros(bs, h, w, device=tensor_tiles.device)
weights = torch.zeros(bs, h, w, device=tensor_tiles.device)
for y, x, patch_idx_h, patch_idx_w in patch_coords(h, w, patch_size, overlap):
patch = patched_probabilities[:, patch_idx_h, patch_idx_w]
prediction[:, y : y + patch_size, x : x + patch_size] += patch * soft_margin
weights[:, y : y + patch_size, x : x + patch_size] += soft_margin
# Avoid division by zero
weights = torch.where(weights == 0, torch.ones_like(weights), weights)
prediction = prediction / weights
# Remove the 1px border and the padding
prediction = prediction[:, p:-p, p:-p]
if return_weights:
return prediction, weights
else:
return prediction