Merge pull request #17910 from hawkinsp:rocm

PiperOrigin-RevId: 570555612
This commit is contained in:
jax authors 2023-10-03 18:45:12 -07:00
commit 6c0924920e

View File

@ -137,30 +137,38 @@ To build with debug information, add the flag `--bazel_options='--copt=/Z7'`.
### Additional notes for building a ROCM `jaxlib` for AMD GPUs
You need several ROCM/HIP libraries installed to build for ROCM. For
example, on a Ubuntu machine with AMD's `apt` repositories available, you need
a number of packages installed:
example, on a Ubuntu machine with
[AMD's `apt` repositories available](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html),
you need a number of packages installed:
```
sudo apt install miopen-hip hipfft-dev rocrand-dev hipsparse-dev hipsolver-dev \
rccl-dev rccl hip-dev rocfft-dev roctracer-dev hipblas-dev rocm-device-libs
```
AMD's fork of the XLA repository may include fixes
not present in the upstream repository. To use AMD's fork, you should clone
their repository:
```
git clone https://github.com/ROCmSoftwarePlatform/tensorflow-upstream.git
```
To build jaxlib with ROCM support, you can run the following build command,
suitably adjusted for your paths and ROCM version.
```
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.3.0 \
--bazel_options=--override_repository=xla=/path/to/xla-upstream
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0
```
AMD's fork of the XLA repository may include fixes not present in the upstream
XLA repository. If you experience problems with the upstream repository, you can
try AMD's fork, by cloning their repository:
```
git clone https://github.com/ROCmSoftwarePlatform/xla.git
```
and override the XLA repository with which JAX is built:
```
python build/build.py --enable_rocm --rocm_path=/opt/rocm-5.7.0 \
--bazel_options=--override_repository=xla=/path/to/xla-rocm
```
## Installing `jax`
Once `jaxlib` has been installed, you can install `jax` by running: