mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00

* Make jaxlib a direct dependency of jax. * Remove mentions of monolithic CUDA installations from the JAX documentation. * Drop the cuda12_pip extra and the cudnn version specific extras. * Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place. * Make a few small doc improvements.
jaxlib: support library for JAX
jaxlib is the support library for JAX. While JAX itself is a pure Python package, jaxlib contains the binary (C/C++) parts of the library, including Python bindings, the XLA compiler, the PJRT runtime, and a handful of handwritten kernels. For more information, including installation and build instructions, refer to main JAX README: https://github.com/google/jax/.