mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Update _cuda_path
- Remove jax-relative module path test - Use `$CUDA_ROOT` environment variable if available - Use `cuda_nvcc` module's path if installed
This commit is contained in:
parent
e8cea0d7a4
commit
1ea8e3c29d
@ -18,6 +18,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import gc
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from typing import Any
|
||||
@ -128,13 +129,29 @@ mlir_api_version = xla_client.mlir_api_version
|
||||
# TODO(rocm): check if we need the same for rocm.
|
||||
|
||||
def _cuda_path() -> str | None:
|
||||
_jaxlib_path = pathlib.Path(jaxlib.__file__).parent
|
||||
# If the pip package nvidia-cuda-nvcc-cu11 is installed, it should have
|
||||
# both of the things XLA looks for in the cuda path, namely bin/ptxas and
|
||||
# nvvm/libdevice/libdevice.10.bc
|
||||
path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc"
|
||||
if path.is_dir():
|
||||
return str(path)
|
||||
def _try_cuda_root_environment_variable() -> str | None:
|
||||
"""Use `CUDA_ROOT` environment variable if set."""
|
||||
return os.environ.get('CUDA_ROOT', None)
|
||||
|
||||
def _try_cuda_nvcc_import() -> str | None:
|
||||
"""Try to import `cuda_nvcc` and get its path directly.
|
||||
|
||||
If the pip package `nvidia-cuda-nvcc-cu11` is installed, it should have
|
||||
both of the things XLA looks for in the cuda path, namely `bin/ptxas` and
|
||||
`nvvm/libdevice/libdevice.10.bc`.
|
||||
"""
|
||||
try:
|
||||
from nvidia import cuda_nvcc
|
||||
except ImportError:
|
||||
return None
|
||||
cuda_nvcc_path = pathlib.Path(cuda_nvcc.__file__).parent
|
||||
return str(cuda_nvcc_path)
|
||||
|
||||
if (path := _try_cuda_root_environment_variable()) is not None:
|
||||
return path
|
||||
elif (path := _try_cuda_nvcc_import()) is not None:
|
||||
return path
|
||||
|
||||
return None
|
||||
|
||||
cuda_path = _cuda_path()
|
||||
|
@ -28,6 +28,7 @@ module = [
|
||||
"jraph.*",
|
||||
"libtpu.*",
|
||||
"matplotlib.*",
|
||||
"nvidia.*",
|
||||
"numpy.*",
|
||||
"opt_einsum.*",
|
||||
"optax.*",
|
||||
|
Loading…
x
Reference in New Issue
Block a user