diff --git a/CHANGELOG.md b/CHANGELOG.md index 83be5f478..25f9df635 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -127,6 +127,8 @@ Remember to align the itemized text with the first line of an item within a list * Changes + * JAX now supports CUDA 12.3 and CUDA 11.8. Support for CUDA 12.2 has been + dropped. * `cost_analysis` now works with cross-compiled `Compiled` objects (i.e. when using `.lower().compile()` with a topology object, e.g., to compile for Cloud TPU from a non-TPU computer). diff --git a/docs/installation.md b/docs/installation.md index 5dba6757e..a6fd97948 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -114,7 +114,7 @@ able to use the that NVIDIA provides for this purpose. JAX currently ships two CUDA wheel variants: -* CUDA 12.2, cuDNN 8.9, NCCL 2.16 +* CUDA 12.3, cuDNN 8.9, NCCL 2.16 * CUDA 11.8, cuDNN 8.6, NCCL 2.16 You may use a JAX wheel provided the major version of your CUDA, cuDNN, and NCCL