Skip to content

darts_acquisition.load_s2_from_gee

Load a Sentinel 2 scene from Google Earth Engine and return it as an xarray dataset.

Parameters:

  • img (str | ee.Image) –

    The Sentinel 2 image ID or the ee image object.

  • bands_mapping (dict[str, str], default: {'B2': 'blue', 'B3': 'green', 'B4': 'red', 'B8': 'nir'} ) –

    A mapping from bands to obtain. Will be renamed to the corresponding band names. Defaults to {"B2": "blue", "B3": "green", "B4": "red", "B8": "nir"}.

  • scale_and_offset (bool | tuple[float, float], default: True ) –

    Whether to apply the scale and offset to the bands. If a tuple is provided, it will be used as the (scale, offset) values with band * scale + offset. If True, use the default values of scale = 0.0001 and offset = 0, taken from ee_extra. Defaults to True.

  • cache (pathlib.Path | None, default: None ) –

    The path to the cache directory. If None, no caching will be done. Defaults to None.

Returns:

Source code in darts-acquisition/src/darts_acquisition/s2.py
@stopuhr.funkuhr("Loading Sentinel 2 scene from GEE", printer=logger.debug, print_kwargs=["img"])
def load_s2_from_gee(
    img: str | ee.Image,
    bands_mapping: dict = {"B2": "blue", "B3": "green", "B4": "red", "B8": "nir"},
    scale_and_offset: bool | tuple[float, float] = True,
    cache: Path | None = None,
) -> xr.Dataset:
    """Load a Sentinel 2 scene from Google Earth Engine and return it as an xarray dataset.

    Args:
        img (str | ee.Image): The Sentinel 2 image ID or the ee image object.
        bands_mapping (dict[str, str], optional): A mapping from bands to obtain.
            Will be renamed to the corresponding band names.
            Defaults to {"B2": "blue", "B3": "green", "B4": "red", "B8": "nir"}.
        scale_and_offset (bool | tuple[float, float], optional): Whether to apply the scale and offset to the bands.
            If a tuple is provided, it will be used as the (`scale`, `offset`) values with `band * scale + offset`.
            If True, use the default values of `scale` = 0.0001 and `offset` = 0, taken from ee_extra.
            Defaults to True.
        cache (Path | None, optional): The path to the cache directory. If None, no caching will be done.
            Defaults to None.

    Returns:
        xr.Dataset: The loaded dataset

    """
    if isinstance(img, str):
        s2id = img
        img = ee.Image(f"COPERNICUS/S2_SR_HARMONIZED/{s2id}")
    else:
        s2id = img.id().getInfo().split("/")[-1]
    logger.debug(f"Loading Sentinel 2 tile {s2id=} from GEE")

    if "SCL" not in bands_mapping.keys():
        bands_mapping["SCL"] = "scl"

    cache_file = None if cache is None else cache / f"gee-s2srh-{s2id}-{''.join(bands_mapping.keys())}.nc"
    if cache_file is not None and cache_file.exists():
        ds_s2 = xr.open_dataset(cache_file, engine="h5netcdf").set_coords("spatial_ref")
        ds_s2.load()
        logger.debug(f"Loaded {s2id=} from cache.")
    else:
        img = img.select(list(bands_mapping.keys()))
        ds_s2 = xr.open_dataset(
            img,
            engine="ee",
            geometry=img.geometry(),
            crs=img.select(0).projection().crs().getInfo(),
            scale=10,
        )
        ds_s2.attrs["time"] = str(ds_s2.time.values[0])
        ds_s2 = ds_s2.isel(time=0).drop_vars("time").rename({"X": "x", "Y": "y"}).transpose("y", "x")
        ds_s2 = ds_s2.odc.assign_crs(ds_s2.attrs["crs"])
        logger.debug(
            f"Found dataset with shape {ds_s2.sizes} for tile {s2id=}."
            "Start downloading data from GEE. This may take a while."
        )

        with stopuhr.stopuhr(f"Downloading data from GEE for {s2id=}", printer=logger.debug):
            ds_s2.load()
            if cache_file is not None:
                ds_s2.to_netcdf(cache_file, engine="h5netcdf")

    ds_s2 = ds_s2.rename_vars(bands_mapping)

    for var in ds_s2.data_vars:
        ds_s2[var].attrs["data_source"] = "s2-gee"
        ds_s2[var].attrs["long_name"] = f"Sentinel 2 {var.capitalize()}"
        ds_s2[var].attrs["units"] = "Reflectance"

    ds_s2 = convert_masks(ds_s2)

    # For some reason, there are some spatially random nan values in the data, not only at the borders
    # To workaround this, set all nan values to 0 and add this information to the quality_data_mask
    # This workaround is quite computational expensive, but it works for now
    # TODO: Find other solutions for this problem!
    with stopuhr.stopuhr(f"Fixing nan values in {s2id=}", printer=logger.debug):
        for band in set(bands_mapping.values()) - {"scl"}:
            ds_s2["quality_data_mask"] = xr.where(ds_s2[band].isnull(), 0, ds_s2["quality_data_mask"])
            ds_s2[band] = ds_s2[band].fillna(0)
            # Turn real nan values (scl is nan) into invalid data
            ds_s2[band] = ds_s2[band].where(~ds_s2["scl"].isnull())

    if scale_and_offset:
        if isinstance(scale_and_offset, tuple):
            scale, offset = scale_and_offset
        else:
            scale, offset = 0.0001, 0
        logger.debug(f"Applying {scale=} and {offset=} to {s2id=} optical data")
        for band in set(bands_mapping.values()) - {"scl"}:
            ds_s2[band] = ds_s2[band] * scale + offset

    ds_s2.attrs["s2_tile_id"] = s2id
    ds_s2.attrs["tile_id"] = s2id

    return ds_s2