mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
libdevice.10.bc
is removed from JAX wheels bundle.
The recommended source of JAX wheels is `pip`, and NVIDIA dependencies are installed automatically when JAX is installed via `pip install`. `libdevice` gets installed from `nvidia-cuda-nvcc-cu12` package. PiperOrigin-RevId: 647328834
This commit is contained in:
parent
9df105c18f
commit
00528b9858
@ -15,6 +15,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
supported version until December 2024.
|
||||
* {func}`jax.numpy.ceil`, {func}`jax.numpy.floor` and {func}`jax.numpy.trunc` now return the output
|
||||
of the same dtype as the input, i.e. no longer upcast integer or boolean inputs to floating point.
|
||||
* `libdevice.10.bc` is no longer bundled with CUDA wheels. It must be
|
||||
installed either as a part of local CUDA installation, or via NVIDIA's CUDA
|
||||
pip wheels.
|
||||
|
||||
## jaxlib 0.4.31
|
||||
|
||||
|
@ -173,6 +173,9 @@ JAX uses `LD_LIBRARY_PATH` to find CUDA libraries and `PATH` to find binaries
|
||||
(`ptxas`, `nvlink`). Please make sure that these paths point to the correct CUDA
|
||||
installation.
|
||||
|
||||
JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package.
|
||||
Make sure that it is present in your CUDA installation.
|
||||
|
||||
Please let the JAX team know on [the GitHub issue tracker](https://github.com/google/jax/issues)
|
||||
if you run into any errors or problems with the pre-built wheels.
|
||||
|
||||
|
@ -133,11 +133,6 @@ def _cuda_path() -> str | None:
|
||||
# both of the things XLA looks for in the cuda path, namely bin/ptxas and
|
||||
# nvvm/libdevice/libdevice.10.bc
|
||||
path = _jaxlib_path.parent / "nvidia" / "cuda_nvcc"
|
||||
if path.is_dir():
|
||||
return str(path)
|
||||
# Failing that, we use the copy of libdevice.10.bc we include with jaxlib and
|
||||
# hope that the user has ptxas in their PATH.
|
||||
path = _jaxlib_path / "cuda"
|
||||
if path.is_dir():
|
||||
return str(path)
|
||||
return None
|
||||
|
@ -98,10 +98,6 @@ def prepare_wheel_cuda(
|
||||
write_setup_cfg(sources_path, cpu)
|
||||
|
||||
plugin_dir = sources_path / f"jax_cuda{cuda_version}_plugin"
|
||||
copy_runfiles(
|
||||
dst_dir=plugin_dir / "nvvm" / "libdevice",
|
||||
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
|
||||
)
|
||||
copy_runfiles(
|
||||
dst_dir=plugin_dir,
|
||||
src_files=[
|
||||
|
@ -221,10 +221,6 @@ def prepare_wheel(sources_path: pathlib.Path, *, cpu, skip_gpu_kernels):
|
||||
)
|
||||
|
||||
if exists(f"__main__/jaxlib/cuda/_solver.{pyext}") and not skip_gpu_kernels:
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "cuda" / "nvvm" / "libdevice",
|
||||
src_files=["local_config_cuda/cuda/cuda/nvvm/libdevice/libdevice.10.bc"],
|
||||
)
|
||||
copy_runfiles(
|
||||
dst_dir=jaxlib_dir / "cuda",
|
||||
src_files=[
|
||||
|
Loading…
x
Reference in New Issue
Block a user