mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[ROCm] Bring up clang support for JAX+XLA
* Add clang path * bazelrc env fixes * Fix wheelhouse installation and preserve wheels * dockerfile changes * Add target.lst * Change target architectures * Install bzip2 and sqlite packages
This commit is contained in:
parent
9dbbb3a391
commit
dfb7db0e75
16
.bazelrc
16
.bazelrc
@ -107,10 +107,18 @@ build:nvcc_clang --config=cuda_clang
|
||||
build:nvcc_clang --action_env=TF_NVCC_CLANG="1"
|
||||
build:nvcc_clang --@local_config_cuda//:cuda_compiler=nvcc
|
||||
|
||||
build:rocm --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
build:rocm --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm --repo_env TF_NEED_ROCM=1
|
||||
build:rocm --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx1030"
|
||||
build:rocm_base --crosstool_top=@local_config_rocm//crosstool:toolchain
|
||||
build:rocm_base --define=using_rocm=true --define=using_rocm_hipcc=true
|
||||
build:rocm_base --repo_env TF_NEED_ROCM=1
|
||||
build:rocm_base --action_env TF_ROCM_AMDGPU_TARGETS="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100"
|
||||
|
||||
# Build with hipcc for ROCm and clang for the host.
|
||||
build:rocm --config=rocm_base
|
||||
build:rocm --action_env=TF_ROCM_CLANG="1"
|
||||
build:rocm --action_env=CLANG_COMPILER_PATH="/usr/lib/llvm-18/bin/clang"
|
||||
build:rocm --copt=-Wno-gnu-offsetof-extensions
|
||||
build:rocm --copt=-Qunused-arguments
|
||||
build:rocm --action_env=TF_HIPCC_CLANG="1"
|
||||
|
||||
build:nonccl --define=no_nccl_support=true
|
||||
|
||||
|
@ -298,9 +298,12 @@ def write_bazelrc(*, remote_build,
|
||||
f.write(
|
||||
f'build:cuda --repo_env HERMETIC_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"\n')
|
||||
if enable_rocm:
|
||||
f.write("build --config=rocm\n")
|
||||
f.write("build --config=rocm_base\n")
|
||||
if not enable_nccl:
|
||||
f.write("build --config=nonccl\n")
|
||||
if use_clang:
|
||||
f.write("build --config=rocm\n")
|
||||
f.write(f"build --action_env=CLANG_COMPILER_PATH={clang_path}\n")
|
||||
if python_version:
|
||||
f.write(
|
||||
"build --repo_env HERMETIC_PYTHON_VERSION=\"{python_version}\"".format(
|
||||
@ -482,7 +485,7 @@ def main():
|
||||
help="A comma-separated list of CUDA compute capabilities to support.")
|
||||
parser.add_argument(
|
||||
"--rocm_amdgpu_targets",
|
||||
default="gfx900,gfx906,gfx908,gfx90a,gfx1030",
|
||||
default="gfx900,gfx906,gfx908,gfx90a,gfx940,gfx941,gfx942,gfx1030,gfx1100",
|
||||
help="A comma-separated list of ROCm amdgpu targets to support.")
|
||||
parser.add_argument(
|
||||
"--rocm_path",
|
||||
|
@ -5,8 +5,14 @@ FROM ubuntu:20.04 AS rocm_base
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
apt-get update && apt-get install -y python3 python-is-python3
|
||||
|
||||
# Install bzip2 and sqlite3 packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
sqlite3 libsqlite3-dev \
|
||||
libbz2-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add target file to help determine which device(s) to build for
|
||||
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
|
||||
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
|
||||
|
||||
# Install ROCm
|
||||
@ -70,6 +76,7 @@ FROM rocm_base AS rt_build
|
||||
ARG JAX_VERSION
|
||||
ARG JAX_COMMIT
|
||||
ARG XLA_COMMIT
|
||||
ARG JAX_USE_CLANG
|
||||
|
||||
LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
|
||||
com.amdgpu.python_version="$PYTHON_VERSION" \
|
||||
@ -77,7 +84,15 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
|
||||
com.amdgpu.jax_commit="$JAX_COMMIT" \
|
||||
com.amdgpu.xla_commit="$XLA_COMMIT"
|
||||
|
||||
|
||||
# Create a directory to copy and retain the wheels in the image.
|
||||
RUN mkdir -p /rocm_jax_wheels
|
||||
|
||||
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
|
||||
pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
|
||||
cp /wheelhouse/* /rocm_jax_wheels/ && \
|
||||
ls -lah /wheelhouse && \
|
||||
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
|
||||
pip3 install wheelhouse/*rocm60*.whl
|
||||
|
||||
|
||||
|
@ -7,3 +7,13 @@ ARG ROCM_BUILD_NUM
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
|
||||
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM
|
||||
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
|
||||
RUN printf '%s\n' > /opt/rocm/bin/target.lst ${GPU_DEVICE_TARGETS}
|
||||
|
||||
# Install LLVM 18 and dependencies.
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
dnf install -y wget && dnf clean all
|
||||
RUN mkdir /tmp/llvm-project && wget -qO - https://github.com/llvm/llvm-project/archive/refs/tags/llvmorg-18.1.8.tar.gz | tar -xz -C /tmp/llvm-project --strip-components 1 && \
|
||||
mkdir /tmp/llvm-project/build && cd /tmp/llvm-project/build && cmake -DLLVM_ENABLE_PROJECTS='clang;lld' -DCMAKE_BUILD_TYPE=Release -DCMAKE_INSTALL_PREFIX=/usr/lib/llvm-18/ ../llvm && \
|
||||
make -j$(nproc) && make -j$(nproc) install && rm -rf /tmp/llvm-project
|
||||
|
@ -3,8 +3,14 @@ FROM ubuntu:22.04
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
apt-get update && apt-get install -y python3 python-is-python3
|
||||
|
||||
# Install bzip2 and sqlite3 packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
sqlite3 libsqlite3-dev \
|
||||
libbz2-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add target file to help determine which device(s) to build for
|
||||
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
|
||||
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
|
||||
|
||||
# Install ROCM
|
||||
@ -61,4 +67,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
|
||||
|
||||
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
|
||||
pip3 install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
|
||||
ls -lah /wheelhouse && \
|
||||
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
|
||||
pip3 install wheelhouse/*rocm60*.whl
|
||||
|
@ -3,6 +3,12 @@ FROM ubuntu:24.04
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
apt-get update && apt-get install -y python3 python-is-python3 python3-pip
|
||||
|
||||
# Install bzip2 and sqlite3 packages
|
||||
RUN apt-get update && apt-get install -y \
|
||||
sqlite3 libsqlite3-dev \
|
||||
libbz2-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Add target file to help determine which device(s) to build for
|
||||
ARG GPU_DEVICE_TARGETS="gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
|
||||
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
|
||||
@ -60,4 +66,6 @@ LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
|
||||
|
||||
RUN --mount=type=cache,mode=0755,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=wheelhouse,target=/wheelhouse \
|
||||
pip3 install --break-system-packages --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
|
||||
ls -lah /wheelhouse && \
|
||||
pip3 install wheelhouse/*none*.whl wheelhouse/*jaxlib*.whl && \
|
||||
pip3 install wheelhouse/*rocm60*.whl
|
||||
|
@ -56,6 +56,36 @@ def update_rocm_targets(rocm_path, targets):
|
||||
open(version_fp, "a").close()
|
||||
|
||||
|
||||
def find_clang_path():
|
||||
llvm_base_path = "/usr/lib/"
|
||||
# Search for llvm directories and pick the highest version.
|
||||
llvm_dirs = [d for d in os.listdir(llvm_base_path) if d.startswith("llvm-")]
|
||||
if llvm_dirs:
|
||||
# Sort to get the highest llvm version.
|
||||
llvm_dirs.sort(reverse=True)
|
||||
clang_bin_dir = os.path.join(llvm_base_path, llvm_dirs[0], "bin")
|
||||
|
||||
# Prefer versioned clang binaries (e.g., clang-18).
|
||||
versioned_clang = None
|
||||
generic_clang = None
|
||||
|
||||
for f in os.listdir(clang_bin_dir):
|
||||
# Checks for versioned clang binaries.
|
||||
if f.startswith("clang-") and f[6:].isdigit():
|
||||
versioned_clang = os.path.join(clang_bin_dir, f)
|
||||
# Fallback to non-versioned clang.
|
||||
elif f == "clang":
|
||||
generic_clang = os.path.join(clang_bin_dir, f)
|
||||
|
||||
# Return versioned clang if available, otherwise return generic clang.
|
||||
if versioned_clang:
|
||||
return versioned_clang
|
||||
elif generic_clang:
|
||||
return generic_clang
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def build_jaxlib_wheel(
|
||||
jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"
|
||||
):
|
||||
@ -70,6 +100,14 @@ def build_jaxlib_wheel(
|
||||
"--use_clang=%s" % use_clang,
|
||||
]
|
||||
|
||||
# Add clang path if clang is used.
|
||||
if compiler == "clang":
|
||||
clang_path = find_clang_path()
|
||||
if clang_path:
|
||||
cmd.append("--clang_path=%s" % clang_path)
|
||||
else:
|
||||
raise RuntimeError("Clang binary not found in /usr/lib/llvm-*")
|
||||
|
||||
if xla_path:
|
||||
cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path)
|
||||
|
||||
@ -168,18 +206,26 @@ def to_cpy_ver(python_version):
|
||||
|
||||
|
||||
def fix_wheel(path, jax_path):
|
||||
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
|
||||
# so use one of the CPythons in /opt to run
|
||||
env = dict(os.environ)
|
||||
py_bin = "/opt/python/cp310-cp310/bin"
|
||||
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
|
||||
try:
|
||||
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
|
||||
# so use one of the CPythons in /opt to run
|
||||
env = dict(os.environ)
|
||||
py_bin = "/opt/python/cp310-cp310/bin"
|
||||
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
|
||||
|
||||
cmd = ["pip", "install", "auditwheel>=6"]
|
||||
subprocess.run(cmd, check=True, env=env)
|
||||
cmd = ["pip", "install", "auditwheel>=6"]
|
||||
subprocess.run(cmd, check=True, env=env)
|
||||
|
||||
fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
|
||||
cmd = ["python", fixwheel_path, path]
|
||||
subprocess.run(cmd, check=True, env=env)
|
||||
fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
|
||||
cmd = ["python", fixwheel_path, path]
|
||||
subprocess.run(cmd, check=True, env=env)
|
||||
LOG.info("Wheel fix completed successfully.")
|
||||
except subprocess.CalledProcessError as cpe:
|
||||
LOG.error(f"Subprocess failed with error: {cpe}")
|
||||
raise
|
||||
except Exception as e:
|
||||
LOG.error(f"An unexpected error occurred: {e}")
|
||||
raise
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -240,6 +240,7 @@ pybind_extension(
|
||||
"//jaxlib:kernel_nanobind_helpers",
|
||||
"@com_google_absl//absl/status",
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
"@local_config_rocm//rocm:hip",
|
||||
"@nanobind",
|
||||
"@xla//third_party/python_runtime:headers",
|
||||
"@xla//xla:status",
|
||||
|
Loading…
x
Reference in New Issue
Block a user