libdevice.10.bc is removed from JAX wheels bundle.

The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.

PiperOrigin-RevId: 647328834
This commit is contained in:
jax authors 2024-06-27 08:35:09 -07:00 committed by jax authors
parent 9df105c18f
commit 00528b9858
5 changed files with 6 additions and 13 deletions

View File

@ -15,6 +15,9 @@ Remember to align the itemized text with the first line of an item within a list
supported version until December 2024. supported version until December 2024.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output * {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point. of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
## jaxlib 0.4.31 ## jaxlib 0.4.31

View File

@ -173,6 +173,9 @@ JAX uses `LD_LIBRARY_PATH` to find CUDA libraries and `PATH` to find binaries
(`ptxas`, `nvlink`). Please make sure that these paths point to the correct CUDA (`ptxas`, `nvlink`). Please make sure that these paths point to the correct CUDA
installation. installation.
JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package.
Make sure that it is present in your CUDA installation.
Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues) Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues)
if you run into any errors or problems with the pre-built wheels. if you run into any errors or problems with the pre-built wheels.

View File

@ -133,11 +133,6 @@ def _cuda_path() -> str | None:
# both of the things XLA looks for in the cuda path, namely bin/ptxas and # both of the things XLA looks for in the cuda path, namely bin/ptxas and
# nvvm/libdevice/libdevice.10.bc # nvvm/libdevice/libdevice.10.bc
path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc" path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc"
if path.is_dir():
return str(path)
# Failing that, we use the copy of libdevice.10.bc we include with jaxlib and
# hope that the user has ptxas in their PATH.
path = _jaxlib_path / "cuda"
if path.is_dir(): if path.is_dir():
return str(path) return str(path)
return None return None

View File

@ -98,10 +98,6 @@ def prepare_wheel_cuda(
write_setup_cfg(sources_path, cpu) write_setup_cfg(sources_path, cpu)
plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin" plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin"
copy_runfiles(
dst_dir=plugin_dir / "nvvm" / "libdevice",
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
)
copy_runfiles( copy_runfiles(
dst_dir=plugin_dir, dst_dir=plugin_dir,
src_files=[ src_files=[

View File

@ -221,10 +221,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
) )
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels: if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels:
copy_runfiles(
dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice",
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
)
copy_runfiles( copy_runfiles(
dst_dir=jaxlib_dir / "cuda", dst_dir=jaxlib_dir / "cuda",
src_files=[ src_files=[