1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

[ROCm] Bring up clang support for JAX+XLA

* Add clang path

* bazelrc env fixes

* Fix wheelhouse installation and preserve wheels

* dockerfile changes

* Add target.lst

* Change target architectures

* Install bzip2 and sqlite packages
This commit is contained in:
Ruturaj4 2024-08-19 16:32:26 -05:00 committed by Zahid Iqbal
parent 9dbbb3a391
commit dfb7db0e75
8 changed files with 120 additions and 21 deletions

@ -107,10 +107,18 @@ build:nvcc_clang --config=cuda_clang
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm --repo_env TF_NEED_ROCM=1
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain
build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true
build:rocm_base --repo_env TF_NEED_ROCM=1
build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100"
# Build with hipcc for ROCm and clang for the host.
build:rocm --config=rocm_base
build:rocm --action_env=TF_ROCM_CLANG="1"
build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
build:rocm --copt=-Wno-gnu-offsetof-extensions
build:rocm --copt=-Qunused-arguments
build:rocm --action_env=TF_HIPCC_CLANG="1"
build:nonccl --define=no_nccl_support=true

@ -298,9 +298,12 @@ def write_bazelrc(*, remote_build,
f.write(
f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n')
if enable_rocm:
f.write("build --config=rocm\n")
f.write("build --config=rocm_base\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if use_clang:
f.write("build --config=rocm\n")
f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n")
if python_version:
f.write(
"build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format(
@ -482,7 +485,7 @@ def main():
help="A comma-separated list of CUDA compute capabilities to support.")
parser.add_argument(
"--rocm_amdgpu_targets",
default="gfx900,gfx906,gfx908,gfx90a,gfx1030",
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100",
help="A comma-separated list of ROCm amdgpu targets to support.")
parser.add_argument(
"--rocm_path",

@ -5,8 +5,14 @@ FROM ubuntu:20.04 AS rocm_base
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && apt-get install -y python3 python-is-python3
# Install bzip2 and sqlite3 packages
RUN apt-get update && apt-get install -y \
sqlite3 libsqlite3-dev \
libbz2-dev \
&& rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCm
@ -70,6 +76,7 @@ FROM rocm_base AS rt_build
ARG JAX_VERSION
ARG JAX_COMMIT
ARG XLA_COMMIT
ARG JAX_USE_CLANG
LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
com.amdgpu.python_version="$PYTHON_VERSION" \
@ -77,7 +84,15 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
com.amdgpu.jax_commit="$JAX_COMMIT" \
com.amdgpu.xla_commit="$XLA_COMMIT"
# Create a directory to copy and retain the wheels in the image.
RUN mkdir -p /rocm_jax_wheels
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
cp /wheelhouse/* /rocm_jax_wheels/ && \
ls -lah /wheelhouse && \
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
pip3 install wheelhouse/*rocm60*.whl

@ -7,3 +7,13 @@ ARG ROCM_BUILD_NUM
RUN --mount=type=cache,target=/var/cache/dnf \
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
RUN printf '%s\n' > /opt/rocm/bin/target.lst ${GPU_DEVICE_TARGETS}
# Install LLVM 18 and dependencies.
RUN --mount=type=cache,target=/var/cache/dnf \
dnf install -y wget && dnf clean all
RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \
mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \
make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project

@ -3,8 +3,14 @@ FROM ubuntu:22.04
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && apt-get install -y python3 python-is-python3
# Install bzip2 and sqlite3 packages
RUN apt-get update && apt-get install -y \
sqlite3 libsqlite3-dev \
libbz2-dev \
&& rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
# Install ROCM
@ -61,4 +67,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
pip3 install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
ls -lah /wheelhouse && \
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
pip3 install wheelhouse/*rocm60*.whl

@ -3,6 +3,12 @@ FROM ubuntu:24.04
RUN --mount=type=cache,target=/var/cache/apt \
apt-get update && apt-get install -y python3 python-is-python3 python3-pip
# Install bzip2 and sqlite3 packages
RUN apt-get update && apt-get install -y \
sqlite3 libsqlite3-dev \
libbz2-dev \
&& rm -rf /var/lib/apt/lists/*
# Add target file to help determine which device(s) to build for
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
@ -60,4 +66,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
pip3 install --break-system-packages --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
ls -lah /wheelhouse && \
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
pip3 install wheelhouse/*rocm60*.whl

@ -56,6 +56,36 @@ def update_rocm_targets(rocm_path, targets):
open(version_fp, "a").close()
def find_clang_path():
llvm_base_path = "/usr/lib/"
# Search for llvm directories and pick the highest version.
llvm_dirs = [d for d in os.listdir(llvm_base_path) if d.startswith("llvm-")]
if llvm_dirs:
# Sort to get the highest llvm version.
llvm_dirs.sort(reverse=True)
clang_bin_dir = os.path.join(llvm_base_path, llvm_dirs[0], "bin")
# Prefer versioned clang binaries (e.g., clang-18).
versioned_clang = None
generic_clang = None
for f in os.listdir(clang_bin_dir):
# Checks for versioned clang binaries.
if f.startswith("clang-") and f[6:].isdigit():
versioned_clang = os.path.join(clang_bin_dir, f)
# Fallback to non-versioned clang.
elif f == "clang":
generic_clang = os.path.join(clang_bin_dir, f)
# Return versioned clang if available, otherwise return generic clang.
if versioned_clang:
return versioned_clang
elif generic_clang:
return generic_clang
return None
def build_jaxlib_wheel(
jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"
):
@ -70,6 +100,14 @@ def build_jaxlib_wheel(
"--use_clang=%s" % use_clang,
]
# Add clang path if clang is used.
if compiler == "clang":
clang_path = find_clang_path()
if clang_path:
cmd.append("--clang_path=%s" % clang_path)
else:
raise RuntimeError("Clang binary not found in /usr/lib/llvm-*")
if xla_path:
cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path)
@ -168,18 +206,26 @@ def to_cpy_ver(python_version):
def fix_wheel(path, jax_path):
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
# so use one of the CPythons in /opt to run
env = dict(os.environ)
py_bin = "/opt/python/cp310-cp310/bin"
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
try:
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
# so use one of the CPythons in /opt to run
env = dict(os.environ)
py_bin = "/opt/python/cp310-cp310/bin"
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
cmd = ["pip", "install", "auditwheel>=6"]
subprocess.run(cmd, check=True, env=env)
cmd = ["pip", "install", "auditwheel>=6"]
subprocess.run(cmd, check=True, env=env)
fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
cmd = ["python", fixwheel_path, path]
subprocess.run(cmd, check=True, env=env)
fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
cmd = ["python", fixwheel_path, path]
subprocess.run(cmd, check=True, env=env)
LOG.info("Wheel fix completed successfully.")
except subprocess.CalledProcessError as cpe:
LOG.error(f"Subprocess failed with error: {cpe}")
raise
except Exception as e:
LOG.error(f"An unexpected error occurred: {e}")
raise
def parse_args():

@ -240,6 +240,7 @@ pybind_extension(
"//jaxlib:kernel_nanobind_helpers",
"@com_google_absl//absl/status",
"@local_config_rocm//rocm:rocm_headers",
"@local_config_rocm//rocm:hip",
"@nanobind",
"@xla//third_party/python_runtime:headers",
"@xla//xla:status",