darts_ensemble.ensemble_v1
¶
DARTS v1 ensemble based on two models, one trained with TCVIS data and the other without.
DEFAULT_DEVICE
module-attribute
¶
DEFAULT_DEVICE = torch.device(
"cuda" if torch.cuda.is_available() else "cpu"
)
EnsembleV1
¶
EnsembleV1(
model_dict,
device: torch.device = darts_ensemble.ensemble_v1.DEFAULT_DEVICE,
)
Model ensemble that averages predictions from multiple segmentation models.
This class manages multiple trained segmentation models and combines their predictions by averaging, providing more robust and stable predictions than any single model. It's particularly useful for combining models trained with different data sources (e.g., with and without TCVIS data).
Attributes:
-
models(dict[str, darts_segmentation.segment.SMPSegmenter]) –Dictionary mapping model names to loaded segmenters.
Note
The ensemble automatically: - Manages multiple model instances with separate configurations - Handles band requirements across all models - Averages probability predictions (simple arithmetic mean) - Optionally preserves individual model outputs for analysis
Example
Create and use an ensemble:
from darts_ensemble import EnsembleV1
import torch
# Initialize ensemble with multiple models
ensemble = EnsembleV1(
model_dict={
"with_tcvis": "path/to/model_with_tcvis.ckpt",
"without_tcvis": "path/to/model_without_tcvis.ckpt",
},
device=torch.device("cuda")
)
# Check combined band requirements
print(ensemble.required_bands)
# {'blue', 'green', 'red', 'nir', 'ndvi', 'tc_brightness', ...}
# Run ensemble inference
result = ensemble.segment_tile(
tile=preprocessed_tile,
keep_inputs=True # Keep individual model predictions
)
# Access predictions
ensemble_probs = result["probabilities"] # Averaged
model1_probs = result["probabilities-with_tcvis"] # Individual
model2_probs = result["probabilities-without_tcvis"] # Individual
Initialize the ensemble with multiple model checkpoints.
Parameters:
-
model_dict(dict[str, str | pathlib.Path]) –Mapping of model identifiers to checkpoint paths. Keys are used to name individual model outputs (e.g., "with_tcvis", "without_tcvis"). Values are paths to model checkpoint files.
-
device(torch.device, default:darts_ensemble.ensemble_v1.DEFAULT_DEVICE) –Device to load all models on. Defaults to CUDA if available, else CPU.
Note
All models are loaded on the same device. For multi-GPU ensembles, instantiate separate EnsembleV1 objects per device.
Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
models
instance-attribute
¶
models = {
k: (
darts_segmentation.segment.SMPSegmenter(
v,
device=darts_ensemble.ensemble_v1.EnsembleV1(
device
),
)
)
for (k, v) in (model_paths.items())
}
required_bands
property
¶
The combined bands required by all models in this ensemble.
__call__
¶
__call__(
input: xarray.Dataset | list[xarray.Dataset],
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
keep_inputs: bool = False,
) -> xarray.Dataset
Run the ensemble on the given tile.
Parameters:
-
input(xarray.Dataset | list[xarray.Dataset]) –A single tile or a list of tiles.
-
patch_size(int, default:1024) –The size of the patches. Defaults to 1024.
-
overlap(int, default:16) –The size of the overlap. Defaults to 16.
-
batch_size(int, default:8) –The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
-
reflection(int, default:0) –Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
-
keep_inputs(bool, default:False) –Whether to keep the input probabilities in the output. Defaults to False.
Returns:
Raises:
-
ValueError–in case the input is not an xr.Dataset or a list of xr.Dataset
Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
segment_tile
¶
segment_tile(
tile: xarray.Dataset,
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
keep_inputs: bool = False,
) -> xarray.Dataset
Run ensemble inference on a single tile by averaging multiple model predictions.
Each model in the ensemble processes the tile independently, then predictions are combined by simple arithmetic averaging to produce the final ensemble prediction.
Parameters:
-
tile(xarray.Dataset) –Input tile containing preprocessed data. Must include all bands required by any model in the ensemble (union of all
required_bands). -
patch_size(int, default:1024) –Size of square patches for inference in pixels. Defaults to 1024.
-
overlap(int, default:16) –Overlap between adjacent patches in pixels. Defaults to 16.
-
batch_size(int, default:8) –Number of patches to process simultaneously per model. Defaults to 8.
-
reflection(int, default:0) –Reflection padding applied to tile edges in pixels. Defaults to 0.
-
keep_inputs(bool, default:False) –If True, preserves individual model predictions as separate variables (e.g., "probabilities-with_tcvis"). Defaults to False.
Returns:
-
xarray.Dataset–xr.Dataset: Input tile augmented with: - probabilities (float32): Ensemble-averaged predictions in range [0, 1]. Attributes: long_name="Probabilities" - probabilities-{model_name} (float32): Individual model predictions (only if keep_inputs=True)
Note
Averaging method: Simple arithmetic mean across all models. For N models: ensemble_prob = (prob_1 + prob_2 + ... + prob_N) / N
This approach assumes equal confidence in all models. Consider weighted averaging if models have different validation performances.
Example
Run ensemble with analysis of individual models:
result = ensemble.segment_tile(
tile=preprocessed_tile,
patch_size=1024,
overlap=16,
keep_inputs=True # Keep individual predictions
)
# Compare ensemble vs individual models
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, 3)
result["probabilities"].plot(ax=axes[0], title="Ensemble")
result["probabilities-with_tcvis"].plot(ax=axes[1], title="Model 1")
result["probabilities-without_tcvis"].plot(ax=axes[2], title="Model 2")
Source code in darts-ensemble/src/darts_ensemble/ensemble_v1.py
106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | |
segment_tile_batched
¶
segment_tile_batched(
tiles: list[xarray.Dataset],
patch_size: int = 1024,
overlap: int = 16,
batch_size: int = 8,
reflection: int = 0,
keep_inputs: bool = False,
) -> list[xarray.Dataset]
Run inference on a list of tiles.
Parameters:
-
tiles(list[xarray.Dataset]) –The input tiles, containing preprocessed, harmonized data.
-
patch_size(int, default:1024) –The size of the patches. Defaults to 1024.
-
overlap(int, default:16) –The size of the overlap. Defaults to 16.
-
batch_size(int, default:8) –The batch size for the prediction, NOT the batch_size of input tiles. Tensor will be sliced into patches and these again will be infered in batches. Defaults to 8.
-
reflection(int, default:0) –Reflection-Padding which will be applied to the edges of the tensor. Defaults to 0.
-
keep_inputs(bool, default:False) –Whether to keep the input probabilities in the output. Defaults to False.
Returns:
-
list[xarray.Dataset]–A list of input tiles augmented by a predicted
probabilitieslayer with type float32 and range [0, 1].