Add build instructions to build jaxlib with cuda plugin from source.

PiperOrigin-RevId: 601231525
This commit is contained in:
Jieying Luo 2024-01-24 14:15:14 -08:00 committed by jax authors
parent 4646c64f54
commit cfb6250158

View File

@ -52,15 +52,19 @@ You can install the necessary Python dependencies using `pip`:
pip install numpy wheel build
```
To build `jaxlib` without CUDA GPU or TPU support (CPU only), you can run:
To build `jaxlib` for CPU or TPU, you can run:
```
python build/build.py
pip install dist/*.whl # installs jaxlib (includes XLA)
```
To build `jaxlib` with CUDA support, use `python build/build.py --enable_cuda`;
to build with TPU support, use `python build/build.py`.
There are two ways to build `jaxlib` with CUDA support: (1) use
`python build/build.py --enable_cuda` to generate a jaxlib wheel with cuda
support, or (2) use
`python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12`
to generate three wheels (jaxlib without cuda, jax-cuda-plugin,
and jax-cuda-pjrt). You can set `gpu_plugin_cuda_version` to 11 or 12.
See `python build/build.py --help` for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. Here