1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-26 06:06:07 +00:00

Fix ROCm build README ()

This commit is contained in:
charleshofer 2025-03-18 14:35:36 -05:00 committed by GitHub
parent c46b4fc02b
commit dd7f96b27c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -174,7 +174,33 @@ Run the following command to verify that ROCm JAX is installed correctly:
Follow these steps to build JAX with ROCm support from source: Follow these steps to build JAX with ROCm support from source:
### Step 1: Clone the Repository ### Step 1: Install ROCm
Please follow [ROCm installation guide](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html) to install ROCm on your system.
Once installed, verify ROCm installation using:
```Bash
> rocm-smi
========================================== ROCm System Management Interface ==========================================
==================================================== Concise Info ====================================================
Device [Model : Revision] Temp Power Partitions SCLK MCLK Fan Perf PwrCap VRAM% GPU%
Name (20 chars) (Junction) (Socket) (Mem, Compute)
======================================================================================================================
0 [0x74a1 : 0x00] 50.0°C 170.0W NPS1, SPX 131Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
1 [0x74a1 : 0x00] 51.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
2 [0x74a1 : 0x00] 50.0°C 177.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
3 [0x74a1 : 0x00] 53.0°C 176.0W NPS1, SPX 132Mhz 900Mhz 0% auto 750.0W 0% 0%
AMD Instinct MI300X
======================================================================================================================
================================================ End of ROCm SMI Log =================================================
```
### Step 2: Clone the Repository
Clone the ROCm-specific fork of JAX for the desired branch: Clone the ROCm-specific fork of JAX for the desired branch:
@ -183,13 +209,15 @@ Clone the ROCm-specific fork of JAX for the desired branch:
> cd jax > cd jax
``` ```
### Step 2: Build the Wheels ### Step 3: Build the Wheels
Run the following command to build the necessary wheels: Run the following command to build the necessary wheels:
```Bash ```Bash
> python3 ./build/build.py build --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \ > python3 ./build/build.py build \
--rocm_version=60 --rocm_path=/opt/rocm-[version] --wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt \
--rocm_path=/opt/rocm-[version] \
--clang_path=/opt/rocm-[version]/lib/llvm/bin/clang
``` ```
This will generate three wheels in the `dist/` directory: This will generate three wheels in the `dist/` directory:
@ -198,10 +226,10 @@ This will generate three wheels in the `dist/` directory:
* jax-rocm-plugin (ROCm-specific plugin) * jax-rocm-plugin (ROCm-specific plugin)
* jax-rocm-pjrt (ROCm-specific runtime) * jax-rocm-pjrt (ROCm-specific runtime)
### Step 3: Then install custom JAX using: ### Step 4: Then install custom JAX using:
```Bash ```Bash
> python3 setup.py develop --user && pip3 -m pip install dist/*.whl > python3 setup.py develop --user && python3 -m pip install dist/*.whl
``` ```
### Simplified Build Script ### Simplified Build Script