Print debug information about the CUDA devices and library installations.
Source code in darts/src/darts/utils/cuda.py
| def debug_info(): # noqa: C901
"""Print debug information about the CUDA devices and library installations."""
import os
logger.debug("===vvv CUDA DEBUG INFO vvv===")
important_env_vars = [
"CUDA_HOME",
"CUDA_PATH",
"LD_LIBRARY_PATH",
"NUMBA_CUDA_DRIVER",
"NUMBA_CUDA_INCLUDE_PATH",
]
for v in important_env_vars:
value = os.environ.get(v, "UNSET")
logger.debug(f"{v}: {value}")
logger.debug("Quicknote: CUDA driver is something different than CUDA runtime, hence versions can mismatch")
try:
from pynvml import ( # type: ignore
nvmlDeviceGetCount,
nvmlDeviceGetHandleByIndex,
nvmlDeviceGetMemoryInfo,
nvmlDeviceGetName,
nvmlInit,
nvmlShutdown,
nvmlSystemGetCudaDriverVersion_v2,
nvmlSystemGetDriverVersion,
)
nvmlInit()
cuda_driver_version_legacy = nvmlSystemGetDriverVersion().decode()
cuda_driver_version = nvmlSystemGetCudaDriverVersion_v2()
logger.debug(f"CUDA driver version: {cuda_driver_version} ({cuda_driver_version_legacy})")
ndevices = nvmlDeviceGetCount()
logger.debug(f"Number of CUDA devices: {ndevices}")
for i in range(ndevices):
handle = nvmlDeviceGetHandleByIndex(i)
device_name = nvmlDeviceGetName(handle).decode()
meminfo = nvmlDeviceGetMemoryInfo(handle)
logger.debug(f"Device {i} ({device_name}): {meminfo.used / meminfo.total:.2%} memory usage.")
nvmlShutdown()
except ImportError:
logger.debug("Module 'pynvml' could not be imported. darts is probably installed without CUDA support.")
try:
import torch
logger.debug(f"PyTorch version: {torch.__version__}")
logger.debug(f"PyTorch CUDA available: {torch.cuda.is_available()}")
logger.debug(f"PyTorch CUDA runtime version: {torch.version.cuda}")
except ImportError as e:
logger.error("Module 'torch' could not be imported:")
logger.exception(e, exc_info=True)
try:
import cupy # type: ignore
logger.debug(f"Cupy version: {cupy.__version__}")
cupy_driver_version = cupy.cuda.runtime.driverGetVersion()
logger.debug(f"Cupy CUDA driver version: {cupy_driver_version}")
# This is the version which is installed (dynamically linked via PATH or LD_LIBRARY_PATH) in the environment
env_runtime_version = cupy.cuda.get_local_runtime_version()
logger.debug(f"Cupy CUDA runtime version: {env_runtime_version}")
if cupy_driver_version < env_runtime_version:
logger.warning(
"CUDA runtime version is newer than CUDA driver version!"
" The CUDA environment is probably not setup correctly!"
" Consider linking CUDA to an older version with CUDA_HOME and LD_LIBRARY_PATH environment variables,"
" or in case of a setup done by pixi choose a different environment with the -e flag."
)
# This is the version which is was used when cupy was compiled (statically linked)
cupy_runtime_version = cupy.cuda.runtime.runtimeGetVersion()
if env_runtime_version != cupy_runtime_version:
logger.debug(
"Cupy CUDA runtime versions don't match!\n"
f"Got {env_runtime_version} as local (dynamically linked) runtime version.\n"
f"Got {cupy_runtime_version} as by cupy statically linked runtime version.\n"
"This can happen if cupy was compiled using a different CUDA runtime version. "
"Things should still work, note that Cupy will use the dynamically linked version."
)
except ImportError:
logger.debug("Module 'cupy' not found, darts is probably installed without CUDA support.")
try:
import numba.cuda
cuda_available = numba.cuda.is_available()
logger.debug(f"Numba CUDA is available: {cuda_available}")
if cuda_available:
logger.debug(f"Numba CUDA runtime version: {numba.cuda.runtime.get_version()}")
# logger.debug(f"Numba CUDA has supported devices: {numba.cuda.detect()}")
except ImportError:
logger.debug("Module 'numba.cuda' not found, darts is probably installed without CUDA support.")
from xrspatial.utils import has_cuda_and_cupy
logger.debug(f"Cupy+Numba CUDA available: {has_cuda_and_cupy()}")
try:
import cucim # type: ignore
logger.debug(f"Cucim version: {cucim.__version__}")
except ImportError:
logger.debug("Module 'cucim' not found, darts is probably installed without CUDA support.")
logger.debug("===^^^ CUDA DEBUG INFO ^^^===")
|