@torch.no_grad()
def create_patches(
tensor_tiles: torch.Tensor, patch_size: int, overlap: int, return_coords: bool = False
) -> torch.Tensor:
"""Create patches from a tensor.
Args:
tensor_tiles (torch.Tensor): The input tensor. Shape: (BS, C, H, W).
patch_size (int, optional): The size of the patches.
overlap (int, optional): The size of the overlap.
return_coords (bool, optional): Whether to return the coordinates of the patches.
Can be used for debugging. Defaults to False.
Returns:
torch.Tensor: The patches. Shape: (BS, N_h, N_w, C, patch_size, patch_size).
"""
logger.debug(
f"Creating patches from a tensor with shape {tensor_tiles.shape} "
f"with patch_size {patch_size} and overlap {overlap}"
)
assert tensor_tiles.dim() == 4, f"Expects tensor_tiles to has shape (BS, C, H, W), got {tensor_tiles.shape}"
bs, c, h, w = tensor_tiles.shape
assert h > patch_size > overlap
assert w > patch_size > overlap
step_size = patch_size - overlap
# The problem with unfold is that is cuts off the last patch if it doesn't fit exactly
# Padding could help, but then the next problem is that the view needs to get reshaped (copied in memory)
# to fit the model input shape. Such a complex view can't be inserted into the model.
# Since we need, doing it manually is currently our best choice, since be can avoid the padding.
# patches = (
# tensor_tiles.unfold(2, patch_size, step_size).unfold(3, patch_size, step_size).transpose(1, 2).transpose(2, 3)
# )
# return patches
nh, nw = math.ceil((h - overlap) / step_size), math.ceil((w - overlap) / step_size)
# Create Patches of size (BS, N_h, N_w, C, patch_size, patch_size)
patches = torch.zeros((bs, nh, nw, c, patch_size, patch_size), device=tensor_tiles.device)
coords = torch.zeros((nh, nw, 5))
for i, (y, x, patch_idx_h, patch_idx_w) in enumerate(patch_coords(h, w, patch_size, overlap)):
patches[:, patch_idx_h, patch_idx_w, :] = tensor_tiles[:, :, y : y + patch_size, x : x + patch_size]
coords[patch_idx_h, patch_idx_w, :] = torch.tensor([i, y, x, patch_idx_h, patch_idx_w])
if return_coords:
return patches, coords
else:
return patches