@stopuhr.funkuhr("Loading Sentinel 2 scene from STAC", printer=logger.debug, print_kwargs=["s2id"])
def load_s2_from_stac(
s2id: str,
bands_mapping: dict = {"B02_10m": "blue", "B03_10m": "green", "B04_10m": "red", "B08_10m": "nir"},
scale_and_offset: bool | tuple[float, float] = True,
cache: Path | None = None,
) -> xr.Dataset:
"""Load a Sentinel 2 scene from the Copernicus STAC API and return it as an xarray dataset.
Args:
s2id (str): The Sentinel 2 image ID.
bands_mapping (dict[str, str], optional): A mapping from bands to obtain.
Will be renamed to the corresponding band names.
Defaults to {"B2": "blue", "B3": "green", "B4": "red", "B8": "nir"}.
scale_and_offset (bool | tuple[float, float], optional): Whether to apply the scale and offset to the bands.
If a tuple is provided, it will be used as the (`scale`, `offset`) values with `band * scale + offset`.
If True, use the default values of `scale` = 0.0001 and `offset` = 0, taken from ee_extra.
Defaults to True.
cache (Path | None, optional): The path to the cache directory. If None, no caching will be done.
Defaults to None.
Returns:
xr.Dataset: The loaded dataset
"""
if "SCL_20m" not in bands_mapping.keys():
bands_mapping["SCL_20m"] = "scl"
catalog = Client.open("https://stac.dataspace.copernicus.eu/v1/")
search = catalog.search(
collections=["sentinel-2-l2a"],
ids=[s2id],
)
cache_file = None if cache is None else cache / f"gee-s2srh-{s2id}-{''.join(bands_mapping.keys())}.nc"
if cache_file is not None and cache_file.exists():
ds_s2 = xr.open_dataset(cache_file, engine="h5netcdf").set_coords("spatial_ref")
ds_s2.load()
logger.debug(f"Loaded {s2id=} from cache.")
else:
ds_s2 = xr.open_dataset(
search,
engine="stac",
backend_kwargs={"crs": "utm", "resolution": 10, "bands": list(bands_mapping.keys())},
)
ds_s2.attrs["time"] = str(ds_s2.time.values[0])
ds_s2 = ds_s2.isel(time=0).drop_vars("time")
logger.debug(
f"Found a dataset with shape {ds_s2.sizes} for tile {s2id=}."
"Start downloading data from STAC. This may take a while."
)
with stopuhr.stopuhr(f"Downloading data from STAC for {s2id=}", printer=logger.debug):
# Need double loading since the first load transforms lazy-stac to dask and second actually downloads
ds_s2.load().load()
if cache_file is not None:
ds_s2.to_netcdf(cache_file, engine="h5netcdf")
ds_s2 = ds_s2.rename_vars(bands_mapping)
for var in ds_s2.data_vars:
ds_s2[var].attrs["data_source"] = "s2-stac"
ds_s2[var].attrs["long_name"] = f"Sentinel 2 {var.capitalize()}"
ds_s2[var].attrs["units"] = "Reflectance"
ds_s2 = convert_masks(ds_s2)
if scale_and_offset:
if isinstance(scale_and_offset, tuple):
scale, offset = scale_and_offset
else:
scale, offset = 0.0001, 0
logger.debug(f"Applying {scale=} and {offset=} to {s2id=} optical data")
for band in set(bands_mapping.values()) - {"scl"}:
ds_s2[band] = ds_s2[band] * scale + offset
ds_s2.attrs["s2_tile_id"] = s2id
ds_s2.attrs["tile_id"] = s2id
return ds_s2