Update _cuda_path

- Remove jax-relative module path test
- Use `$CUDA_ROOT` environment variable if available
- Use `cuda_nvcc` module's path if installed
This commit is contained in:
Kristian Hartikainen 2024-10-05 16:49:28 +03:00
parent e8cea0d7a4
commit 1ea8e3c29d
2 changed files with 25 additions and 7 deletions

View File

@ -18,6 +18,7 @@
from __future__ import annotations
import gc
import os
import pathlib
import re
from typing import Any
@ -128,13 +129,29 @@ mlir_api_version = xla_client.mlir_api_version
# TODO(rocm): check if we need the same for rocm.
def _cuda_path() -> str | None:
_jaxlib_path = pathlib.Path(jaxlib.__file__).parent
# If the pip package nvidia-cuda-nvcc-cu11 is installed, it should have
# both of the things XLA looks for in the cuda path, namely bin/ptxas and
# nvvm/libdevice/libdevice.10.bc
path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc"
if path.is_dir():
return str(path)
def _try_cuda_root_environment_variable() -> str | None:
"""Use `CUDA_ROOT` environment variable if set."""
return os.environ.get('CUDA_ROOT', None)
def _try_cuda_nvcc_import() -> str | None:
"""Try to import `cuda_nvcc` and get its path directly.
If the pip package `nvidia-cuda-nvcc-cu11` is installed, it should have
both of the things XLA looks for in the cuda path, namely `bin/ptxas` and
`nvvm/libdevice/libdevice.10.bc`.
"""
try:
from nvidia import cuda_nvcc
except ImportError:
return None
cuda_nvcc_path = pathlib.Path(cuda_nvcc.__file__).parent
return str(cuda_nvcc_path)
if (path := _try_cuda_root_environment_variable()) is not None:
return path
elif (path := _try_cuda_nvcc_import()) is not None:
return path
return None
cuda_path = _cuda_path()

View File

@ -28,6 +28,7 @@ module = [
"jraph.*",
"libtpu.*",
"matplotlib.*",
"nvidia.*",
"numpy.*",
"opt_einsum.*",
"optax.*",