[ROCm] Change ROCm builds to manylinux wheels

This commit is contained in:
Mathew Odden 2024-07-17 19:09:07 -05:00
parent ee31e95ecd
commit 1e58d76772
12 changed files with 877 additions and 158 deletions

View File

@ -1,6 +1,5 @@
################################################################################
ARG BASE_DOCKER=ubuntu:20.04
FROM $BASE_DOCKER as rt_build
FROM ubuntu:20.04 AS rocm_base
################################################################################
# Add target file to help determine which device(s) to build for
@ -12,9 +11,9 @@ ARG ROCM_VERSION=6.0.0
ARG CUSTOM_INSTALL
ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION}
ENV ROCM_PATH=${ROCM_PATH}
COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL}
COPY setup.rocm.sh /setup.rocm.sh
RUN /setup.rocm.sh $ROCM_VERSION
#COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL}
RUN --mount=type=bind,source=build/rocm/setup.rocm.sh,target=/setup.rocm.sh \
/setup.rocm.sh $ROCM_VERSION
# Set up paths
ENV HCC_HOME=$ROCM_PATH/hcc
@ -25,13 +24,35 @@ ENV PATH="$ROCM_PATH/bin:${PATH}"
ENV PATH="$OPENCL_ROOT/bin:${PATH}"
ENV PATH="/root/bin:/root/.local/bin:$PATH"
# Install pyenv with different python versions
ARG PYTHON_VERSION=3.10.0
ARG PYTHON_VERSION=3.10.14
RUN git clone https://github.com/pyenv/pyenv.git /pyenv
ENV PYENV_ROOT /pyenv
ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH
RUN pyenv install $PYTHON_VERSION
RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip && pip install numpy setuptools build wheel six auditwheel scipy pytest pytest-html pytest_html_merger pytest-reportlog pytest-rerunfailures cloudpickle portpicker matplotlib absl-py flatbuffers hypothesis
RUN eval "$(pyenv init -)" && \
pyenv local ${PYTHON_VERSION} && \
pip3 install --upgrade --force-reinstall setuptools pip && \
pip install \
numpy setuptools build wheel six auditwheel scipy \
pytest pytest-html pytest_html_merger pytest-reportlog \
pytest-rerunfailures cloudpickle portpicker matplotlib absl-py \
flatbuffers hypothesis
################################################################################
FROM rocm_base AS rt_build
################################################################################
ARG JAX_VERSION
ARG JAX_COMMIT
ARG XLA_COMMIT
LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
com.amdgpu.python_version="$PYTHON_VERSION" \
com.amdgpu.jax_version="$JAX_VERSION" \
com.amdgpu.jax_commit="$JAX_COMMIT" \
com.amdgpu.xla_commit="$XLA_COMMIT"
RUN --mount=type=bind,source=wheelhouse,target=/wheelhouse \
pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt

View File

@ -1,4 +1,5 @@
#!/usr/bin/env bash
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -13,57 +14,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Environment Var Notes
# XLA_CLONE_DIR -
# Specifies filepath to where XLA repo is cloned.
# NOTE:, if this is set then XLA repo is not cloned. Must clone repo before running this script.
# Also, if this is set then setting XLA_REPO and XLA_BRANCH have no effect.
# XLA_REPO
# XLA repo to clone from. Default is https://github.com/ROCmSoftwarePlatform/tensorflow-upstream
# XLA_BRANCH
# XLA branch in the XLA repo. Default is develop-upstream-jax
#
# NOTE(mrodden): ROCm JAX build and installs have moved to wheel based builds and installs,
# but some CI scripts still try to run this script. Nothing needs to be done here,
# but we print some debugging information for logs.
set -eux
python -V
#If XLA_REPO is not set, then use default
if [ ! -v XLA_REPO ]; then
XLA_REPO="https://github.com/openxla/xla.git"
XLA_BRANCH="main"
elif [ -z "$XLA_REPO" ]; then
XLA_REPO="https://github.com/openxla/xla.git"
XLA_BRANCH="main"
fi
#If XLA_CLONE_PATH is not set, then use default path.
#Note, setting XLA_CLONE_PATH makes setting XLA_REPO and XLA_BRANCH a no-op
#Set this when XLA repository has been already clone. This is useful in CI
#environments and when doing local development
if [ ! -v XLA_CLONE_DIR ]; then
XLA_CLONE_DIR=/tmp/xla
rm -rf /tmp/xla || true
git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla
elif [ -z "$XLA_CLONE_DIR" ]; then
XLA_CLONE_DIR=/tmp/xla
rm -rf /tmp/xla || true
git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla
fi
#Export JAX_ROCM_VERSION so that it is appened in the wheel name
export JAXLIB_RELEASE=1
rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1)
export JAX_ROCM_VERSION=${rocm_version//./}
#Build and install wheel
python3 ./build/build.py --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR}
JAX_RELEASE=1 python -m build
pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA)
#This is for CI to read without having to start the container again
if [ -v CI_RUN ]; then
pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1 > jax_version_installed
cat /opt/rocm/.info/version | cut -d "-" -f 1 > jax_rocm_version
fi
printf "Detected jaxlib version: %s\n" $(pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1)
printf "Detected ROCm version: %s\n" $(cat /opt/rocm/.info/version | cut -d "-" -f 1)

View File

@ -0,0 +1,7 @@
FROM quay.io/pypa/manylinux_2_28_x86_64
ARG ROCM_VERSION=6.1.1
RUN --mount=type=cache,target=/var/cache/dnf \
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
python3 get_rocm.py --rocm-version $ROCM_VERSION

256
build/rocm/ci_build Executable file
View File

@ -0,0 +1,256 @@
#!/usr/bin/env python3
# NOTE(mrodden): This file is part of the ROCm build scripts, and
# needs be compatible with Python 3.6. Please do not include these
# in any "upgrade" scripts
import argparse
import os
import subprocess
import sys
def image_by_name(name):
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
out = subprocess.check_output(cmd)
image_id = out.decode("utf8").strip().split("\n")[0] or None
return image_id
def dist_wheels(rocm_version, python_versions, xla_path):
xla_path = os.path.abspath(xla_path)
# create manylinux image with requested ROCm installed
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
cmd = [
"docker",
"build",
"-f",
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--tag=%s" % image,
".",
]
if not image_by_name(image):
_ = subprocess.run(cmd, check=True)
# use image to build JAX/jaxlib wheels
os.makedirs("wheelhouse", exist_ok=True)
pyver_string = ",".join(python_versions)
container_xla_path = "/xla"
bw_cmd = [
"python3",
"/jax/build/rocm/tools/build_wheels.py",
"--rocm-version",
rocm_version,
"--python-versions",
pyver_string,
]
if xla_path:
bw_cmd.extend(["--xla-path", container_xla_path])
bw_cmd.append("/jax")
cmd = ["docker", "run", "-it"]
mounts = [
"-v",
"./:/jax",
"-v",
"./wheelhouse:/wheelhouse",
]
if xla_path:
mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)])
cmd.extend(mounts)
# NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID
cmd.extend(
[
"--init",
"--rm",
image,
"bash",
"-c",
" ".join(bw_cmd),
]
)
_ = subprocess.run(cmd, check=True)
def _fetch_jax_metadata(xla_path):
cmd = ["git", "rev-parse", "HEAD"]
jax_commit = subprocess.check_output(cmd)
xla_commit = ""
if xla_path:
try:
xla_commit = subprocess.check_output(cmd, cwd=xla_path)
except Exception as ex:
LOG.warning("Exception while retrieving xla_commit: %s" % ex)
cmd = ["python", "setup.py", "-V"]
env = dict(os.environ)
env["JAX_RELEASE"] = "1"
jax_version = subprocess.check_output(cmd, env=env)
return {
"jax_version": jax_version.decode("utf8").strip(),
"jax_commit": jax_commit.decode("utf8").strip(),
"xla_commit": xla_commit.decode("utf8").strip(),
}
def dist_docker(
rocm_version,
python_versions,
xla_path,
tag="rocm/jax-dev",
dockerfile=None,
keep_image=True,
):
if not dockerfile:
dockerfile = "build/rocm/Dockerfile.ms"
python_version = python_versions[0]
md = _fetch_jax_metadata(xla_path)
cmd = [
"docker",
"build",
"-f",
dockerfile,
"--target",
"rt_build",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=PYTHON_VERSION=%s" % python_version,
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
"--build-arg=XLA_COMMIT=%(xla_commit)s" % md,
"--tag=%s" % tag,
]
if not keep_image:
cmd.append("--rm")
# context dir
cmd.append(".")
subprocess.check_call(cmd)
def test(image_name):
"""Run unit tests like CI would inside a JAX image."""
gpu_args = [
"--device=/dev/kfd",
"--device=/dev/dri",
"--group-add",
"video",
"--cap-add=SYS_PTRACE",
"--security-opt",
"seccomp=unconfined",
"--shm-size",
"16G",
]
cmd = [
"docker",
"run",
"-it",
"--rm",
]
# NOTE(mrodden): we need jax source dir for the unit test code only,
# JAX and jaxlib are already installed from wheels
mounts = [
"-v",
"./:/jax",
]
cmd.extend(mounts)
cmd.extend(gpu_args)
container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh"
cmd.append(image_name)
cmd.extend(
[
"bash",
"-c",
container_cmd,
]
)
subprocess.check_call(cmd)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
"--python-versions",
type=lambda x: x.split(","),
default="3.12",
help="Comma separated list of CPython versions to build wheels for",
)
p.add_argument(
"--rocm-version",
default="6.1.1",
help="ROCm version used for building wheels, testing, and installing into Docker image",
)
p.add_argument(
"--xla-source-dir",
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
)
subp = p.add_subparsers(dest="action", required=True)
dwp = subp.add_parser("dist_wheels")
testp = subp.add_parser("test")
testp.add_argument("image_name")
ddp = subp.add_parser("dist_docker")
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")
ddp.add_argument("--keep-image", action="store_true")
ddp.add_argument("--image-tag", default="rocm/jax-dev")
return p.parse_args()
def main():
args = parse_args()
if args.action == "dist_wheels":
dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir)
elif args.action == "test":
test(args.image_name)
elif args.action == "dist_docker":
dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir)
dist_docker(
args.rocm_version,
args.python_versions,
args.xla_source_dir,
tag=args.image_tag,
dockerfile=args.dockerfile,
keep_image=args.keep_image,
)
if __name__ == "__main__":
main()

View File

@ -1,4 +1,5 @@
#!/usr/bin/env bash
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@ -28,12 +29,8 @@
#
# ROCM_VERSION: ROCm repo version
#
# ROCM_PATH: ROCM path in the docker container
#
# Environment variables read by this script
# WORKSPACE
# XLA_REPO
# XLA_BRANCH
# XLA_CLONE_DIR
# BUILD_TAG
#
@ -47,14 +44,10 @@ CONTAINER_TYPE="rocm"
DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms"
DOCKER_CONTEXT_PATH="${SCRIPT_DIR}"
KEEP_IMAGE="--rm"
KEEP_CONTAINER="--rm"
PYTHON_VERSION="3.10.0"
ROCM_VERSION="6.0.0" #Point to latest release
PYTHON_VERSION="3.10"
ROCM_VERSION="6.1.3"
BASE_DOCKER="ubuntu:20.04"
CUSTOM_INSTALL=""
#BASE_DOCKER="compute-artifactory.amd.com:5000/rocm-plus-docker/compute-rocm-rel-6.0:91-ubuntu-20.04-stg2"
#CUSTOM_INSTALL="custom_install_dummy.sh"
#ROCM_PATH="/opt/rocm-5.6.0"
POSITIONAL_ARGS=()
RUNTIME_FLAG=1
@ -86,11 +79,6 @@ while [[ $# -gt 0 ]]; do
ROCM_VERSION="$2"
shift 2
;;
#--rocm_path)
# ROCM_PATH="$2"
# shift 2
# ;;
*)
POSITIONAL_ARGS+=("$1")
shift
@ -102,9 +90,6 @@ if [[ ! -f "${DOCKERFILE_PATH}" ]]; then
die "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\""
fi
ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video \
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G"
# Helper function to traverse directories up until given file is found.
function upsearch (){
test / == "$PWD" && return || \
@ -126,6 +111,7 @@ DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | sed -e 's/=/_/g' -e 's/,/-/g')
# Convert to all lower-case, as per requirement of Docker image names
DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]')
# Print arguments.
echo "WORKSPACE: ${WORKSPACE}"
echo "COMMAND: ${POSITIONAL_ARGS[*]}"
@ -135,55 +121,25 @@ echo ""
echo "Building container (${DOCKER_IMG_NAME})..."
echo "Python Version (${PYTHON_VERSION})"
if [[ "${RUNTIME_FLAG}" -eq 1 ]]; then
echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..."
docker build --target rt_build --tag ${DOCKER_IMG_NAME} \
--build-arg PYTHON_VERSION=$PYTHON_VERSION --build-arg ROCM_VERSION=$ROCM_VERSION \
--build-arg CUSTOM_INSTALL=$CUSTOM_INSTALL \
--build-arg BASE_DOCKER=$BASE_DOCKER \
-f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}"
else
echo "Building (CI) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..."
docker build --target ci_build --tag ${DOCKER_IMG_NAME} \
--build-arg PYTHON_VERSION=$PYTHON_VERSION \
--build-arg BASE_DOCKER=$BASE_DOCKER \
-f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}"
fi
# Check docker build status
export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
# ci_build.sh is mostly a compatibility wrapper for ci_build
# 'dist_docker' will run 'dist_wheels' followed by a Docker build to create the "JAX image",
# which is the ROCm image that is shipped for users to use (i.e. distributable).
./build/rocm/ci_build \
--rocm-version $ROCM_VERSION \
--python-versions $PYTHON_VERSION \
--xla-source-dir $XLA_CLONE_DIR \
dist_docker \
--dockerfile $DOCKERFILE_PATH \
--image-tag $DOCKER_IMG_NAME
# Check build status
if [[ $? != "0" ]]; then
die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}"
fi
# Run the command inside the container.
echo "Running '${POSITIONAL_ARGS[*]}' inside ${DOCKER_IMG_NAME}..."
export XLA_REPO="${XLA_REPO:-}"
export XLA_BRANCH="${XLA_BRANCH:-}"
export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
export JAX_RENAME_WHL="${XLA_CLONE_DIR:-}"
if [ ! -z ${XLA_CLONE_DIR} ]; then
ROCM_EXTRA_PARAMS=${ROCM_EXTRA_PARAMS}" -v ${XLA_CLONE_DIR}:${XLA_CLONE_DIR}"
fi
docker run ${KEEP_IMAGE} --name ${DOCKER_IMG_NAME} --pid=host \
-v ${WORKSPACE}:/workspace \
-w /workspace \
-e XLA_REPO=${XLA_REPO} \
-e XLA_BRANCH=${XLA_BRANCH} \
-e XLA_CLONE_DIR=${XLA_CLONE_DIR} \
-e PYTHON_VERSION=$PYTHON_VERSION \
-e CI_RUN=1 \
${ROCM_EXTRA_PARAMS} \
"${DOCKER_IMG_NAME}" \
${POSITIONAL_ARGS[@]}
if [[ "${KEEP_IMAGE}" != "--rm" ]] && [[ $? == "0" ]]; then
echo "Committing the docker container as ${DOCKER_IMG_NAME}"
docker stop ${DOCKER_IMG_NAME}
docker commit ${DOCKER_IMG_NAME} ${DOCKER_IMG_NAME}
docker rm ${DOCKER_IMG_NAME} # remove this temp container
fi
echo "Jax-ROCm build was successful!"

View File

@ -0,0 +1,3 @@
#!/bin/bash
black -t py36 build/rocm/ci_build build/rocm/tools/*.py

View File

@ -0,0 +1,222 @@
#!/usr/bin/env python3
# NOTE(mrodden): This file is part of the ROCm build scripts, and
# needs be compatible with Python 3.6. Please do not include these
# in any "upgrade" scripts
import argparse
from collections import deque
import fcntl
import logging
import os
import re
import select
import subprocess
import shutil
import sys
LOG = logging.getLogger(__name__)
GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
def build_rocm_path(rocm_version_str):
return "/opt/rocm-%s" % rocm_version_str
def update_rocm_targets(rocm_path, targets):
target_fp = os.path.join(rocm_path, "bin/target.lst")
version_fp = os.path.join(rocm_path, ".info/version")
with open(target_fp, "w") as fd:
fd.write("%s\n" % targets)
# mimic touch
open(version_fp, "a").close()
def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None):
cmd = [
"python",
"build/build.py",
"--enable_rocm",
"--build_gpu_plugin",
"--gpu_plugin_rocm_version=60",
"--rocm_path=%s" % rocm_path,
]
if xla_path:
cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path)
cpy = to_cpy_ver(python_version)
py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy)
env = dict(os.environ)
env["JAX_RELEASE"] = str(1)
env["JAXLIB_RELEASE"] = str(1)
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
LOG.info("Running %r from cwd=%r" % (cmd, jax_path))
pattern = re.compile("Output wheel: (.+)\n")
return _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr")
def build_jax_wheel(jax_path, python_version):
cmd = [
"python",
"-m",
"build",
]
cpy = to_cpy_ver(python_version)
py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy)
env = dict(os.environ)
env["JAX_RELEASE"] = str(1)
env["JAXLIB_RELEASE"] = str(1)
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
LOG.info("Running %r from cwd=%r" % (cmd, jax_path))
pattern = re.compile("Successfully built jax-.+ and (jax-.+\.whl)\n")
wheels = _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout")
paths = list(map(lambda x: os.path.join(jax_path, "dist", x), wheels))
return paths
def _run_scan_for_output(cmd, pattern, env=None, cwd=None, capture=None):
buf = deque(maxlen=20000)
if capture == "stderr":
p = subprocess.Popen(cmd, env=env, cwd=cwd, stderr=subprocess.PIPE)
redir = sys.stderr
cap_fd = p.stderr
else:
p = subprocess.Popen(cmd, env=env, cwd=cwd, stdout=subprocess.PIPE)
redir = sys.stdout
cap_fd = p.stdout
flags = fcntl.fcntl(cap_fd, fcntl.F_GETFL)
fcntl.fcntl(cap_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
eof = False
while not eof:
r, _, _ = select.select([cap_fd], [], [])
for fd in r:
dat = fd.read(512)
if dat is None:
continue
elif dat:
t = dat.decode("utf8")
redir.write(t)
buf.extend(t)
else:
eof = True
# wait and drain pipes
_, _ = p.communicate()
if p.returncode != 0:
raise Exception(
"Child process exited with nonzero result: rc=%d" % p.returncode
)
text = "".join(buf)
matches = pattern.findall(text)
if not matches:
LOG.error("No wheel name found in output: %r" % text)
raise Exception("No wheel name found in output")
wheels = []
for match in matches:
LOG.info("Found built wheel: %r" % match)
wheels.append(match)
return wheels
def to_cpy_ver(python_version):
tup = python_version.split(".")
return "cp%d%d" % (int(tup[0]), int(tup[1]))
def fix_wheel(path, jax_path):
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
# so use one of the CPythons in /opt to run
env = dict(os.environ)
py_bin = "/opt/python/cp310-cp310/bin"
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
cmd = ["pip", "install", "auditwheel>=6"]
subprocess.run(cmd, check=True, env=env)
fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
cmd = ["python", fixwheel_path, path]
subprocess.run(cmd, check=True, env=env)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
"--rocm-version", default="6.1.1", help="ROCM Version to build JAX against"
)
p.add_argument(
"--python-versions",
default=["3.10.19,3.12"],
help="Comma separated CPython versions that wheels will be built and output for",
)
p.add_argument(
"--xla-path",
type=str,
default=None,
help="Optional directory where XLA source is located to use instead of JAX builtin XLA",
)
p.add_argument("jax_path", help="Directory where JAX source directory is located")
return p.parse_args()
def main():
args = parse_args()
python_versions = args.python_versions.split(",")
print("ROCM_VERSION=%s" % args.rocm_version)
print("PYTHON_VERSIONS=%r" % python_versions)
print("JAX_PATH=%s" % args.jax_path)
print("XLA_PATH=%s" % args.xla_path)
rocm_path = build_rocm_path(args.rocm_version)
update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS)
for py in python_versions:
wheel_paths = build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path)
for wheel_path in wheel_paths:
fix_wheel(wheel_path, args.jax_path)
# build JAX wheel for completeness
jax_wheels = build_jax_wheel(args.jax_path, python_versions[-1])
# NOTE(mrodden): the jax wheel is a "non-platform wheel", so auditwheel will
# do nothing, and in fact will throw an Exception. we just need to copy it
# along with the jaxlib and plugin ones
# copy jax wheel(s) to wheelhouse
wheelhouse_dir = "/wheelhouse/"
for whl in jax_wheels:
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
shutil.copy(whl, wheelhouse_dir)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

View File

@ -0,0 +1,97 @@
#!/usr/bin/env python3
# NOTE(mrodden): This file is part of the ROCm build scripts, and
# needs be compatible with Python 3.6. Please do not include these
# in any "upgrade" scripts
import argparse
import logging
import os
from pprint import pprint
import subprocess
from auditwheel.lddtree import lddtree
from auditwheel.wheeltools import InWheelCtx
from auditwheel.elfutils import elf_file_filter
from auditwheel.policy import WheelPolicies
from auditwheel.wheel_abi import analyze_wheel_abi
LOG = logging.getLogger(__name__)
def tree(path):
with InWheelCtx(path) as ctx:
for sofile, fd in elf_file_filter(ctx.iter_files()):
LOG.info("found SO file: %s" % sofile)
elftree = lddtree(sofile)
print(elftree)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("wheel_path")
return p.parse_args()
def parse_wheel_name(path):
wheel_name = os.path.basename(path)
return wheel_name[:-4].split("-")
def fix_wheel(path):
tup = parse_wheel_name(path)
plat_tag = tup[4]
if "manylinux2014" in plat_tag:
# strip any manylinux tags from the current wheel first
from wheel.cli import tags
plat_mod_str = "linux_x86_64"
new_wheel = tags.tags(
path,
python_tags=None,
abi_tags=None,
platform_tags=plat_mod_str,
build_tag=None,
)
new_path = os.path.join(os.path.dirname(path), new_wheel)
LOG.info("Stripped broken tags and created new wheel at %r" % new_path)
path = new_path
# build excludes, using auditwheels lddtree to find them
wheel_pol = WheelPolicies()
exclude = frozenset()
abi = analyze_wheel_abi(wheel_pol, path, exclude)
plat = "manylinux_2_28_x86_64"
ext_libs = abi.external_refs.get(plat, {}).get("libs")
exclude = list(ext_libs.keys())
# call auditwheel repair with excludes
cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"]
for ex in exclude:
cmd.append("--exclude")
cmd.append(ex)
cmd.append(path)
LOG.info("running %r" % cmd)
rc = subprocess.run(cmd, check=True)
def main():
args = parse_args()
path = args.wheel_path
fix_wheel(path)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

View File

@ -0,0 +1,100 @@
#!/usr/bin/env python3
# NOTE(mrodden): This file is part of the ROCm build scripts, and
# needs be compatible with Python 3.6. Please do not include these
# in any "upgrade" scripts
import argparse
import logging
import subprocess
LOG = logging.getLogger(__name__)
def which_linux():
try:
os_rel = open("/etc/os-release").read()
kvs = {}
for line in os_rel.split("\n"):
if line.strip():
k, v = line.strip().split("=", 1)
v = v.strip('"')
kvs[k] = v
print(kvs)
except OSError:
pass
rocm_package_names = [
"libdrm-amdgpu",
"rocm-dev",
"rocm-ml-sdk",
"miopen-hip ",
"miopen-hip-devel",
"rocblas",
"rocblas-devel",
"rocsolver-devel",
"rocrand-devel",
"rocfft-devel",
"hipfft-devel",
"hipblas-devel",
"rocprim-devel",
"hipcub-devel",
"rccl-devel",
"hipsparse-devel",
"hipsolver-devel",
]
def install_rocm_el8(rocm_version_str):
with open("/etc/yum.repos.d/rocm.repo", "w") as rfd:
rfd.write(
"""
[ROCm]
name=ROCm
baseurl=http://repo.radeon.com/rocm/rhel8/%s/main
enabled=1
gpgcheck=1
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key
"""
% rocm_version_str
)
with open("/etc/yum.repos.d/amdgpu.repo", "w") as afd:
afd.write(
"""
[amdgpu]
name=amdgpu
baseurl=https://repo.radeon.com/amdgpu/latest/rhel/8.8/main/x86_64/
enabled=1
gpgcheck=1
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key
"""
)
cmd = ["dnf", "install", "-y"]
cmd.extend(rocm_package_names)
LOG.info("Running %r" % cmd)
subprocess.run(cmd, check=True)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--rocm-version", help="ROCm version to install", default="6.1.1")
return p.parse_args()
def main():
args = parse_args()
install_rocm_el8(args.rocm_version)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()

48
build/rocm/tools/libc.py Normal file
View File

@ -0,0 +1,48 @@
#!/usr/bin/env python3
# NOTE(mrodden): This file is part of the ROCm build scripts, and
# needs be compatible with Python 3.6. Please do not include these
# in any "upgrade" scripts
import os
import sys
def get_libc_version():
"""
Detect and return glibc version that the current Python is linked against.
This mimics the detection behavior of the 'wheel' and 'auditwheel' projects,
but without any PyPy or libmusl support.
"""
try:
version_str = os.confstr("CS_GNU_LIBC_VERSION")
return version_str
except Exception:
print("WARN: lookup by confstr failed", file=sys.stderr)
pass
try:
import ctypes
except ImportError:
return None
pn = ctypes.CDLL(None)
print(dir(pn))
try:
gnu_get_libc_version = pn.gnu_get_libc_version
except AttributeError:
return None
gnu_get_libc_version.restype = ctypes.c_char_p
version_str = gnu_get_libc_version()
return version_str
if __name__ == "__main__":
print(get_libc_version())

View File

@ -0,0 +1,53 @@
#!/usr/bin/env python3
# NOTE(mrodden): This file is part of the ROCm build scripts, and
# needs be compatible with Python 3.6. Please do not include these
# in any "upgrade" scripts
import pprint
import re
import sys
import subprocess
"""
Utility for examining GLIBC versioned symbols
for an object file (shared object or ELF binary)
"""
def main():
sofile = sys.argv[1]
s = highest_for_file(sofile)
print("%s: %r" % (sofile, s))
def highest_for_file(sofile):
output = subprocess.check_output(["objdump", "-T", sofile])
r = re.compile("\(GLIBC_(.*)\)")
versions = {}
for line in output.decode("utf-8").split("\n"):
line = line.strip()
match = r.search(line)
if match:
version_str = match.group(1)
count = versions.get(version_str, 0)
versions[version_str] = count + 1
vtups = list(map(lambda x: parse(x), versions.keys()))
s = sorted(vtups)
return s[-1]
def parse(version_str):
return tuple(map(int, version_str.split(".")))
if __name__ == "__main__":
main()