diff --git a/build/rocm/dev_build_rocm.py b/build/rocm/dev_build_rocm.py index 2be64152f..aa5754b78 100755 --- a/build/rocm/dev_build_rocm.py +++ b/build/rocm/dev_build_rocm.py @@ -77,13 +77,14 @@ def build_jax_xla(xla_path, rocm_version, rocm_target, use_clang, clang_path): build_command = [ "python3", "./build/build.py", - "--enable_rocm", - "--build_gpu_plugin", - "--gpu_plugin_rocm_version=60", + "build" f"--use_clang={str(use_clang).lower()}", + "--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt" + "--rocm_path=%/opt/rocm-{rocm_version}/", + "--rocm_version=60", f"--rocm_amdgpu_targets={rocm_target}", - f"--rocm_path=/opt/rocm-{rocm_version}/", bazel_options, + "--verbose" ] if clang_option: diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index deb6ab703..ec825f40b 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -93,11 +93,12 @@ def build_jaxlib_wheel( cmd = [ "python", "build/build.py", - "--enable_rocm", - "--build_gpu_plugin", - "--gpu_plugin_rocm_version=60", + "build" + "--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt" "--rocm_path=%s" % rocm_path, + "--rocm_version=60", "--use_clang=%s" % use_clang, + "--verbose" ] # Add clang path if clang is used.