rocm_jax/jax_plugins
Peter Hawkins b13733c13f Update JAX dependencies, extras, and documentation for plugins.
* Make jaxlib a direct dependency of jax.
* Remove mentions of monolithic CUDA installations from the JAX documentation.
* Drop the cuda12_pip extra and the cudnn version specific extras.
* Add a with_cuda extra to the jax-cuda12-plugin package, use it in jax's setup.py. This allows us to specify cuda extras in one place.
* Make a few small doc improvements.
2024-06-13 11:36:23 -04:00
..