libdevice.10.bc is removed from JAX wheels bundle.

The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package.

PiperOrigin-RevId: 647328834
This commit is contained in:
jax authors 2024-06-27 08:35:09 -07:00 committed by jax authors
parent 9df105c18f
commit 00528b9858
5 changed files with 6 additions and 13 deletions

View File

@ -15,6 +15,9 @@ Remember to align the itemized text with the first line of an item within a list
supported version until December 2024.
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
pip wheels.
## jaxlib 0.4.31

View File

@ -173,6 +173,9 @@ JAX uses `LD_LIBRARY_PATH` to find CUDA libraries and `PATH` to find binaries
(`ptxas`, `nvlink`). Please make sure that these paths point to the correct CUDA
installation.
JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package.
Make sure that it is present in your CUDA installation.
Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues)
if you run into any errors or problems with the pre-built wheels.

View File

@ -133,11 +133,6 @@ def _cuda_path() -> str | None:
# both of the things XLA looks for in the cuda path, namely bin/ptxas and
# nvvm/libdevice/libdevice.10.bc
path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc"
if path.is_dir():
return str(path)
# Failing that, we use the copy of libdevice.10.bc we include with jaxlib and
# hope that the user has ptxas in their PATH.
path = _jaxlib_path / "cuda"
if path.is_dir():
return str(path)
return None

View File

@ -98,10 +98,6 @@ def prepare_wheel_cuda(
write_setup_cfg(sources_path, cpu)
plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin"
copy_runfiles(
dst_dir=plugin_dir / "nvvm" / "libdevice",
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
)
copy_runfiles(
dst_dir=plugin_dir,
src_files=[

View File

@ -221,10 +221,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
)
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels:
copy_runfiles(
dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice",
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
)
copy_runfiles(
dst_dir=jaxlib_dir / "cuda",
src_files=[