mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Add build instructions to build jaxlib with cuda plugin from source.
PiperOrigin-RevId: 601231525
This commit is contained in:
parent
4646c64f54
commit
cfb6250158
@ -52,15 +52,19 @@ You can install the necessary Python dependencies using `pip`:
|
||||
pip install numpy wheel build
|
||||
```
|
||||
|
||||
To build `jaxlib` without CUDA GPU or TPU support (CPU only), you can run:
|
||||
To build `jaxlib` for CPU or TPU, you can run:
|
||||
|
||||
```
|
||||
python build/build.py
|
||||
pip install dist/*.whl # installs jaxlib (includes XLA)
|
||||
```
|
||||
|
||||
To build `jaxlib` with CUDA support, use `python build/build.py --enable_cuda`;
|
||||
to build with TPU support, use `python build/build.py`.
|
||||
There are two ways to build `jaxlib` with CUDA support: (1) use
|
||||
`python build/build.py --enable_cuda` to generate a jaxlib wheel with cuda
|
||||
support, or (2) use
|
||||
`python build/build.py --enable_cuda --build_gpu_plugin --gpu_plugin_cuda_version=12`
|
||||
to generate three wheels (jaxlib without cuda, jax-cuda-plugin,
|
||||
and jax-cuda-pjrt). You can set `gpu_plugin_cuda_version` to 11 or 12.
|
||||
|
||||
See `python build/build.py --help` for configuration options, including ways to
|
||||
specify the paths to CUDA and CUDNN, which you must have installed. Here
|
||||
|
Loading…
x
Reference in New Issue
Block a user