darts_segmentation.training.create_training_patches¶
Create training patches from a tile and labels.
Parameters:
-
tile
(xarray.Dataset
) –The input tile, containing preprocessed, harmonized data.
-
labels
(geopandas.GeoDataFrame
) –The labels to be used for training.
-
bands
(list[str]
) –The bands to be used for training. Must be present in the tile.
-
norm_factors
(dict[str, float]
) –The normalization factors for the bands.
-
patch_size
(int
) –The size of the patches.
-
overlap
(int
) –The size of the overlap.
-
exclude_nopositive
(bool
) –Whether to exclude patches where the labels do not contain positives.
-
exclude_nan
(bool
) –Whether to exclude patches where the input data has nan values.
-
device
(typing.Literal['cuda', 'cpu'] | int
) –The device to use for the erosion.
-
mask_erosion_size
(int
) –The size of the disk to use for erosion.
Yields:
-
collections.abc.Generator[tuple[torch.tensor, torch.tensor]]
–Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors. The input has the format (C, H, W), the labels (H, W).
Raises:
-
ValueError
–If a band is not found in the preprocessed data.
Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 |
|