diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index 0bcc89f49..d1867e6a5 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -9,7 +9,7 @@ RUN --mount=type=cache,target=/var/cache/apt \ ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS} -# Install ROCM +# Install ROCm ARG ROCM_VERSION=6.0.0 ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} ENV ROCM_PATH=${ROCM_PATH} @@ -19,13 +19,8 @@ RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ --mount=type=cache,target=/var/cache/apt \ python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM -# Set up paths -ENV HCC_HOME=$ROCM_PATH/hcc -ENV HIP_PATH=$ROCM_PATH/ -ENV OPENCL_ROOT=$ROCM_PATH/opencl -ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}" +# add ROCm bins to PATH ENV PATH="$ROCM_PATH/bin:${PATH}" -ENV PATH="$OPENCL_ROOT/bin:${PATH}" ENV PATH="/root/bin:/root/.local/bin:$PATH" # install pyenv and python build dependencies diff --git a/build/rocm/ci_build b/build/rocm/ci_build index aeb0201e2..9fb0ebd77 100755 --- a/build/rocm/ci_build +++ b/build/rocm/ci_build @@ -34,8 +34,12 @@ def image_by_name(name): def dist_wheels( - rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="", - compiler="gcc" + rocm_version, + python_versions, + xla_path, + rocm_build_job="", + rocm_build_num="", + compiler="gcc", ): if xla_path: xla_path = os.path.abspath(xla_path) @@ -260,7 +264,7 @@ def parse_args(): p.add_argument( "--compiler", choices=["gcc", "clang"], - help="Compiler backend to use when compiling jax/jaxlib" + help="Compiler backend to use when compiling jax/jaxlib", ) subp = p.add_subparsers(dest="action", required=True) diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py index b6dd1256e..f0631f099 100644 --- a/build/rocm/tools/build_wheels.py +++ b/build/rocm/tools/build_wheels.py @@ -56,7 +56,9 @@ def update_rocm_targets(rocm_path, targets): open(version_fp, "a").close() -def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"): +def build_jaxlib_wheel( + jax_path, rocm_path, python_version, xla_path=None, compiler="gcc" +): use_clang = "true" if compiler == "clang" else "false" cmd = [ "python", diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py index 5334bf40e..2bcae5f90 100644 --- a/build/rocm/tools/get_rocm.py +++ b/build/rocm/tools/get_rocm.py @@ -320,11 +320,12 @@ gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key """ [amdgpu] name=amdgpu -baseurl=https://repo.radeon.com/amdgpu/latest/rhel/8.8/main/x86_64/ +baseurl=https://repo.radeon.com/amdgpu/%s/rhel/8.8/main/x86_64/ enabled=1 gpgcheck=1 gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key """ + % rocm_version_str )