Fix ROCm build README (#284)

This commit is contained in:
charleshofer 2025-03-18 14:35:36 -05:00 committed by GitHub
parent c46b4fc02b
commit dd7f96b27c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -174,7 +174,33 @@ Run the following command to verify that ROCm JAX is installed correctly:
Follow these steps to build JAX with ROCm support from source:
### Step 1: Clone the Repository
### Step 1: Install ROCm
Please follow [ROCm installation guide](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) to install ROCm on your system.
Once installed, verify ROCm installation using:
```Bash
> rocm-smi
========================================== ROCm System Management Interface ==========================================
==================================================== Concise Info ====================================================
Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
Name (20 chars) (Junction) (Socket) (Mem, Compute)
======================================================================================================================
0 [0x74a1 : 0x00] 50.0°C 170.0W NPS1, SPX 131Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
1 [0x74a1 : 0x00] 51.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
2 [0x74a1 : 0x00] 50.0°C 177.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
3 [0x74a1 : 0x00] 53.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
======================================================================================================================
================================================ End of ROCm SMI Log =================================================
```
### Step 2: Clone the Repository
Clone the ROCm-specific fork of JAX for the desired branch:
@ -183,13 +209,15 @@ Clone the ROCm-specific fork of JAX for the desired branch:
> cd jax
```
### Step 2: Build the Wheels
### Step 3: Build the Wheels
Run the following command to build the necessary wheels:
```Bash
> python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \
--rocm_version=60 --rocm_path=/opt/rocm-[version]
> python3 ./build/build.py build \
--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \
--rocm_path=/opt/rocm-[version] \
--clang_path=/opt/rocm-[version]/lib/llvm/bin/clang
```
This will generate three wheels in the `dist/` directory:
@ -198,10 +226,10 @@ This will generate three wheels in the `dist/` directory:
* jax-rocm-plugin (ROCm-specific plugin)
* jax-rocm-pjrt (ROCm-specific runtime)
### Step 3: Then install custom JAX using:
### Step 4: Then install custom JAX using:
```Bash
> python3 setup.py develop --user && pip3 -m pip install dist/*.whl
> python3 setup.py develop --user && python3 -m pip install dist/*.whl
```
### Simplified Build Script