From 1e58d76772afdcf006d03eeaf4743c08af6765cc Mon Sep 17 00:00:00 2001 From: Mathew Odden Date: Wed, 17 Jul 2024 19:09:07 -0500 Subject: [PATCH] [ROCm] Change ROCm builds to manylinux wheels --- build/rocm/Dockerfile.ms | 37 ++- build/rocm/build_rocm.sh | 56 +--- .../Dockerfile.manylinux_2_28_x86_64.rocm | 7 + build/rocm/ci_build | 256 ++++++++++++++++++ build/rocm/ci_build.sh | 154 ++++------- build/rocm/setup.rocm.sh | 2 +- build/rocm/tools/blacken.sh | 3 + build/rocm/tools/build_wheels.py | 222 +++++++++++++++ build/rocm/tools/fixwheel.py | 97 +++++++ build/rocm/tools/get_rocm.py | 100 +++++++ build/rocm/tools/libc.py | 48 ++++ build/rocm/tools/symbols.py | 53 ++++ 12 files changed, 877 insertions(+), 158 deletions(-) create mode 100644 build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm create mode 100755 build/rocm/ci_build create mode 100644 build/rocm/tools/blacken.sh create mode 100644 build/rocm/tools/build_wheels.py create mode 100644 build/rocm/tools/fixwheel.py create mode 100644 build/rocm/tools/get_rocm.py create mode 100644 build/rocm/tools/libc.py create mode 100644 build/rocm/tools/symbols.py diff --git a/build/rocm/Dockerfile.ms b/build/rocm/Dockerfile.ms index 5f831f111..899f29a14 100644 --- a/build/rocm/Dockerfile.ms +++ b/build/rocm/Dockerfile.ms @@ -1,6 +1,5 @@ ################################################################################ -ARG BASE_DOCKER=ubuntu:20.04 -FROM $BASE_DOCKER as rt_build +FROM ubuntu:20.04 AS rocm_base ################################################################################ # Add target file to help determine which device(s) to build for @@ -12,9 +11,9 @@ ARG ROCM_VERSION=6.0.0 ARG CUSTOM_INSTALL ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION} ENV ROCM_PATH=${ROCM_PATH} -COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL} -COPY setup.rocm.sh /setup.rocm.sh -RUN /setup.rocm.sh $ROCM_VERSION +#COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL} +RUN --mount=type=bind,source=build/rocm/setup.rocm.sh,target=/setup.rocm.sh \ + /setup.rocm.sh $ROCM_VERSION # Set up paths ENV HCC_HOME=$ROCM_PATH/hcc @@ -25,13 +24,35 @@ ENV PATH="$ROCM_PATH/bin:${PATH}" ENV PATH="$OPENCL_ROOT/bin:${PATH}" ENV PATH="/root/bin:/root/.local/bin:$PATH" - # Install pyenv with different python versions -ARG PYTHON_VERSION=3.10.0 +ARG PYTHON_VERSION=3.10.14 RUN git clone https://github.com/pyenv/pyenv.git /pyenv ENV PYENV_ROOT /pyenv ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH RUN pyenv install $PYTHON_VERSION -RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip && pip install numpy setuptools build wheel six auditwheel scipy pytest pytest-html pytest_html_merger pytest-reportlog pytest-rerunfailures cloudpickle portpicker matplotlib absl-py flatbuffers hypothesis +RUN eval "$(pyenv init -)" && \ + pyenv local ${PYTHON_VERSION} && \ + pip3 install --upgrade --force-reinstall setuptools pip && \ + pip install \ + numpy setuptools build wheel six auditwheel scipy \ + pytest pytest-html pytest_html_merger pytest-reportlog \ + pytest-rerunfailures cloudpickle portpicker matplotlib absl-py \ + flatbuffers hypothesis +################################################################################ +FROM rocm_base AS rt_build +################################################################################ + +ARG JAX_VERSION +ARG JAX_COMMIT +ARG XLA_COMMIT + +LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \ + com.amdgpu.python_version="$PYTHON_VERSION" \ + com.amdgpu.jax_version="$JAX_VERSION" \ + com.amdgpu.jax_commit="$JAX_COMMIT" \ + com.amdgpu.xla_commit="$XLA_COMMIT" + +RUN --mount=type=bind,source=wheelhouse,target=/wheelhouse \ + pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt diff --git a/build/rocm/build_rocm.sh b/build/rocm/build_rocm.sh index 6374a2a18..111998d35 100755 --- a/build/rocm/build_rocm.sh +++ b/build/rocm/build_rocm.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash + # Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,57 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Environment Var Notes -# XLA_CLONE_DIR - -# Specifies filepath to where XLA repo is cloned. -# NOTE:, if this is set then XLA repo is not cloned. Must clone repo before running this script. -# Also, if this is set then setting XLA_REPO and XLA_BRANCH have no effect. -# XLA_REPO -# XLA repo to clone from. Default is https://github.com/ROCmSoftwarePlatform/tensorflow-upstream -# XLA_BRANCH -# XLA branch in the XLA repo. Default is develop-upstream-jax -# +# NOTE(mrodden): ROCm JAX build and installs have moved to wheel based builds and installs, +# but some CI scripts still try to run this script. Nothing needs to be done here, +# but we print some debugging information for logs. set -eux python -V -#If XLA_REPO is not set, then use default -if [ ! -v XLA_REPO ]; then - XLA_REPO="https://github.com/openxla/xla.git" - XLA_BRANCH="main" -elif [ -z "$XLA_REPO" ]; then - XLA_REPO="https://github.com/openxla/xla.git" - XLA_BRANCH="main" -fi - -#If XLA_CLONE_PATH is not set, then use default path. -#Note, setting XLA_CLONE_PATH makes setting XLA_REPO and XLA_BRANCH a no-op -#Set this when XLA repository has been already clone. This is useful in CI -#environments and when doing local development -if [ ! -v XLA_CLONE_DIR ]; then - XLA_CLONE_DIR=/tmp/xla - rm -rf /tmp/xla || true - git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla -elif [ -z "$XLA_CLONE_DIR" ]; then - XLA_CLONE_DIR=/tmp/xla - rm -rf /tmp/xla || true - git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla -fi - - -#Export JAX_ROCM_VERSION so that it is appened in the wheel name -export JAXLIB_RELEASE=1 -rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1) -export JAX_ROCM_VERSION=${rocm_version//./} - -#Build and install wheel -python3 ./build/build.py --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR} - -JAX_RELEASE=1 python -m build -pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA) - -#This is for CI to read without having to start the container again -if [ -v CI_RUN ]; then - pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1 > jax_version_installed - cat /opt/rocm/.info/version | cut -d "-" -f 1 > jax_rocm_version -fi +printf "Detected jaxlib version: %s\n" $(pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1) +printf "Detected ROCm version: %s\n" $(cat /opt/rocm/.info/version | cut -d "-" -f 1) diff --git a/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm new file mode 100644 index 000000000..fd2de6a0c --- /dev/null +++ b/build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm @@ -0,0 +1,7 @@ +FROM quay.io/pypa/manylinux_2_28_x86_64 + +ARG ROCM_VERSION=6.1.1 + +RUN --mount=type=cache,target=/var/cache/dnf \ + --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \ + python3 get_rocm.py --rocm-version $ROCM_VERSION diff --git a/build/rocm/ci_build b/build/rocm/ci_build new file mode 100755 index 000000000..a43bd26fd --- /dev/null +++ b/build/rocm/ci_build @@ -0,0 +1,256 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +import os +import subprocess +import sys + + +def image_by_name(name): + cmd = ["docker", "images", "-q", "-f", "reference=%s" % name] + out = subprocess.check_output(cmd) + image_id = out.decode("utf8").strip().split("\n")[0] or None + return image_id + + +def dist_wheels(rocm_version, python_versions, xla_path): + xla_path = os.path.abspath(xla_path) + + # create manylinux image with requested ROCm installed + image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "") + + cmd = [ + "docker", + "build", + "-f", + "build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm", + "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--tag=%s" % image, + ".", + ] + + if not image_by_name(image): + _ = subprocess.run(cmd, check=True) + + # use image to build JAX/jaxlib wheels + os.makedirs("wheelhouse", exist_ok=True) + + pyver_string = ",".join(python_versions) + + container_xla_path = "/xla" + + bw_cmd = [ + "python3", + "/jax/build/rocm/tools/build_wheels.py", + "--rocm-version", + rocm_version, + "--python-versions", + pyver_string, + ] + + if xla_path: + bw_cmd.extend(["--xla-path", container_xla_path]) + + bw_cmd.append("/jax") + + cmd = ["docker", "run", "-it"] + + mounts = [ + "-v", + "./:/jax", + "-v", + "./wheelhouse:/wheelhouse", + ] + + if xla_path: + mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)]) + + cmd.extend(mounts) + + # NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID + cmd.extend( + [ + "--init", + "--rm", + image, + "bash", + "-c", + " ".join(bw_cmd), + ] + ) + + _ = subprocess.run(cmd, check=True) + + +def _fetch_jax_metadata(xla_path): + cmd = ["git", "rev-parse", "HEAD"] + jax_commit = subprocess.check_output(cmd) + xla_commit = "" + + if xla_path: + try: + xla_commit = subprocess.check_output(cmd, cwd=xla_path) + except Exception as ex: + LOG.warning("Exception while retrieving xla_commit: %s" % ex) + + cmd = ["python", "setup.py", "-V"] + env = dict(os.environ) + env["JAX_RELEASE"] = "1" + + jax_version = subprocess.check_output(cmd, env=env) + + return { + "jax_version": jax_version.decode("utf8").strip(), + "jax_commit": jax_commit.decode("utf8").strip(), + "xla_commit": xla_commit.decode("utf8").strip(), + } + + +def dist_docker( + rocm_version, + python_versions, + xla_path, + tag="rocm/jax-dev", + dockerfile=None, + keep_image=True, +): + if not dockerfile: + dockerfile = "build/rocm/Dockerfile.ms" + + python_version = python_versions[0] + + md = _fetch_jax_metadata(xla_path) + + cmd = [ + "docker", + "build", + "-f", + dockerfile, + "--target", + "rt_build", + "--build-arg=ROCM_VERSION=%s" % rocm_version, + "--build-arg=PYTHON_VERSION=%s" % python_version, + "--build-arg=JAX_VERSION=%(jax_version)s" % md, + "--build-arg=JAX_COMMIT=%(jax_commit)s" % md, + "--build-arg=XLA_COMMIT=%(xla_commit)s" % md, + "--tag=%s" % tag, + ] + + if not keep_image: + cmd.append("--rm") + + # context dir + cmd.append(".") + + subprocess.check_call(cmd) + + +def test(image_name): + """Run unit tests like CI would inside a JAX image.""" + + gpu_args = [ + "--device=/dev/kfd", + "--device=/dev/dri", + "--group-add", + "video", + "--cap-add=SYS_PTRACE", + "--security-opt", + "seccomp=unconfined", + "--shm-size", + "16G", + ] + + cmd = [ + "docker", + "run", + "-it", + "--rm", + ] + + # NOTE(mrodden): we need jax source dir for the unit test code only, + # JAX and jaxlib are already installed from wheels + mounts = [ + "-v", + "./:/jax", + ] + + cmd.extend(mounts) + cmd.extend(gpu_args) + + container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh" + cmd.append(image_name) + cmd.extend( + [ + "bash", + "-c", + container_cmd, + ] + ) + + subprocess.check_call(cmd) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument( + "--python-versions", + type=lambda x: x.split(","), + default="3.12", + help="Comma separated list of CPython versions to build wheels for", + ) + + p.add_argument( + "--rocm-version", + default="6.1.1", + help="ROCm version used for building wheels, testing, and installing into Docker image", + ) + + p.add_argument( + "--xla-source-dir", + help="Path to XLA source to use during jaxlib build, instead of builtin XLA", + ) + + subp = p.add_subparsers(dest="action", required=True) + + dwp = subp.add_parser("dist_wheels") + + testp = subp.add_parser("test") + testp.add_argument("image_name") + + ddp = subp.add_parser("dist_docker") + ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms") + ddp.add_argument("--keep-image", action="store_true") + ddp.add_argument("--image-tag", default="rocm/jax-dev") + + return p.parse_args() + + +def main(): + args = parse_args() + + if args.action == "dist_wheels": + dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir) + + elif args.action == "test": + test(args.image_name) + + elif args.action == "dist_docker": + dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir) + dist_docker( + args.rocm_version, + args.python_versions, + args.xla_source_dir, + tag=args.image_tag, + dockerfile=args.dockerfile, + keep_image=args.keep_image, + ) + + +if __name__ == "__main__": + main() diff --git a/build/rocm/ci_build.sh b/build/rocm/ci_build.sh index 9084651be..ab599e266 100755 --- a/build/rocm/ci_build.sh +++ b/build/rocm/ci_build.sh @@ -1,4 +1,5 @@ #!/usr/bin/env bash + # Copyright 2022 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -28,12 +29,8 @@ # # ROCM_VERSION: ROCm repo version # -# ROCM_PATH: ROCM path in the docker container -# # Environment variables read by this script # WORKSPACE -# XLA_REPO -# XLA_BRANCH # XLA_CLONE_DIR # BUILD_TAG # @@ -44,75 +41,63 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" source "${SCRIPT_DIR}/build_common.sh" CONTAINER_TYPE="rocm" -DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms" +DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms" DOCKER_CONTEXT_PATH="${SCRIPT_DIR}" KEEP_IMAGE="--rm" -KEEP_CONTAINER="--rm" -PYTHON_VERSION="3.10.0" -ROCM_VERSION="6.0.0" #Point to latest release +PYTHON_VERSION="3.10" +ROCM_VERSION="6.1.3" BASE_DOCKER="ubuntu:20.04" CUSTOM_INSTALL="" -#BASE_DOCKER="compute-artifactory.amd.com:5000/rocm-plus-docker/compute-rocm-rel-6.0:91-ubuntu-20.04-stg2" -#CUSTOM_INSTALL="custom_install_dummy.sh" -#ROCM_PATH="/opt/rocm-5.6.0" POSITIONAL_ARGS=() RUNTIME_FLAG=1 while [[ $# -gt 0 ]]; do - case $1 in - --py_version) - PYTHON_VERSION="$2" - shift 2 - ;; - --dockerfile) - DOCKERFILE_PATH="$2" - DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") - shift 2 - ;; - --keep_image) - KEEP_IMAGE="" - shift 1 - ;; - --runtime) - RUNTIME_FLAG=1 - shift 1 - ;; - --keep_container) - KEEP_CONTAINER="" - shift 1 - ;; - --rocm_version) - ROCM_VERSION="$2" - shift 2 - ;; - #--rocm_path) - # ROCM_PATH="$2" - # shift 2 - # ;; - - *) - POSITIONAL_ARGS+=("$1") - shift - ;; - esac + case $1 in + --py_version) + PYTHON_VERSION="$2" + shift 2 + ;; + --dockerfile) + DOCKERFILE_PATH="$2" + DOCKER_CONTEXT_PATH=$(dirname "${DOCKERFILE_PATH}") + shift 2 + ;; + --keep_image) + KEEP_IMAGE="" + shift 1 + ;; + --runtime) + RUNTIME_FLAG=1 + shift 1 + ;; + --keep_container) + KEEP_CONTAINER="" + shift 1 + ;; + --rocm_version) + ROCM_VERSION="$2" + shift 2 + ;; + *) + POSITIONAL_ARGS+=("$1") + shift + ;; + esac done if [[ ! -f "${DOCKERFILE_PATH}" ]]; then - die "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\"" + die "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\"" fi -ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video \ - --cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G" - # Helper function to traverse directories up until given file is found. function upsearch (){ - test / == "$PWD" && return || \ - test -e "$1" && echo "$PWD" && return || \ - cd .. && upsearch "$1" + test / == "$PWD" && return || \ + test -e "$1" && echo "$PWD" && return || \ + cd .. && upsearch "$1" } -# Set up WORKSPACE. +# Set up WORKSPACE. WORKSPACE="${WORKSPACE:-$(upsearch WORKSPACE)}" BUILD_TAG="${BUILD_TAG:-jax}" @@ -126,6 +111,7 @@ DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | sed -e 's/=/_/g' -e 's/,/-/g') # Convert to all lower-case, as per requirement of Docker image names DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]') + # Print arguments. echo "WORKSPACE: ${WORKSPACE}" echo "COMMAND: ${POSITIONAL_ARGS[*]}" @@ -135,55 +121,25 @@ echo "" echo "Building container (${DOCKER_IMG_NAME})..." echo "Python Version (${PYTHON_VERSION})" -if [[ "${RUNTIME_FLAG}" -eq 1 ]]; then - echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..." - docker build --target rt_build --tag ${DOCKER_IMG_NAME} \ - --build-arg PYTHON_VERSION=$PYTHON_VERSION --build-arg ROCM_VERSION=$ROCM_VERSION \ - --build-arg CUSTOM_INSTALL=$CUSTOM_INSTALL \ - --build-arg BASE_DOCKER=$BASE_DOCKER \ - -f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}" -else - echo "Building (CI) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..." - docker build --target ci_build --tag ${DOCKER_IMG_NAME} \ - --build-arg PYTHON_VERSION=$PYTHON_VERSION \ - --build-arg BASE_DOCKER=$BASE_DOCKER \ - -f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}" -fi +echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..." -# Check docker build status -if [[ $? != "0" ]]; then - die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}" -fi - -# Run the command inside the container. -echo "Running '${POSITIONAL_ARGS[*]}' inside ${DOCKER_IMG_NAME}..." - -export XLA_REPO="${XLA_REPO:-}" -export XLA_BRANCH="${XLA_BRANCH:-}" export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}" -export JAX_RENAME_WHL="${XLA_CLONE_DIR:-}" -if [ ! -z ${XLA_CLONE_DIR} ]; then - ROCM_EXTRA_PARAMS=${ROCM_EXTRA_PARAMS}" -v ${XLA_CLONE_DIR}:${XLA_CLONE_DIR}" -fi +# ci_build.sh is mostly a compatibility wrapper for ci_build -docker run ${KEEP_IMAGE} --name ${DOCKER_IMG_NAME} --pid=host \ - -v ${WORKSPACE}:/workspace \ - -w /workspace \ - -e XLA_REPO=${XLA_REPO} \ - -e XLA_BRANCH=${XLA_BRANCH} \ - -e XLA_CLONE_DIR=${XLA_CLONE_DIR} \ - -e PYTHON_VERSION=$PYTHON_VERSION \ - -e CI_RUN=1 \ - ${ROCM_EXTRA_PARAMS} \ - "${DOCKER_IMG_NAME}" \ - ${POSITIONAL_ARGS[@]} +# 'dist_docker' will run 'dist_wheels' followed by a Docker build to create the "JAX image", +# which is the ROCm image that is shipped for users to use (i.e. distributable). +./build/rocm/ci_build \ + --rocm-version $ROCM_VERSION \ + --python-versions $PYTHON_VERSION \ + --xla-source-dir $XLA_CLONE_DIR \ + dist_docker \ + --dockerfile $DOCKERFILE_PATH \ + --image-tag $DOCKER_IMG_NAME -if [[ "${KEEP_IMAGE}" != "--rm" ]] && [[ $? == "0" ]]; then - echo "Committing the docker container as ${DOCKER_IMG_NAME}" - docker stop ${DOCKER_IMG_NAME} - docker commit ${DOCKER_IMG_NAME} ${DOCKER_IMG_NAME} - docker rm ${DOCKER_IMG_NAME} # remove this temp container +# Check build status +if [[ $? != "0" ]]; then + die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}" fi echo "Jax-ROCm build was successful!" diff --git a/build/rocm/setup.rocm.sh b/build/rocm/setup.rocm.sh index 1ade67b17..35c8f4c51 100755 --- a/build/rocm/setup.rocm.sh +++ b/build/rocm/setup.rocm.sh @@ -25,7 +25,7 @@ ROCM_DEB_REPO=${ROCM_DEB_REPO_HOME}${ROCM_VERS}/ if [ ! -f "/${CUSTOM_INSTALL}" ]; then # Add rocm repository chmod 1777 /tmp - DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update + DEBIAN_FRONTEND=noninteractive apt-get --allow-unauthenticated update DEBIAN_FRONTEND=noninteractive apt install -y wget software-properties-common DEBIAN_FRONTEND=noninteractive apt-get clean all wget -qO - https://repo.radeon.com/rocm/rocm.gpg.key | apt-key add -; diff --git a/build/rocm/tools/blacken.sh b/build/rocm/tools/blacken.sh new file mode 100644 index 000000000..7b61cbdb9 --- /dev/null +++ b/build/rocm/tools/blacken.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +black -t py36 build/rocm/ci_build build/rocm/tools/*.py diff --git a/build/rocm/tools/build_wheels.py b/build/rocm/tools/build_wheels.py new file mode 100644 index 000000000..9b9ff7788 --- /dev/null +++ b/build/rocm/tools/build_wheels.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +from collections import deque +import fcntl +import logging +import os +import re +import select +import subprocess +import shutil +import sys + + +LOG = logging.getLogger(__name__) + + +GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100" + + +def build_rocm_path(rocm_version_str): + return "/opt/rocm-%s" % rocm_version_str + + +def update_rocm_targets(rocm_path, targets): + target_fp = os.path.join(rocm_path, "bin/target.lst") + version_fp = os.path.join(rocm_path, ".info/version") + with open(target_fp, "w") as fd: + fd.write("%s\n" % targets) + + # mimic touch + open(version_fp, "a").close() + + +def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None): + cmd = [ + "python", + "build/build.py", + "--enable_rocm", + "--build_gpu_plugin", + "--gpu_plugin_rocm_version=60", + "--rocm_path=%s" % rocm_path, + ] + + if xla_path: + cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path) + + cpy = to_cpy_ver(python_version) + py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy) + + env = dict(os.environ) + env["JAX_RELEASE"] = str(1) + env["JAXLIB_RELEASE"] = str(1) + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) + pattern = re.compile("Output wheel: (.+)\n") + + return _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr") + + +def build_jax_wheel(jax_path, python_version): + cmd = [ + "python", + "-m", + "build", + ] + + cpy = to_cpy_ver(python_version) + py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy) + + env = dict(os.environ) + env["JAX_RELEASE"] = str(1) + env["JAXLIB_RELEASE"] = str(1) + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) + pattern = re.compile("Successfully built jax-.+ and (jax-.+\.whl)\n") + + wheels = _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout") + + paths = list(map(lambda x: os.path.join(jax_path, "dist", x), wheels)) + return paths + + +def _run_scan_for_output(cmd, pattern, env=None, cwd=None, capture=None): + + buf = deque(maxlen=20000) + + if capture == "stderr": + p = subprocess.Popen(cmd, env=env, cwd=cwd, stderr=subprocess.PIPE) + redir = sys.stderr + cap_fd = p.stderr + else: + p = subprocess.Popen(cmd, env=env, cwd=cwd, stdout=subprocess.PIPE) + redir = sys.stdout + cap_fd = p.stdout + + flags = fcntl.fcntl(cap_fd, fcntl.F_GETFL) + fcntl.fcntl(cap_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + + eof = False + while not eof: + r, _, _ = select.select([cap_fd], [], []) + for fd in r: + dat = fd.read(512) + if dat is None: + continue + elif dat: + t = dat.decode("utf8") + redir.write(t) + buf.extend(t) + else: + eof = True + + # wait and drain pipes + _, _ = p.communicate() + + if p.returncode != 0: + raise Exception( + "Child process exited with nonzero result: rc=%d" % p.returncode + ) + + text = "".join(buf) + + matches = pattern.findall(text) + + if not matches: + LOG.error("No wheel name found in output: %r" % text) + raise Exception("No wheel name found in output") + + wheels = [] + for match in matches: + LOG.info("Found built wheel: %r" % match) + wheels.append(match) + + return wheels + + +def to_cpy_ver(python_version): + tup = python_version.split(".") + return "cp%d%d" % (int(tup[0]), int(tup[1])) + + +def fix_wheel(path, jax_path): + # NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8 + # so use one of the CPythons in /opt to run + env = dict(os.environ) + py_bin = "/opt/python/cp310-cp310/bin" + env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) + + cmd = ["pip", "install", "auditwheel>=6"] + subprocess.run(cmd, check=True, env=env) + + fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") + cmd = ["python", fixwheel_path, path] + subprocess.run(cmd, check=True, env=env) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument( + "--rocm-version", default="6.1.1", help="ROCM Version to build JAX against" + ) + p.add_argument( + "--python-versions", + default=["3.10.19,3.12"], + help="Comma separated CPython versions that wheels will be built and output for", + ) + p.add_argument( + "--xla-path", + type=str, + default=None, + help="Optional directory where XLA source is located to use instead of JAX builtin XLA", + ) + + p.add_argument("jax_path", help="Directory where JAX source directory is located") + + return p.parse_args() + + +def main(): + args = parse_args() + python_versions = args.python_versions.split(",") + + print("ROCM_VERSION=%s" % args.rocm_version) + print("PYTHON_VERSIONS=%r" % python_versions) + print("JAX_PATH=%s" % args.jax_path) + print("XLA_PATH=%s" % args.xla_path) + + rocm_path = build_rocm_path(args.rocm_version) + + update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS) + + for py in python_versions: + wheel_paths = build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path) + for wheel_path in wheel_paths: + fix_wheel(wheel_path, args.jax_path) + + # build JAX wheel for completeness + jax_wheels = build_jax_wheel(args.jax_path, python_versions[-1]) + + # NOTE(mrodden): the jax wheel is a "non-platform wheel", so auditwheel will + # do nothing, and in fact will throw an Exception. we just need to copy it + # along with the jaxlib and plugin ones + + # copy jax wheel(s) to wheelhouse + wheelhouse_dir = "/wheelhouse/" + for whl in jax_wheels: + LOG.info("Copying %s into %s" % (whl, wheelhouse_dir)) + shutil.copy(whl, wheelhouse_dir) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/build/rocm/tools/fixwheel.py b/build/rocm/tools/fixwheel.py new file mode 100644 index 000000000..d5951cdd4 --- /dev/null +++ b/build/rocm/tools/fixwheel.py @@ -0,0 +1,97 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +import logging +import os +from pprint import pprint +import subprocess + +from auditwheel.lddtree import lddtree +from auditwheel.wheeltools import InWheelCtx +from auditwheel.elfutils import elf_file_filter +from auditwheel.policy import WheelPolicies +from auditwheel.wheel_abi import analyze_wheel_abi + + +LOG = logging.getLogger(__name__) + + +def tree(path): + + with InWheelCtx(path) as ctx: + for sofile, fd in elf_file_filter(ctx.iter_files()): + + LOG.info("found SO file: %s" % sofile) + elftree = lddtree(sofile) + + print(elftree) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("wheel_path") + return p.parse_args() + + +def parse_wheel_name(path): + wheel_name = os.path.basename(path) + return wheel_name[:-4].split("-") + + +def fix_wheel(path): + tup = parse_wheel_name(path) + plat_tag = tup[4] + if "manylinux2014" in plat_tag: + # strip any manylinux tags from the current wheel first + from wheel.cli import tags + + plat_mod_str = "linux_x86_64" + new_wheel = tags.tags( + path, + python_tags=None, + abi_tags=None, + platform_tags=plat_mod_str, + build_tag=None, + ) + new_path = os.path.join(os.path.dirname(path), new_wheel) + LOG.info("Stripped broken tags and created new wheel at %r" % new_path) + path = new_path + + # build excludes, using auditwheels lddtree to find them + wheel_pol = WheelPolicies() + exclude = frozenset() + abi = analyze_wheel_abi(wheel_pol, path, exclude) + + plat = "manylinux_2_28_x86_64" + ext_libs = abi.external_refs.get(plat, {}).get("libs") + exclude = list(ext_libs.keys()) + + # call auditwheel repair with excludes + cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"] + + for ex in exclude: + cmd.append("--exclude") + cmd.append(ex) + + cmd.append(path) + + LOG.info("running %r" % cmd) + + rc = subprocess.run(cmd, check=True) + + +def main(): + args = parse_args() + path = args.wheel_path + fix_wheel(path) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/build/rocm/tools/get_rocm.py b/build/rocm/tools/get_rocm.py new file mode 100644 index 000000000..4cc4a4682 --- /dev/null +++ b/build/rocm/tools/get_rocm.py @@ -0,0 +1,100 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import argparse +import logging +import subprocess + + +LOG = logging.getLogger(__name__) + + +def which_linux(): + try: + os_rel = open("/etc/os-release").read() + + kvs = {} + for line in os_rel.split("\n"): + if line.strip(): + k, v = line.strip().split("=", 1) + v = v.strip('"') + kvs[k] = v + + print(kvs) + except OSError: + pass + + +rocm_package_names = [ + "libdrm-amdgpu", + "rocm-dev", + "rocm-ml-sdk", + "miopen-hip ", + "miopen-hip-devel", + "rocblas", + "rocblas-devel", + "rocsolver-devel", + "rocrand-devel", + "rocfft-devel", + "hipfft-devel", + "hipblas-devel", + "rocprim-devel", + "hipcub-devel", + "rccl-devel", + "hipsparse-devel", + "hipsolver-devel", +] + + +def install_rocm_el8(rocm_version_str): + + with open("/etc/yum.repos.d/rocm.repo", "w") as rfd: + rfd.write( + """ +[ROCm] +name=ROCm +baseurl=http://repo.radeon.com/rocm/rhel8/%s/main +enabled=1 +gpgcheck=1 +gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key +""" + % rocm_version_str + ) + + with open("/etc/yum.repos.d/amdgpu.repo", "w") as afd: + afd.write( + """ +[amdgpu] +name=amdgpu +baseurl=https://repo.radeon.com/amdgpu/latest/rhel/8.8/main/x86_64/ +enabled=1 +gpgcheck=1 +gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key +""" + ) + + cmd = ["dnf", "install", "-y"] + cmd.extend(rocm_package_names) + LOG.info("Running %r" % cmd) + subprocess.run(cmd, check=True) + + +def parse_args(): + p = argparse.ArgumentParser() + p.add_argument("--rocm-version", help="ROCm version to install", default="6.1.1") + return p.parse_args() + + +def main(): + args = parse_args() + install_rocm_el8(args.rocm_version) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/build/rocm/tools/libc.py b/build/rocm/tools/libc.py new file mode 100644 index 000000000..61983d6c2 --- /dev/null +++ b/build/rocm/tools/libc.py @@ -0,0 +1,48 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import os +import sys + + +def get_libc_version(): + """ + Detect and return glibc version that the current Python is linked against. + + This mimics the detection behavior of the 'wheel' and 'auditwheel' projects, + but without any PyPy or libmusl support. + """ + + try: + version_str = os.confstr("CS_GNU_LIBC_VERSION") + return version_str + except Exception: + print("WARN: lookup by confstr failed", file=sys.stderr) + pass + + try: + import ctypes + except ImportError: + return None + + pn = ctypes.CDLL(None) + print(dir(pn)) + + try: + gnu_get_libc_version = pn.gnu_get_libc_version + except AttributeError: + return None + + gnu_get_libc_version.restype = ctypes.c_char_p + version_str = gnu_get_libc_version() + + return version_str + + +if __name__ == "__main__": + print(get_libc_version()) diff --git a/build/rocm/tools/symbols.py b/build/rocm/tools/symbols.py new file mode 100644 index 000000000..dc74a0a9b --- /dev/null +++ b/build/rocm/tools/symbols.py @@ -0,0 +1,53 @@ +#!/usr/bin/env python3 + + +# NOTE(mrodden): This file is part of the ROCm build scripts, and +# needs be compatible with Python 3.6. Please do not include these +# in any "upgrade" scripts + + +import pprint +import re +import sys +import subprocess + +""" +Utility for examining GLIBC versioned symbols +for an object file (shared object or ELF binary) +""" + + +def main(): + sofile = sys.argv[1] + + s = highest_for_file(sofile) + + print("%s: %r" % (sofile, s)) + + +def highest_for_file(sofile): + output = subprocess.check_output(["objdump", "-T", sofile]) + + r = re.compile("\(GLIBC_(.*)\)") + versions = {} + + for line in output.decode("utf-8").split("\n"): + line = line.strip() + match = r.search(line) + if match: + version_str = match.group(1) + count = versions.get(version_str, 0) + versions[version_str] = count + 1 + + vtups = list(map(lambda x: parse(x), versions.keys())) + s = sorted(vtups) + + return s[-1] + + +def parse(version_str): + return tuple(map(int, version_str.split("."))) + + +if __name__ == "__main__": + main()