darts_segmentation.training.create_training_patches¶
Create training patches from a tile and labels.
Parameters:
-
tile(xarray.Dataset) –The input tile, containing preprocessed, harmonized data.
-
labels(geopandas.GeoDataFrame) –The labels to be used for training.
-
bands(list[str]) –The bands to be used for training. Must be present in the tile.
-
norm_factors(dict[str, float]) –The normalization factors for the bands.
-
patch_size(int) –The size of the patches.
-
overlap(int) –The size of the overlap.
-
exclude_nopositive(bool) –Whether to exclude patches where the labels do not contain positives.
-
exclude_nan(bool) –Whether to exclude patches where the input data has nan values.
-
device(typing.Literal['cuda', 'cpu'] | int) –The device to use for the erosion.
-
mask_erosion_size(int) –The size of the disk to use for erosion.
Yields:
-
collections.abc.Generator[tuple[torch.tensor, torch.tensor]]–Generator[tuple[torch.tensor, torch.tensor]]: A tuple containing the input and the labels as pytorch tensors. The input has the format (C, H, W), the labels (H, W).
Raises:
-
ValueError–If a band is not found in the preprocessed data.
Source code in darts-segmentation/src/darts_segmentation/training/prepare_training.py
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | |