@stopuhr.funkuhr("Loading Sentinel 2 scene from GEE", printer=logger.debug, print_kwargs=["img"])
def load_s2_from_gee(
img: str | ee.Image,
bands_mapping: dict = {"B2": "blue", "B3": "green", "B4": "red", "B8": "nir"},
scale_and_offset: bool | tuple[float, float] = True,
cache: Path | None = None,
) -> xr.Dataset:
"""Load a Sentinel 2 scene from Google Earth Engine and return it as an xarray dataset.
Args:
img (str | ee.Image): The Sentinel 2 image ID or the ee image object.
bands_mapping (dict[str, str], optional): A mapping from bands to obtain.
Will be renamed to the corresponding band names.
Defaults to {"B2": "blue", "B3": "green", "B4": "red", "B8": "nir"}.
scale_and_offset (bool | tuple[float, float], optional): Whether to apply the scale and offset to the bands.
If a tuple is provided, it will be used as the (`scale`, `offset`) values with `band * scale + offset`.
If True, use the default values of `scale` = 0.0001 and `offset` = 0, taken from ee_extra.
Defaults to True.
cache (Path | None, optional): The path to the cache directory. If None, no caching will be done.
Defaults to None.
Returns:
xr.Dataset: The loaded dataset
"""
if isinstance(img, str):
s2id = img
img = ee.Image(f"COPERNICUS/S2_SR_HARMONIZED/{s2id}")
else:
s2id = img.id().getInfo().split("/")[-1]
logger.debug(f"Loading Sentinel 2 tile {s2id=} from GEE")
if "SCL" not in bands_mapping.keys():
bands_mapping["SCL"] = "scl"
cache_file = None if cache is None else cache / f"gee-s2srh-{s2id}-{''.join(bands_mapping.keys())}.nc"
if cache_file is not None and cache_file.exists():
ds_s2 = xr.open_dataset(cache_file, engine="h5netcdf").set_coords("spatial_ref")
ds_s2.load()
logger.debug(f"Loaded {s2id=} from cache.")
else:
img = img.select(list(bands_mapping.keys()))
ds_s2 = xr.open_dataset(
img,
engine="ee",
geometry=img.geometry(),
crs=img.select(0).projection().crs().getInfo(),
scale=10,
)
ds_s2.attrs["time"] = str(ds_s2.time.values[0])
ds_s2 = ds_s2.isel(time=0).drop_vars("time").rename({"X": "x", "Y": "y"}).transpose("y", "x")
ds_s2 = ds_s2.odc.assign_crs(ds_s2.attrs["crs"])
logger.debug(
f"Found dataset with shape {ds_s2.sizes} for tile {s2id=}."
"Start downloading data from GEE. This may take a while."
)
with stopuhr.stopuhr(f"Downloading data from GEE for {s2id=}", printer=logger.debug):
ds_s2.load()
if cache_file is not None:
ds_s2.to_netcdf(cache_file, engine="h5netcdf")
ds_s2 = ds_s2.rename_vars(bands_mapping)
for var in ds_s2.data_vars:
ds_s2[var].attrs["data_source"] = "s2-gee"
ds_s2[var].attrs["long_name"] = f"Sentinel 2 {var.capitalize()}"
ds_s2[var].attrs["units"] = "Reflectance"
ds_s2 = convert_masks(ds_s2)
# For some reason, there are some spatially random nan values in the data, not only at the borders
# To workaround this, set all nan values to 0 and add this information to the quality_data_mask
# This workaround is quite computational expensive, but it works for now
# TODO: Find other solutions for this problem!
with stopuhr.stopuhr(f"Fixing nan values in {s2id=}", printer=logger.debug):
for band in set(bands_mapping.values()) - {"scl"}:
ds_s2["quality_data_mask"] = xr.where(ds_s2[band].isnull(), 0, ds_s2["quality_data_mask"])
ds_s2[band] = ds_s2[band].fillna(0)
# Turn real nan values (scl is nan) into invalid data
ds_s2[band] = ds_s2[band].where(~ds_s2["scl"].isnull())
if scale_and_offset:
if isinstance(scale_and_offset, tuple):
scale, offset = scale_and_offset
else:
scale, offset = 0.0001, 0
logger.debug(f"Applying {scale=} and {offset=} to {s2id=} optical data")
for band in set(bands_mapping.values()) - {"scl"}:
ds_s2[band] = ds_s2[band] * scale + offset
ds_s2.attrs["s2_tile_id"] = s2id
ds_s2.attrs["tile_id"] = s2id
return ds_s2