From 00528b9858a0a6da2c32434f46fa6ab53149cebc Mon Sep 17 00:00:00 2001 From: jax authors Date: Thu, 27 Jun 2024 08:35:09 -0700 Subject: [PATCH] `libdevice.10.bc` is removed from JAX wheels bundle. The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package. PiperOrigin-RevId: 647328834 --- CHANGELOG.md | 3 +++ docs/installation.md | 3 +++ jax/_src/lib/__init__.py | 5 ----- jaxlib/tools/build_gpu_kernels_wheel.py | 4 ---- jaxlib/tools/build_wheel.py | 4 ---- 5 files changed, 6 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 85ec11685..be8f6f8d1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ Remember to align the itemized text with the first line of an item within a list supported version until December 2024. * {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point. + * `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be + installed either as a part of local CUDA installation, or via NVIDIA's CUDA + pip wheels. ## jaxlib 0.4.31 diff --git a/docs/installation.md b/docs/installation.md index 82a0fde31..fa77d1fc2 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -173,6 +173,9 @@ JAX uses `LD_LIBRARY_PATH` to find CUDA libraries and `PATH` to find binaries (`ptxas`, `nvlink`). Please make sure that these paths point to the correct CUDA installation. +JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package. +Make sure that it is present in your CUDA installation. + Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues) if you run into any errors or problems with the pre-built wheels. diff --git a/jax/_src/lib/__init__.py b/jax/_src/lib/__init__.py index 3d426ff37..b2bcc53a5 100644 --- a/jax/_src/lib/__init__.py +++ b/jax/_src/lib/__init__.py @@ -133,11 +133,6 @@ def _cuda_path() -> str | None: # both of the things XLA looks for in the cuda path, namely bin/ptxas and # nvvm/libdevice/libdevice.10.bc path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc" - if path.is_dir(): - return str(path) - # Failing that, we use the copy of libdevice.10.bc we include with jaxlib and - # hope that the user has ptxas in their PATH. - path = _jaxlib_path / "cuda" if path.is_dir(): return str(path) return None diff --git a/jaxlib/tools/build_gpu_kernels_wheel.py b/jaxlib/tools/build_gpu_kernels_wheel.py index f0da3d253..28d2806a7 100644 --- a/jaxlib/tools/build_gpu_kernels_wheel.py +++ b/jaxlib/tools/build_gpu_kernels_wheel.py @@ -98,10 +98,6 @@ def prepare_wheel_cuda( write_setup_cfg(sources_path, cpu) plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin" - copy_runfiles( - dst_dir=plugin_dir / "nvvm" / "libdevice", - src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"], - ) copy_runfiles( dst_dir=plugin_dir, src_files=[ diff --git a/jaxlib/tools/build_wheel.py b/jaxlib/tools/build_wheel.py index 68a0e0aba..62864f7ad 100644 --- a/jaxlib/tools/build_wheel.py +++ b/jaxlib/tools/build_wheel.py @@ -221,10 +221,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels): ) if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels: - copy_runfiles( - dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice", - src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"], - ) copy_runfiles( dst_dir=jaxlib_dir / "cuda", src_files=[