mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00

* Make jaxlib a direct dependency of jax. * Remove mentions of monolithic CUDA installations from the JAX documentation. * Drop the cuda12_pip extra and the cudnn version specific extras. * Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place. * Make a few small doc improvements.