[ROCm] Remove broken legacy env vars

These env vars are no longer used or need and were
being set incorrectly.

[ROCm] Use specific amdgpu version for EL8 systems

We were always installing the latest driver versions
but this had some side effects when yum would try
to download index files from a URL with changing content.

[ROCm] Fix formatting on python files

Reformatted with black
This commit is contained in:
Mathew Odden 2024-09-24 17:41:15 -05:00 committed by Ruturaj4
parent aa73aa0021
commit 9ff891dfa1
4 changed files with 14 additions and 12 deletions

View File

@ -9,7 +9,7 @@ RUN --mount=type=cache,target=/var/cache/apt \
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCM
# Install ROCm
ARG ROCM_VERSION=6.0.0
ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION}
ENV ROCM_PATH=${ROCM_PATH}
@ -19,13 +19,8 @@ RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
--mount=type=cache,target=/var/cache/apt \
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM
# Set up paths
ENV HCC_HOME=$ROCM_PATH/hcc
ENV HIP_PATH=$ROCM_PATH/
ENV OPENCL_ROOT=$ROCM_PATH/opencl
ENV PATH="$HCC_HOME/bin:$HIP_PATH/bin:${PATH}"
# add ROCm bins to PATH
ENV PATH="$ROCM_PATH/bin:${PATH}"
ENV PATH="$OPENCL_ROOT/bin:${PATH}"
ENV PATH="/root/bin:/root/.local/bin:$PATH"
# install pyenv and python build dependencies

View File

@ -34,8 +34,12 @@ def image_by_name(name):
def dist_wheels(
rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="",
compiler="gcc"
rocm_version,
python_versions,
xla_path,
rocm_build_job="",
rocm_build_num="",
compiler="gcc",
):
if xla_path:
xla_path = os.path.abspath(xla_path)
@ -260,7 +264,7 @@ def parse_args():
p.add_argument(
"--compiler",
choices=["gcc", "clang"],
help="Compiler backend to use when compiling jax/jaxlib"
help="Compiler backend to use when compiling jax/jaxlib",
)
subp = p.add_subparsers(dest="action", required=True)

View File

@ -56,7 +56,9 @@ def update_rocm_targets(rocm_path, targets):
open(version_fp, "a").close()
def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"):
def build_jaxlib_wheel(
jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"
):
use_clang = "true" if compiler == "clang" else "false"
cmd = [
"python",

View File

@ -320,11 +320,12 @@ gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key
"""
[amdgpu]
name=amdgpu
baseurl=https://repo.radeon.com/amdgpu/latest/rhel/8.8/main/x86_64/
baseurl=https://repo.radeon.com/amdgpu/%s/rhel/8.8/main/x86_64/
enabled=1
gpgcheck=1
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key
"""
% rocm_version_str
)