2024-04-16 12:46:36 -07:00
(installation)=
2024-08-12 22:13:36 -07:00
# Installation
2023-09-22 14:58:38 -04:00
2024-06-21 14:50:02 -07:00
<!-- * freshness: { reviewed: '2024 - 06 - 18' } * -->
2024-04-16 12:46:36 -07:00
Using JAX requires installing two packages: `jax` , which is pure Python and
cross-platform, and `jaxlib` which contains compiled binaries, and requires
different builds for different operating systems and accelerators.
2023-09-22 14:58:38 -04:00
2024-08-28 10:46:47 -07:00
**Summary:** For most users, a typical JAX installation may look something like this:
2024-04-16 12:46:36 -07:00
* **CPU-only (Linux/macOS/Windows)**
```
2024-06-18 11:12:23 -04:00
pip install -U jax
2024-04-16 12:46:36 -07:00
```
2024-06-13 10:50:24 -04:00
* **GPU (NVIDIA, CUDA 12)**
2024-05-07 14:46:43 -07:00
```
pip install -U "jax[cuda12]"
```
2024-06-15 05:12:21 -07:00
* **TPU (Google Cloud TPU VM)**
2024-04-16 12:46:36 -07:00
```
2024-06-13 10:50:24 -04:00
pip install -U "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
2024-04-16 12:46:36 -07:00
```
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
(install-supported-platforms)=
## Supported platforms
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
The table below shows all supported platforms and installation options. Check if your setup is supported; and if it says _"yes"_ or _"experimental"_ , then click on the corresponding link to learn how to install JAX in greater detail.
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
| | Linux, x86_64 | Linux, aarch64 | macOS, Intel x86_64, AMD GPU | macOS, Apple Silicon, ARM-based | Windows, x86_64 | Windows WSL2, x86_64 |
|------------------|---------------------------------------|--------------------------------|----------------------------------------|----------------------------------------|-------------------------|-----------------------------------------|
| CPU | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` | {ref}`yes <install-cpu>` |
| NVIDIA GPU | {ref}`yes <install-nvidia-gpu>` | {ref}`yes <install-nvidia-gpu>` | no | n/a | no | {ref}`experimental <install-nvidia-gpu>` |
| Google Cloud TPU | {ref}`yes <install-google-tpu>` | n/a | n/a | n/a | n/a | n/a |
| AMD GPU | {ref}`experimental <install-amd-gpu>` | no | no | n/a | no | no |
| Apple GPU | n/a | no | {ref}`experimental <install-apple-gpu>` | {ref}`experimental <install-apple-gpu>` | n/a | n/a |
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
(install-cpu)=
2023-09-22 14:58:38 -04:00
## CPU
### pip installation: CPU
2024-04-16 12:46:36 -07:00
Currently, the JAX team releases `jaxlib` wheels for the following
2023-09-22 14:58:38 -04:00
operating systems and architectures:
2024-04-16 12:46:36 -07:00
- Linux, x86_64
2024-06-13 10:50:24 -04:00
- Linux, aarch64
2024-04-16 12:46:36 -07:00
- macOS, Intel
- macOS, Apple ARM-based
- Windows, x86_64 (*experimental*)
2023-09-22 14:58:38 -04:00
To install a CPU-only version of JAX, which might be useful for doing local
2024-04-16 12:46:36 -07:00
development on a laptop, you can run:
2023-09-22 14:58:38 -04:00
```bash
pip install --upgrade pip
2024-06-18 11:12:23 -04:00
pip install --upgrade jax
2023-09-22 14:58:38 -04:00
```
On Windows, you may also need to install the
[Microsoft Visual Studio 2019 Redistributable ](https://learn.microsoft.com/en-US/cpp/windows/latest-supported-vc-redist?view=msvc-170#visual-studio-2015-2017-2019-and-2022 )
if it is not already installed on your machine.
Other operating systems and architectures require building from source. Trying
to pip install on other operating systems and architectures may lead to `jaxlib`
not being installed alongside `jax` , although `jax` may successfully install
(but fail at runtime).
2024-04-16 12:46:36 -07:00
(install-nvidia-gpu)=
2023-09-22 14:58:38 -04:00
## NVIDIA GPU
2023-12-14 17:23:18 -08:00
JAX supports NVIDIA GPUs that have SM version 5.2 (Maxwell) or newer.
2023-09-22 14:58:38 -04:00
Note that Kepler-series GPUs are no longer supported by JAX since
NVIDIA has dropped support for Kepler GPUs in its software.
2024-04-16 12:46:36 -07:00
You must first install the NVIDIA driver. You're
recommended to install the newest driver available from NVIDIA, but the driver
version must be >= 525.60.13 for CUDA 12 on Linux.
2023-09-22 14:58:38 -04:00
If you need to use a newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
able to use the
[CUDA forward compatibility packages ](https://docs.nvidia.com/deploy/cuda-compatibility/ )
that NVIDIA provides for this purpose.
2024-04-16 12:46:36 -07:00
### pip installation: NVIDIA GPU (CUDA, installed via pip, easier)
There are two ways to install JAX with NVIDIA GPU support:
2023-10-06 09:11:49 -04:00
2024-04-16 12:46:36 -07:00
- Using NVIDIA CUDA and cuDNN installed from pip wheels
- Using a self-installed CUDA/cuDNN
The JAX team strongly recommends installing CUDA and cuDNN using the pip wheels,
since it is much easier!
2024-06-13 10:50:24 -04:00
NVIDIA has released CUDA pip packages only for x86_64 and aarch64; on other
platforms you must use a local installation of CUDA.
2023-09-22 14:58:38 -04:00
```bash
pip install --upgrade pip
2024-04-16 12:46:36 -07:00
# NVIDIA CUDA 12 installation
2023-09-22 14:58:38 -04:00
# Note: wheels only available on linux.
2024-05-07 14:46:43 -07:00
pip install --upgrade "jax[cuda12]"
2023-09-22 14:58:38 -04:00
```
2024-04-16 12:46:36 -07:00
If JAX detects the wrong version of the NVIDIA CUDA libraries, there are several things
you need to check:
* Make sure that `LD_LIBRARY_PATH` is not set, since `LD_LIBRARY_PATH` can
override the NVIDIA CUDA libraries.
* Make sure that the NVIDIA CUDA libraries installed are those requested by JAX.
2023-09-27 15:56:47 -04:00
Rerunning the installation command above should work.
2024-04-16 12:46:36 -07:00
### pip installation: NVIDIA GPU (CUDA, installed locally, harder)
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
If you prefer to use a preinstalled copy of NVIDIA CUDA, you must first
install NVIDIA [CUDA ](https://developer.nvidia.com/cuda-downloads ) and
[cuDNN ](https://developer.nvidia.com/CUDNN ).
2023-09-22 14:58:38 -04:00
2024-06-13 10:50:24 -04:00
JAX provides pre-built CUDA-compatible wheels for **Linux x86_64 and Linux aarch64 only** . Other
2023-09-22 14:58:38 -04:00
combinations of operating system and architecture are possible, but require
2024-04-16 12:46:36 -07:00
building from source (refer to {ref}`building-from-source` to learn more}.
2023-09-22 14:58:38 -04:00
You should use an NVIDIA driver version that is at least as new as your
2024-04-16 12:46:36 -07:00
[NVIDIA CUDA toolkit's corresponding driver version ](https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html#cuda-major-component-versions__table-cuda-toolkit-driver-versions ).
2023-09-22 14:58:38 -04:00
If you need to use a newer CUDA toolkit with an older driver, for example
on a cluster where you cannot update the NVIDIA driver easily, you may be
able to use the
[CUDA forward compatibility packages ](https://docs.nvidia.com/deploy/cuda-compatibility/ )
that NVIDIA provides for this purpose.
2024-03-25 11:44:40 -07:00
JAX currently ships one CUDA wheel variant:
2024-06-13 10:50:24 -04:00
| Built with | Compatible with |
|------------|--------------------|
| CUDA 12.3 | CUDA >=12.1 |
2024-07-29 09:27:54 -07:00
| CUDNN 9.1 | CUDNN >=9.1, < 10.0 |
2024-06-13 10:50:24 -04:00
| NCCL 2.19 | NCCL >=2.18 |
2023-09-22 14:58:38 -04:00
2023-09-27 15:56:47 -04:00
JAX checks the versions of your libraries, and will report an error if they are
not sufficiently new.
2024-03-25 11:44:40 -07:00
Setting the `JAX_SKIP_CUDA_CONSTRAINTS_CHECK` environment variable will disable
the check, but using older versions of CUDA may lead to errors, or incorrect
results.
2023-09-22 14:58:38 -04:00
2023-10-06 09:11:49 -04:00
NCCL is an optional dependency, required only if you are performing multi-GPU
computations.
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
To install, run:
2023-09-22 14:58:38 -04:00
```bash
pip install --upgrade pip
2024-06-16 14:38:05 +08:00
# Installs the wheel compatible with NVIDIA CUDA 12 and cuDNN 9.0 or newer.
2023-09-22 14:58:38 -04:00
# Note: wheels only available on linux.
2024-06-13 10:50:24 -04:00
pip install --upgrade "jax[cuda12_local]"
2023-09-22 14:58:38 -04:00
```
2024-04-16 12:46:36 -07:00
**These `pip` installations do not work with Windows, and may fail silently; refer to the table
[above ](#supported-platforms ).**
2023-09-22 14:58:38 -04:00
You can find your CUDA version with the command:
```bash
nvcc --version
```
2023-10-06 09:11:49 -04:00
JAX uses `LD_LIBRARY_PATH` to find CUDA libraries and `PATH` to find binaries
(`ptxas` , `nvlink` ). Please make sure that these paths point to the correct CUDA
installation.
2024-06-27 08:35:09 -07:00
JAX requires libdevice10.bc, which typically comes from the cuda-nvvm package.
Make sure that it is present in your CUDA installation.
2023-09-22 14:58:38 -04:00
2024-09-20 07:51:48 -07:00
Please let the JAX team know on [the GitHub issue tracker ](https://github.com/jax-ml/jax/issues )
2024-04-16 12:46:36 -07:00
if you run into any errors or problems with the pre-built wheels.
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
(docker-containers-nvidia-gpu)=
### NVIDIA GPU Docker containers
2023-09-22 14:58:38 -04:00
NVIDIA provides the [JAX
Toolbox](https://github.com/NVIDIA/JAX-Toolbox) containers, which are
bleeding edge containers containing nightly releases of jax and some
models/frameworks.
2024-04-16 12:46:36 -07:00
(install-google-tpu)=
## Google Cloud TPU
2023-09-22 14:58:38 -04:00
### pip installation: Google Cloud TPU
JAX provides pre-built wheels for
[Google Cloud TPU ](https://cloud.google.com/tpu/docs/users-guide-tpu-vm ).
To install JAX along with appropriate versions of `jaxlib` and `libtpu` , you can run
the following in your cloud TPU VM:
2024-04-16 12:46:36 -07:00
2023-09-22 14:58:38 -04:00
```bash
pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
```
2024-04-16 12:46:36 -07:00
For users of Colab (https://colab.research.google.com/), be sure you are
using *TPU v2* and not the older, deprecated TPU runtime.
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
(install-apple-gpu)=
## Apple Silicon GPU (ARM-based)
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
### pip installation: Apple ARM-based Silicon GPUs
2023-09-22 14:58:38 -04:00
2024-04-16 12:46:36 -07:00
Apple provides an experimental Metal plugin for Apple ARM-based GPU hardware. For details,
refer to
2023-09-22 14:58:38 -04:00
[Apple's JAX on Metal documentation ](https://developer.apple.com/metal/jax/ ).
2024-04-16 12:46:36 -07:00
**Note:** There are several caveats with the Metal plugin:
* The Metal plugin is new and experimental and has a number of
2024-09-20 07:51:48 -07:00
[known issues ](https://github.com/jax-ml/jax/issues?q=is%3Aissue+is%3Aopen+label%3A%22Apple+GPU+%28Metal%29+plugin%22 ).
2023-09-22 14:58:38 -04:00
Please report any issues on the JAX issue tracker.
2024-04-16 12:46:36 -07:00
* The Metal plugin currently requires very specific versions of `jax` and
2023-09-22 14:58:38 -04:00
`jaxlib` . This restriction will be relaxed over time as the plugin API
matures.
2024-04-16 12:46:36 -07:00
(install-amd-gpu)=
2023-09-26 09:49:06 -04:00
## AMD GPU
2024-04-16 12:46:36 -07:00
JAX has experimental ROCm support. There are two ways to install JAX:
2023-09-26 09:49:06 -04:00
2024-04-16 12:46:36 -07:00
* Use [AMD's Docker container ](https://hub.docker.com/r/rocm/jax ); or
* Build from source (refer to {ref}`building-from-source` — a section called _Additional notes for building a ROCM `jaxlib` for AMD GPUs_ ).
2023-09-26 09:49:06 -04:00
2024-04-16 12:46:36 -07:00
## Conda (community-supported)
2023-09-22 14:58:38 -04:00
### Conda installation
2024-04-16 12:46:36 -07:00
There is a community-supported Conda build of `jax` . To install it using `conda` ,
simply run:
2023-09-22 14:58:38 -04:00
```bash
conda install jax -c conda-forge
```
2024-04-16 12:46:36 -07:00
To install it on a machine with an NVIDIA GPU, run:
2023-09-22 14:58:38 -04:00
```bash
conda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia
```
Note the `cudatoolkit` distributed by `conda-forge` is missing `ptxas` , which
JAX requires. You must therefore either install the `cuda-nvcc` package from
the `nvidia` channel, or install CUDA on your machine separately so that `ptxas`
is in your path. The channel order above is important (`conda-forge` before
`nvidia` ).
If you would like to override which release of CUDA is used by JAX, or to
install the CUDA build on a machine without GPUs, follow the instructions in the
[Tips & tricks ](https://conda-forge.org/docs/user/tipsandtricks.html#installing-cuda-enabled-packages-like-tensorflow-and-pytorch )
section of the `conda-forge` website.
2024-04-16 12:46:36 -07:00
Go to the `conda-forge`
2023-09-22 14:58:38 -04:00
[jaxlib ](https://github.com/conda-forge/jaxlib-feedstock#installing-jaxlib ) and
[jax ](https://github.com/conda-forge/jax-feedstock#installing-jax ) repositories
for more details.
2024-04-16 12:46:36 -07:00
2024-06-14 15:25:04 -04:00
## JAX nightly installation
Nightly releases reflect the state of the main JAX repository at the time they are
built, and may not pass the full test suite.
2024-09-09 14:36:23 -04:00
Unlike the instructions for installing a JAX release, here we name all of JAX's
packages explicitly on the command line, so `pip` will upgrade them if a newer
version is available.
2024-06-14 15:25:04 -04:00
- CPU only:
```bash
2024-09-09 14:36:23 -04:00
pip install -U --pre jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
2024-06-14 15:25:04 -04:00
```
- Google Cloud TPU:
```bash
2024-09-16 11:46:23 -07:00
pip install -U --pre jax jaxlib libtpu-nightly requests -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
2024-06-14 15:25:04 -04:00
```
- NVIDIA GPU (CUDA 12):
```bash
2024-09-16 11:46:23 -07:00
pip install -U --pre jax jaxlib jax-cuda12-plugin[with_cuda] jax-cuda12-pjrt -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
2024-06-14 15:25:04 -04:00
```
- NVIDIA GPU (CUDA 12) legacy:
Use the following for historical nightly releases of monolithic CUDA jaxlibs.
You most likely do not want this; no further monolithic CUDA jaxlibs will be
built and those that exist will expire by Sep 2024. Use the "CUDA 12" option above.
```bash
pip install -U --pre jaxlib -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_cuda12_releases.html
```
2024-04-16 12:46:36 -07:00
(building-jax-from-source)=
2023-09-22 14:58:38 -04:00
## Building JAX from source
2023-11-03 12:34:22 -07:00
2024-04-16 12:46:36 -07:00
Refer to {ref}`building-from-source` .
2023-11-03 12:34:22 -07:00
2024-04-16 12:46:36 -07:00
## Installing older `jaxlib` wheels
Due to storage limitations on the Python package index, the JAX team periodically removes
older `jaxlib` wheels from the releases on http://pypi.org/project/jax. These can
still be installed directly via the URLs here. For example:
```bash
2023-11-03 12:34:22 -07:00
# Install jaxlib on CPU via the wheel archive
pip install jax[cpu]==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
# Install the jaxlib 0.3.25 CPU wheel directly
pip install jaxlib==0.3.25 -f https://storage.googleapis.com/jax-releases/jax_releases.html
```
2024-01-12 11:12:39 -08:00
For specific older GPU wheels, be sure to use the `jax_cuda_releases.html` URL; for example
2024-04-16 12:46:36 -07:00
```bash
2024-01-12 11:12:39 -08:00
pip install jaxlib==0.3.25+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```