darts_segmentation.SMPSegmenter¶
An actor that keeps a model as its state and segments tiles.
Initialize the segmenter.
Parameters:
-
model_checkpoint(pathlib.Path) –The path to the model checkpoint.
-
device(torch.device, default:darts_segmentation.segment.DEFAULT_DEVICE) –The device to run the model on. Defaults to torch.device("cuda") if cuda is available, else torch.device("cpu").
Source code in darts-segmentation/src/darts_segmentation/segment.py
config
instance-attribute
¶
config: darts_segmentation.segment.SMPSegmenterConfig = (
darts_segmentation.segment.validate_config(
ckpt["config"]
)
)
device
instance-attribute
¶
device: torch.device = (
darts_segmentation.segment.SMPSegmenter(device)
)
model
instance-attribute
¶
model: torch.nn.Module = (
segmentation_models_pytorch.create_model(
**darts_segmentation.segment.SMPSegmenter(
self
).config["model"]
)
)
__call__
¶
__call__(
input: xarray.Dataset | list[xarray.Dataset],
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
) -> xarray.Dataset | list[xarray.Dataset]
Run inference on a single tile or a list of tiles.
Parameters:
-
input(xarray.Dataset | list[xarray.Dataset]) –A single tile or a list of tiles.
-
patch_size(int, default:1024) –The size of the patches. Defaults to 1024.
-
overlap(int, default:16) –The size of the overlap. Defaults to 16.
-
batch_size(int, default:8) –The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
-
reflection(int, default:0) –Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
Returns:
-
xarray.Dataset | list[xarray.Dataset]–A single tile or a list of tiles augmented by a predicted
probabilitieslayer, depending on the input. -
xarray.Dataset | list[xarray.Dataset]–Each
probabilityhas type float32 and range [0, 1].
Raises:
-
ValueError–in case the input is not an xr.Dataset or a list of xr.Dataset
Source code in darts-segmentation/src/darts_segmentation/segment.py
segment_tile
¶
segment_tile(
tile: xarray.Dataset,
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
) -> xarray.Dataset
Run inference on a tile.
Parameters:
-
tile(xarray.Dataset) –The input tile, containing preprocessed, harmonized data.
-
patch_size(int, default:1024) –The size of the patches. Defaults to 1024.
-
overlap(int, default:16) –The size of the overlap. Defaults to 16.
-
batch_size(int, default:8) –The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
-
reflection(int, default:0) –Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
Returns:
-
xarray.Dataset–Input tile augmented by a predicted
probabilitieslayer with type float32 and range [0, 1].
Source code in darts-segmentation/src/darts_segmentation/segment.py
segment_tile_batched
¶
segment_tile_batched(
tiles: list[xarray.Dataset],
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
) -> list[xarray.Dataset]
Run inference on a list of tiles.
Parameters:
-
tiles(list[xarray.Dataset]) –The input tiles, containing preprocessed, harmonized data.
-
patch_size(int, default:1024) –The size of the patches. Defaults to 1024.
-
overlap(int, default:16) –The size of the overlap. Defaults to 16.
-
batch_size(int, default:8) –The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
-
reflection(int, default:0) –Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
Returns:
-
list[xarray.Dataset]–A list of input tiles augmented by a predicted
probabilitieslayer with type float32 and range [0, 1].
Source code in darts-segmentation/src/darts_segmentation/segment.py
tile2tensor
¶
Take a tile and convert it to a pytorch tensor.
Respects the input combination from the config.
Returns:
-
torch.Tensor–A torch tensor for the full tile consisting of the bands specified in
self.band_combination.
Source code in darts-segmentation/src/darts_segmentation/segment.py
tile2tensor_batched
¶
Take a list of tiles and convert them to a pytorch tensor.
Respects the the input combination from the config.
Returns:
-
torch.Tensor–A torch tensor for the full tile consisting of the bands specified in
self.band_combination.