Skip to content

cuda

darts_utils.cuda

Utility functions around cuda, e.g. memory management.

DEFAULT_DEVICE module-attribute

DEFAULT_DEVICE = 'cuda'

logger module-attribute

logger = logging.getLogger(
    __name__.replace("darts_", "darts.")
)

free_cupy

free_cupy()

Free the CUDA memory of cupy.

Source code in darts-utils/src/darts_utils/cuda.py
def free_cupy():
    """Free the CUDA memory of cupy."""
    try:
        import cupy as cp  # type: ignore
    except ImportError:
        cp = None

    if cp is not None:
        gc.collect()
        cp.get_default_memory_pool().free_all_blocks()
        cp.get_default_pinned_memory_pool().free_all_blocks()

free_torch

free_torch()

Free the CUDA memory of pytorch.

Source code in darts-utils/src/darts_utils/cuda.py
def free_torch():
    """Free the CUDA memory of pytorch."""
    import torch

    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

move_to_device

move_to_device(
    tile: xarray.Dataset,
    device: typing.Literal["cuda", "cpu"] | int,
)

Context manager to ensure a dataset is on the correct device.

Parameters:

  • tile (xarray.Dataset) –

    The xarray dataset to operate on.

  • device (typing.Literal['cuda', 'cpu'] | int) –

    The device to use for calculations (either "cuda", "cpu", or a specific GPU index).

Returns:

  • xr.Dataset: The xarray dataset on the specified device.

Source code in darts-utils/src/darts_utils/cuda.py
def move_to_device(
    tile: xr.Dataset,
    device: Literal["cuda", "cpu"] | int,
):
    """Context manager to ensure a dataset is on the correct device.

    Args:
        tile: The xarray dataset to operate on.
        device: The device to use for calculations (either "cuda", "cpu", or a specific GPU index).

    Returns:
        xr.Dataset: The xarray dataset on the specified device.

    """
    use_gpu = device == "cuda" or isinstance(device, int)

    # Warn user if use_gpu is set but no GPU is available
    if use_gpu and not has_cuda_and_cupy():
        logger.warning(
            f"Device was set to {device}, but GPU acceleration is not available. Calculating optical indices on CPU."
        )
        use_gpu = False

    if use_gpu:
        device_nr = device if isinstance(device, int) else 0
        # Persist in case of dask - since cupy-dask is not supported
        if tile.chunks is not None:
            logger.debug("Persisting dask array before moving to GPU.")
            tile = tile.persist()
        # Move and calculate on specified device
        logger.debug(f"Moving tile to GPU:{device}.")
        with cp.cuda.Device(device_nr):
            tile = tile.cupy.as_cupy()
    return tile

move_to_host

move_to_host(tile: xarray.Dataset) -> xarray.Dataset

Move a dataset from GPU to CPU.

Parameters:

Returns:

Source code in darts-utils/src/darts_utils/cuda.py
def move_to_host(tile: xr.Dataset) -> xr.Dataset:
    """Move a dataset from GPU to CPU.

    Args:
        tile (xr.Dataset): The xarray dataset to move.

    Returns:
        xr.Dataset: _description_

    """
    if tile.cupy.is_cupy:
        tile = tile.cupy.as_numpy()
        free_cupy()
    return tile