mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
[ROCm] Change ROCm builds to manylinux wheels
This commit is contained in:
parent
ee31e95ecd
commit
1e58d76772
@ -1,6 +1,5 @@
|
||||
################################################################################
|
||||
ARG BASE_DOCKER=ubuntu:20.04
|
||||
FROM $BASE_DOCKER as rt_build
|
||||
FROM ubuntu:20.04 AS rocm_base
|
||||
################################################################################
|
||||
|
||||
# Add target file to help determine which device(s) to build for
|
||||
@ -12,9 +11,9 @@ ARG ROCM_VERSION=6.0.0
|
||||
ARG CUSTOM_INSTALL
|
||||
ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION}
|
||||
ENV ROCM_PATH=${ROCM_PATH}
|
||||
COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL}
|
||||
COPY setup.rocm.sh /setup.rocm.sh
|
||||
RUN /setup.rocm.sh $ROCM_VERSION
|
||||
#COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL}
|
||||
RUN --mount=type=bind,source=build/rocm/setup.rocm.sh,target=/setup.rocm.sh \
|
||||
/setup.rocm.sh $ROCM_VERSION
|
||||
|
||||
# Set up paths
|
||||
ENV HCC_HOME=$ROCM_PATH/hcc
|
||||
@ -25,13 +24,35 @@ ENV PATH="$ROCM_PATH/bin:${PATH}"
|
||||
ENV PATH="$OPENCL_ROOT/bin:${PATH}"
|
||||
ENV PATH="/root/bin:/root/.local/bin:$PATH"
|
||||
|
||||
|
||||
# Install pyenv with different python versions
|
||||
ARG PYTHON_VERSION=3.10.0
|
||||
ARG PYTHON_VERSION=3.10.14
|
||||
RUN git clone https://github.com/pyenv/pyenv.git /pyenv
|
||||
ENV PYENV_ROOT /pyenv
|
||||
ENV PATH $PYENV_ROOT/shims:$PYENV_ROOT/bin:$PATH
|
||||
RUN pyenv install $PYTHON_VERSION
|
||||
RUN eval "$(pyenv init -)" && pyenv local ${PYTHON_VERSION} && pip3 install --upgrade --force-reinstall setuptools pip && pip install numpy setuptools build wheel six auditwheel scipy pytest pytest-html pytest_html_merger pytest-reportlog pytest-rerunfailures cloudpickle portpicker matplotlib absl-py flatbuffers hypothesis
|
||||
RUN eval "$(pyenv init -)" && \
|
||||
pyenv local ${PYTHON_VERSION} && \
|
||||
pip3 install --upgrade --force-reinstall setuptools pip && \
|
||||
pip install \
|
||||
numpy setuptools build wheel six auditwheel scipy \
|
||||
pytest pytest-html pytest_html_merger pytest-reportlog \
|
||||
pytest-rerunfailures cloudpickle portpicker matplotlib absl-py \
|
||||
flatbuffers hypothesis
|
||||
|
||||
|
||||
################################################################################
|
||||
FROM rocm_base AS rt_build
|
||||
################################################################################
|
||||
|
||||
ARG JAX_VERSION
|
||||
ARG JAX_COMMIT
|
||||
ARG XLA_COMMIT
|
||||
|
||||
LABEL com.amdgpu.rocm_version="$ROCM_VERSION" \
|
||||
com.amdgpu.python_version="$PYTHON_VERSION" \
|
||||
com.amdgpu.jax_version="$JAX_VERSION" \
|
||||
com.amdgpu.jax_commit="$JAX_COMMIT" \
|
||||
com.amdgpu.xla_commit="$XLA_COMMIT"
|
||||
|
||||
RUN --mount=type=bind,source=wheelhouse,target=/wheelhouse \
|
||||
pip install --find-links /wheelhouse jax jaxlib jax_rocm60_plugin jax_rocm60_pjrt
|
||||
|
@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -13,57 +14,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Environment Var Notes
|
||||
# XLA_CLONE_DIR -
|
||||
# Specifies filepath to where XLA repo is cloned.
|
||||
# NOTE:, if this is set then XLA repo is not cloned. Must clone repo before running this script.
|
||||
# Also, if this is set then setting XLA_REPO and XLA_BRANCH have no effect.
|
||||
# XLA_REPO
|
||||
# XLA repo to clone from. Default is https://github.com/ROCmSoftwarePlatform/tensorflow-upstream
|
||||
# XLA_BRANCH
|
||||
# XLA branch in the XLA repo. Default is develop-upstream-jax
|
||||
#
|
||||
# NOTE(mrodden): ROCm JAX build and installs have moved to wheel based builds and installs,
|
||||
# but some CI scripts still try to run this script. Nothing needs to be done here,
|
||||
# but we print some debugging information for logs.
|
||||
|
||||
set -eux
|
||||
python -V
|
||||
|
||||
#If XLA_REPO is not set, then use default
|
||||
if [ ! -v XLA_REPO ]; then
|
||||
XLA_REPO="https://github.com/openxla/xla.git"
|
||||
XLA_BRANCH="main"
|
||||
elif [ -z "$XLA_REPO" ]; then
|
||||
XLA_REPO="https://github.com/openxla/xla.git"
|
||||
XLA_BRANCH="main"
|
||||
fi
|
||||
|
||||
#If XLA_CLONE_PATH is not set, then use default path.
|
||||
#Note, setting XLA_CLONE_PATH makes setting XLA_REPO and XLA_BRANCH a no-op
|
||||
#Set this when XLA repository has been already clone. This is useful in CI
|
||||
#environments and when doing local development
|
||||
if [ ! -v XLA_CLONE_DIR ]; then
|
||||
XLA_CLONE_DIR=/tmp/xla
|
||||
rm -rf /tmp/xla || true
|
||||
git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla
|
||||
elif [ -z "$XLA_CLONE_DIR" ]; then
|
||||
XLA_CLONE_DIR=/tmp/xla
|
||||
rm -rf /tmp/xla || true
|
||||
git clone -b ${XLA_BRANCH} ${XLA_REPO} /tmp/xla
|
||||
fi
|
||||
|
||||
|
||||
#Export JAX_ROCM_VERSION so that it is appened in the wheel name
|
||||
export JAXLIB_RELEASE=1
|
||||
rocm_version=$(cat /opt/rocm/.info/version | cut -d "-" -f 1)
|
||||
export JAX_ROCM_VERSION=${rocm_version//./}
|
||||
|
||||
#Build and install wheel
|
||||
python3 ./build/build.py --enable_rocm --build_gpu_plugin --gpu_plugin_rocm_version=60 --rocm_path=${ROCM_PATH} --bazel_options=--override_repository=xla=${XLA_CLONE_DIR}
|
||||
|
||||
JAX_RELEASE=1 python -m build
|
||||
pip3 install --force-reinstall dist/*.whl # installs jaxlib (includes XLA)
|
||||
|
||||
#This is for CI to read without having to start the container again
|
||||
if [ -v CI_RUN ]; then
|
||||
pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1 > jax_version_installed
|
||||
cat /opt/rocm/.info/version | cut -d "-" -f 1 > jax_rocm_version
|
||||
fi
|
||||
printf "Detected jaxlib version: %s\n" $(pip3 list | grep jaxlib | tr -s ' ' | cut -d " " -f 2 | cut -d "+" -f 1)
|
||||
printf "Detected ROCm version: %s\n" $(cat /opt/rocm/.info/version | cut -d "-" -f 1)
|
||||
|
@ -0,0 +1,7 @@
|
||||
FROM quay.io/pypa/manylinux_2_28_x86_64
|
||||
|
||||
ARG ROCM_VERSION=6.1.1
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
|
||||
python3 get_rocm.py --rocm-version $ROCM_VERSION
|
256
build/rocm/ci_build
Executable file
256
build/rocm/ci_build
Executable file
@ -0,0 +1,256 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
# NOTE(mrodden): This file is part of the ROCm build scripts, and
|
||||
# needs be compatible with Python 3.6. Please do not include these
|
||||
# in any "upgrade" scripts
|
||||
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
|
||||
def image_by_name(name):
|
||||
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
|
||||
out = subprocess.check_output(cmd)
|
||||
image_id = out.decode("utf8").strip().split("\n")[0] or None
|
||||
return image_id
|
||||
|
||||
|
||||
def dist_wheels(rocm_version, python_versions, xla_path):
|
||||
xla_path = os.path.abspath(xla_path)
|
||||
|
||||
# create manylinux image with requested ROCm installed
|
||||
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
|
||||
|
||||
cmd = [
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
|
||||
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
||||
"--tag=%s" % image,
|
||||
".",
|
||||
]
|
||||
|
||||
if not image_by_name(image):
|
||||
_ = subprocess.run(cmd, check=True)
|
||||
|
||||
# use image to build JAX/jaxlib wheels
|
||||
os.makedirs("wheelhouse", exist_ok=True)
|
||||
|
||||
pyver_string = ",".join(python_versions)
|
||||
|
||||
container_xla_path = "/xla"
|
||||
|
||||
bw_cmd = [
|
||||
"python3",
|
||||
"/jax/build/rocm/tools/build_wheels.py",
|
||||
"--rocm-version",
|
||||
rocm_version,
|
||||
"--python-versions",
|
||||
pyver_string,
|
||||
]
|
||||
|
||||
if xla_path:
|
||||
bw_cmd.extend(["--xla-path", container_xla_path])
|
||||
|
||||
bw_cmd.append("/jax")
|
||||
|
||||
cmd = ["docker", "run", "-it"]
|
||||
|
||||
mounts = [
|
||||
"-v",
|
||||
"./:/jax",
|
||||
"-v",
|
||||
"./wheelhouse:/wheelhouse",
|
||||
]
|
||||
|
||||
if xla_path:
|
||||
mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)])
|
||||
|
||||
cmd.extend(mounts)
|
||||
|
||||
# NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID
|
||||
cmd.extend(
|
||||
[
|
||||
"--init",
|
||||
"--rm",
|
||||
image,
|
||||
"bash",
|
||||
"-c",
|
||||
" ".join(bw_cmd),
|
||||
]
|
||||
)
|
||||
|
||||
_ = subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def _fetch_jax_metadata(xla_path):
|
||||
cmd = ["git", "rev-parse", "HEAD"]
|
||||
jax_commit = subprocess.check_output(cmd)
|
||||
xla_commit = ""
|
||||
|
||||
if xla_path:
|
||||
try:
|
||||
xla_commit = subprocess.check_output(cmd, cwd=xla_path)
|
||||
except Exception as ex:
|
||||
LOG.warning("Exception while retrieving xla_commit: %s" % ex)
|
||||
|
||||
cmd = ["python", "setup.py", "-V"]
|
||||
env = dict(os.environ)
|
||||
env["JAX_RELEASE"] = "1"
|
||||
|
||||
jax_version = subprocess.check_output(cmd, env=env)
|
||||
|
||||
return {
|
||||
"jax_version": jax_version.decode("utf8").strip(),
|
||||
"jax_commit": jax_commit.decode("utf8").strip(),
|
||||
"xla_commit": xla_commit.decode("utf8").strip(),
|
||||
}
|
||||
|
||||
|
||||
def dist_docker(
|
||||
rocm_version,
|
||||
python_versions,
|
||||
xla_path,
|
||||
tag="rocm/jax-dev",
|
||||
dockerfile=None,
|
||||
keep_image=True,
|
||||
):
|
||||
if not dockerfile:
|
||||
dockerfile = "build/rocm/Dockerfile.ms"
|
||||
|
||||
python_version = python_versions[0]
|
||||
|
||||
md = _fetch_jax_metadata(xla_path)
|
||||
|
||||
cmd = [
|
||||
"docker",
|
||||
"build",
|
||||
"-f",
|
||||
dockerfile,
|
||||
"--target",
|
||||
"rt_build",
|
||||
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
||||
"--build-arg=PYTHON_VERSION=%s" % python_version,
|
||||
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
|
||||
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
|
||||
"--build-arg=XLA_COMMIT=%(xla_commit)s" % md,
|
||||
"--tag=%s" % tag,
|
||||
]
|
||||
|
||||
if not keep_image:
|
||||
cmd.append("--rm")
|
||||
|
||||
# context dir
|
||||
cmd.append(".")
|
||||
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
def test(image_name):
|
||||
"""Run unit tests like CI would inside a JAX image."""
|
||||
|
||||
gpu_args = [
|
||||
"--device=/dev/kfd",
|
||||
"--device=/dev/dri",
|
||||
"--group-add",
|
||||
"video",
|
||||
"--cap-add=SYS_PTRACE",
|
||||
"--security-opt",
|
||||
"seccomp=unconfined",
|
||||
"--shm-size",
|
||||
"16G",
|
||||
]
|
||||
|
||||
cmd = [
|
||||
"docker",
|
||||
"run",
|
||||
"-it",
|
||||
"--rm",
|
||||
]
|
||||
|
||||
# NOTE(mrodden): we need jax source dir for the unit test code only,
|
||||
# JAX and jaxlib are already installed from wheels
|
||||
mounts = [
|
||||
"-v",
|
||||
"./:/jax",
|
||||
]
|
||||
|
||||
cmd.extend(mounts)
|
||||
cmd.extend(gpu_args)
|
||||
|
||||
container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh"
|
||||
cmd.append(image_name)
|
||||
cmd.extend(
|
||||
[
|
||||
"bash",
|
||||
"-c",
|
||||
container_cmd,
|
||||
]
|
||||
)
|
||||
|
||||
subprocess.check_call(cmd)
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument(
|
||||
"--python-versions",
|
||||
type=lambda x: x.split(","),
|
||||
default="3.12",
|
||||
help="Comma separated list of CPython versions to build wheels for",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--rocm-version",
|
||||
default="6.1.1",
|
||||
help="ROCm version used for building wheels, testing, and installing into Docker image",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--xla-source-dir",
|
||||
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
|
||||
)
|
||||
|
||||
subp = p.add_subparsers(dest="action", required=True)
|
||||
|
||||
dwp = subp.add_parser("dist_wheels")
|
||||
|
||||
testp = subp.add_parser("test")
|
||||
testp.add_argument("image_name")
|
||||
|
||||
ddp = subp.add_parser("dist_docker")
|
||||
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")
|
||||
ddp.add_argument("--keep-image", action="store_true")
|
||||
ddp.add_argument("--image-tag", default="rocm/jax-dev")
|
||||
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.action == "dist_wheels":
|
||||
dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir)
|
||||
|
||||
elif args.action == "test":
|
||||
test(args.image_name)
|
||||
|
||||
elif args.action == "dist_docker":
|
||||
dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir)
|
||||
dist_docker(
|
||||
args.rocm_version,
|
||||
args.python_versions,
|
||||
args.xla_source_dir,
|
||||
tag=args.image_tag,
|
||||
dockerfile=args.dockerfile,
|
||||
keep_image=args.keep_image,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,4 +1,5 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@ -28,12 +29,8 @@
|
||||
#
|
||||
# ROCM_VERSION: ROCm repo version
|
||||
#
|
||||
# ROCM_PATH: ROCM path in the docker container
|
||||
#
|
||||
# Environment variables read by this script
|
||||
# WORKSPACE
|
||||
# XLA_REPO
|
||||
# XLA_BRANCH
|
||||
# XLA_CLONE_DIR
|
||||
# BUILD_TAG
|
||||
#
|
||||
@ -47,14 +44,10 @@ CONTAINER_TYPE="rocm"
|
||||
DOCKERFILE_PATH="${SCRIPT_DIR}/Dockerfile.ms"
|
||||
DOCKER_CONTEXT_PATH="${SCRIPT_DIR}"
|
||||
KEEP_IMAGE="--rm"
|
||||
KEEP_CONTAINER="--rm"
|
||||
PYTHON_VERSION="3.10.0"
|
||||
ROCM_VERSION="6.0.0" #Point to latest release
|
||||
PYTHON_VERSION="3.10"
|
||||
ROCM_VERSION="6.1.3"
|
||||
BASE_DOCKER="ubuntu:20.04"
|
||||
CUSTOM_INSTALL=""
|
||||
#BASE_DOCKER="compute-artifactory.amd.com:5000/rocm-plus-docker/compute-rocm-rel-6.0:91-ubuntu-20.04-stg2"
|
||||
#CUSTOM_INSTALL="custom_install_dummy.sh"
|
||||
#ROCM_PATH="/opt/rocm-5.6.0"
|
||||
POSITIONAL_ARGS=()
|
||||
|
||||
RUNTIME_FLAG=1
|
||||
@ -86,11 +79,6 @@ while [[ $# -gt 0 ]]; do
|
||||
ROCM_VERSION="$2"
|
||||
shift 2
|
||||
;;
|
||||
#--rocm_path)
|
||||
# ROCM_PATH="$2"
|
||||
# shift 2
|
||||
# ;;
|
||||
|
||||
*)
|
||||
POSITIONAL_ARGS+=("$1")
|
||||
shift
|
||||
@ -102,9 +90,6 @@ if [[ ! -f "${DOCKERFILE_PATH}" ]]; then
|
||||
die "Invalid Dockerfile path: \"${DOCKERFILE_PATH}\""
|
||||
fi
|
||||
|
||||
ROCM_EXTRA_PARAMS="--device=/dev/kfd --device=/dev/dri --group-add video \
|
||||
--cap-add=SYS_PTRACE --security-opt seccomp=unconfined --shm-size 16G"
|
||||
|
||||
# Helper function to traverse directories up until given file is found.
|
||||
function upsearch (){
|
||||
test / == "$PWD" && return || \
|
||||
@ -126,6 +111,7 @@ DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | sed -e 's/=/_/g' -e 's/,/-/g')
|
||||
# Convert to all lower-case, as per requirement of Docker image names
|
||||
DOCKER_IMG_NAME=$(echo "${DOCKER_IMG_NAME}" | tr '[:upper:]' '[:lower:]')
|
||||
|
||||
|
||||
# Print arguments.
|
||||
echo "WORKSPACE: ${WORKSPACE}"
|
||||
echo "COMMAND: ${POSITIONAL_ARGS[*]}"
|
||||
@ -135,55 +121,25 @@ echo ""
|
||||
|
||||
echo "Building container (${DOCKER_IMG_NAME})..."
|
||||
echo "Python Version (${PYTHON_VERSION})"
|
||||
if [[ "${RUNTIME_FLAG}" -eq 1 ]]; then
|
||||
echo "Building (runtime) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..."
|
||||
docker build --target rt_build --tag ${DOCKER_IMG_NAME} \
|
||||
--build-arg PYTHON_VERSION=$PYTHON_VERSION --build-arg ROCM_VERSION=$ROCM_VERSION \
|
||||
--build-arg CUSTOM_INSTALL=$CUSTOM_INSTALL \
|
||||
--build-arg BASE_DOCKER=$BASE_DOCKER \
|
||||
-f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}"
|
||||
else
|
||||
echo "Building (CI) container (${DOCKER_IMG_NAME}) with Dockerfile($DOCKERFILE_PATH)..."
|
||||
docker build --target ci_build --tag ${DOCKER_IMG_NAME} \
|
||||
--build-arg PYTHON_VERSION=$PYTHON_VERSION \
|
||||
--build-arg BASE_DOCKER=$BASE_DOCKER \
|
||||
-f "${DOCKERFILE_PATH}" "${DOCKER_CONTEXT_PATH}"
|
||||
fi
|
||||
|
||||
# Check docker build status
|
||||
export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
|
||||
|
||||
# ci_build.sh is mostly a compatibility wrapper for ci_build
|
||||
|
||||
# 'dist_docker' will run 'dist_wheels' followed by a Docker build to create the "JAX image",
|
||||
# which is the ROCm image that is shipped for users to use (i.e. distributable).
|
||||
./build/rocm/ci_build \
|
||||
--rocm-version $ROCM_VERSION \
|
||||
--python-versions $PYTHON_VERSION \
|
||||
--xla-source-dir $XLA_CLONE_DIR \
|
||||
dist_docker \
|
||||
--dockerfile $DOCKERFILE_PATH \
|
||||
--image-tag $DOCKER_IMG_NAME
|
||||
|
||||
# Check build status
|
||||
if [[ $? != "0" ]]; then
|
||||
die "ERROR: docker build failed. Dockerfile is at ${DOCKERFILE_PATH}"
|
||||
fi
|
||||
|
||||
# Run the command inside the container.
|
||||
echo "Running '${POSITIONAL_ARGS[*]}' inside ${DOCKER_IMG_NAME}..."
|
||||
|
||||
export XLA_REPO="${XLA_REPO:-}"
|
||||
export XLA_BRANCH="${XLA_BRANCH:-}"
|
||||
export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
|
||||
export JAX_RENAME_WHL="${XLA_CLONE_DIR:-}"
|
||||
|
||||
if [ ! -z ${XLA_CLONE_DIR} ]; then
|
||||
ROCM_EXTRA_PARAMS=${ROCM_EXTRA_PARAMS}" -v ${XLA_CLONE_DIR}:${XLA_CLONE_DIR}"
|
||||
fi
|
||||
|
||||
docker run ${KEEP_IMAGE} --name ${DOCKER_IMG_NAME} --pid=host \
|
||||
-v ${WORKSPACE}:/workspace \
|
||||
-w /workspace \
|
||||
-e XLA_REPO=${XLA_REPO} \
|
||||
-e XLA_BRANCH=${XLA_BRANCH} \
|
||||
-e XLA_CLONE_DIR=${XLA_CLONE_DIR} \
|
||||
-e PYTHON_VERSION=$PYTHON_VERSION \
|
||||
-e CI_RUN=1 \
|
||||
${ROCM_EXTRA_PARAMS} \
|
||||
"${DOCKER_IMG_NAME}" \
|
||||
${POSITIONAL_ARGS[@]}
|
||||
|
||||
if [[ "${KEEP_IMAGE}" != "--rm" ]] && [[ $? == "0" ]]; then
|
||||
echo "Committing the docker container as ${DOCKER_IMG_NAME}"
|
||||
docker stop ${DOCKER_IMG_NAME}
|
||||
docker commit ${DOCKER_IMG_NAME} ${DOCKER_IMG_NAME}
|
||||
docker rm ${DOCKER_IMG_NAME} # remove this temp container
|
||||
fi
|
||||
|
||||
echo "Jax-ROCm build was successful!"
|
||||
|
3
build/rocm/tools/blacken.sh
Normal file
3
build/rocm/tools/blacken.sh
Normal file
@ -0,0 +1,3 @@
|
||||
#!/bin/bash
|
||||
|
||||
black -t py36 build/rocm/ci_build build/rocm/tools/*.py
|
222
build/rocm/tools/build_wheels.py
Normal file
222
build/rocm/tools/build_wheels.py
Normal file
@ -0,0 +1,222 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
# NOTE(mrodden): This file is part of the ROCm build scripts, and
|
||||
# needs be compatible with Python 3.6. Please do not include these
|
||||
# in any "upgrade" scripts
|
||||
|
||||
|
||||
import argparse
|
||||
from collections import deque
|
||||
import fcntl
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import select
|
||||
import subprocess
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
|
||||
|
||||
|
||||
def build_rocm_path(rocm_version_str):
|
||||
return "/opt/rocm-%s" % rocm_version_str
|
||||
|
||||
|
||||
def update_rocm_targets(rocm_path, targets):
|
||||
target_fp = os.path.join(rocm_path, "bin/target.lst")
|
||||
version_fp = os.path.join(rocm_path, ".info/version")
|
||||
with open(target_fp, "w") as fd:
|
||||
fd.write("%s\n" % targets)
|
||||
|
||||
# mimic touch
|
||||
open(version_fp, "a").close()
|
||||
|
||||
|
||||
def build_jaxlib_wheel(jax_path, rocm_path, python_version, xla_path=None):
|
||||
cmd = [
|
||||
"python",
|
||||
"build/build.py",
|
||||
"--enable_rocm",
|
||||
"--build_gpu_plugin",
|
||||
"--gpu_plugin_rocm_version=60",
|
||||
"--rocm_path=%s" % rocm_path,
|
||||
]
|
||||
|
||||
if xla_path:
|
||||
cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path)
|
||||
|
||||
cpy = to_cpy_ver(python_version)
|
||||
py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy)
|
||||
|
||||
env = dict(os.environ)
|
||||
env["JAX_RELEASE"] = str(1)
|
||||
env["JAXLIB_RELEASE"] = str(1)
|
||||
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
|
||||
|
||||
LOG.info("Running %r from cwd=%r" % (cmd, jax_path))
|
||||
pattern = re.compile("Output wheel: (.+)\n")
|
||||
|
||||
return _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr")
|
||||
|
||||
|
||||
def build_jax_wheel(jax_path, python_version):
|
||||
cmd = [
|
||||
"python",
|
||||
"-m",
|
||||
"build",
|
||||
]
|
||||
|
||||
cpy = to_cpy_ver(python_version)
|
||||
py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy)
|
||||
|
||||
env = dict(os.environ)
|
||||
env["JAX_RELEASE"] = str(1)
|
||||
env["JAXLIB_RELEASE"] = str(1)
|
||||
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
|
||||
|
||||
LOG.info("Running %r from cwd=%r" % (cmd, jax_path))
|
||||
pattern = re.compile("Successfully built jax-.+ and (jax-.+\.whl)\n")
|
||||
|
||||
wheels = _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout")
|
||||
|
||||
paths = list(map(lambda x: os.path.join(jax_path, "dist", x), wheels))
|
||||
return paths
|
||||
|
||||
|
||||
def _run_scan_for_output(cmd, pattern, env=None, cwd=None, capture=None):
|
||||
|
||||
buf = deque(maxlen=20000)
|
||||
|
||||
if capture == "stderr":
|
||||
p = subprocess.Popen(cmd, env=env, cwd=cwd, stderr=subprocess.PIPE)
|
||||
redir = sys.stderr
|
||||
cap_fd = p.stderr
|
||||
else:
|
||||
p = subprocess.Popen(cmd, env=env, cwd=cwd, stdout=subprocess.PIPE)
|
||||
redir = sys.stdout
|
||||
cap_fd = p.stdout
|
||||
|
||||
flags = fcntl.fcntl(cap_fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(cap_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
|
||||
eof = False
|
||||
while not eof:
|
||||
r, _, _ = select.select([cap_fd], [], [])
|
||||
for fd in r:
|
||||
dat = fd.read(512)
|
||||
if dat is None:
|
||||
continue
|
||||
elif dat:
|
||||
t = dat.decode("utf8")
|
||||
redir.write(t)
|
||||
buf.extend(t)
|
||||
else:
|
||||
eof = True
|
||||
|
||||
# wait and drain pipes
|
||||
_, _ = p.communicate()
|
||||
|
||||
if p.returncode != 0:
|
||||
raise Exception(
|
||||
"Child process exited with nonzero result: rc=%d" % p.returncode
|
||||
)
|
||||
|
||||
text = "".join(buf)
|
||||
|
||||
matches = pattern.findall(text)
|
||||
|
||||
if not matches:
|
||||
LOG.error("No wheel name found in output: %r" % text)
|
||||
raise Exception("No wheel name found in output")
|
||||
|
||||
wheels = []
|
||||
for match in matches:
|
||||
LOG.info("Found built wheel: %r" % match)
|
||||
wheels.append(match)
|
||||
|
||||
return wheels
|
||||
|
||||
|
||||
def to_cpy_ver(python_version):
|
||||
tup = python_version.split(".")
|
||||
return "cp%d%d" % (int(tup[0]), int(tup[1]))
|
||||
|
||||
|
||||
def fix_wheel(path, jax_path):
|
||||
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
|
||||
# so use one of the CPythons in /opt to run
|
||||
env = dict(os.environ)
|
||||
py_bin = "/opt/python/cp310-cp310/bin"
|
||||
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
|
||||
|
||||
cmd = ["pip", "install", "auditwheel>=6"]
|
||||
subprocess.run(cmd, check=True, env=env)
|
||||
|
||||
fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
|
||||
cmd = ["python", fixwheel_path, path]
|
||||
subprocess.run(cmd, check=True, env=env)
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument(
|
||||
"--rocm-version", default="6.1.1", help="ROCM Version to build JAX against"
|
||||
)
|
||||
p.add_argument(
|
||||
"--python-versions",
|
||||
default=["3.10.19,3.12"],
|
||||
help="Comma separated CPython versions that wheels will be built and output for",
|
||||
)
|
||||
p.add_argument(
|
||||
"--xla-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Optional directory where XLA source is located to use instead of JAX builtin XLA",
|
||||
)
|
||||
|
||||
p.add_argument("jax_path", help="Directory where JAX source directory is located")
|
||||
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
python_versions = args.python_versions.split(",")
|
||||
|
||||
print("ROCM_VERSION=%s" % args.rocm_version)
|
||||
print("PYTHON_VERSIONS=%r" % python_versions)
|
||||
print("JAX_PATH=%s" % args.jax_path)
|
||||
print("XLA_PATH=%s" % args.xla_path)
|
||||
|
||||
rocm_path = build_rocm_path(args.rocm_version)
|
||||
|
||||
update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS)
|
||||
|
||||
for py in python_versions:
|
||||
wheel_paths = build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path)
|
||||
for wheel_path in wheel_paths:
|
||||
fix_wheel(wheel_path, args.jax_path)
|
||||
|
||||
# build JAX wheel for completeness
|
||||
jax_wheels = build_jax_wheel(args.jax_path, python_versions[-1])
|
||||
|
||||
# NOTE(mrodden): the jax wheel is a "non-platform wheel", so auditwheel will
|
||||
# do nothing, and in fact will throw an Exception. we just need to copy it
|
||||
# along with the jaxlib and plugin ones
|
||||
|
||||
# copy jax wheel(s) to wheelhouse
|
||||
wheelhouse_dir = "/wheelhouse/"
|
||||
for whl in jax_wheels:
|
||||
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
|
||||
shutil.copy(whl, wheelhouse_dir)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
97
build/rocm/tools/fixwheel.py
Normal file
97
build/rocm/tools/fixwheel.py
Normal file
@ -0,0 +1,97 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
# NOTE(mrodden): This file is part of the ROCm build scripts, and
|
||||
# needs be compatible with Python 3.6. Please do not include these
|
||||
# in any "upgrade" scripts
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from pprint import pprint
|
||||
import subprocess
|
||||
|
||||
from auditwheel.lddtree import lddtree
|
||||
from auditwheel.wheeltools import InWheelCtx
|
||||
from auditwheel.elfutils import elf_file_filter
|
||||
from auditwheel.policy import WheelPolicies
|
||||
from auditwheel.wheel_abi import analyze_wheel_abi
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def tree(path):
|
||||
|
||||
with InWheelCtx(path) as ctx:
|
||||
for sofile, fd in elf_file_filter(ctx.iter_files()):
|
||||
|
||||
LOG.info("found SO file: %s" % sofile)
|
||||
elftree = lddtree(sofile)
|
||||
|
||||
print(elftree)
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("wheel_path")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def parse_wheel_name(path):
|
||||
wheel_name = os.path.basename(path)
|
||||
return wheel_name[:-4].split("-")
|
||||
|
||||
|
||||
def fix_wheel(path):
|
||||
tup = parse_wheel_name(path)
|
||||
plat_tag = tup[4]
|
||||
if "manylinux2014" in plat_tag:
|
||||
# strip any manylinux tags from the current wheel first
|
||||
from wheel.cli import tags
|
||||
|
||||
plat_mod_str = "linux_x86_64"
|
||||
new_wheel = tags.tags(
|
||||
path,
|
||||
python_tags=None,
|
||||
abi_tags=None,
|
||||
platform_tags=plat_mod_str,
|
||||
build_tag=None,
|
||||
)
|
||||
new_path = os.path.join(os.path.dirname(path), new_wheel)
|
||||
LOG.info("Stripped broken tags and created new wheel at %r" % new_path)
|
||||
path = new_path
|
||||
|
||||
# build excludes, using auditwheels lddtree to find them
|
||||
wheel_pol = WheelPolicies()
|
||||
exclude = frozenset()
|
||||
abi = analyze_wheel_abi(wheel_pol, path, exclude)
|
||||
|
||||
plat = "manylinux_2_28_x86_64"
|
||||
ext_libs = abi.external_refs.get(plat, {}).get("libs")
|
||||
exclude = list(ext_libs.keys())
|
||||
|
||||
# call auditwheel repair with excludes
|
||||
cmd = ["auditwheel", "repair", "--plat", plat, "--only-plat"]
|
||||
|
||||
for ex in exclude:
|
||||
cmd.append("--exclude")
|
||||
cmd.append(ex)
|
||||
|
||||
cmd.append(path)
|
||||
|
||||
LOG.info("running %r" % cmd)
|
||||
|
||||
rc = subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
path = args.wheel_path
|
||||
fix_wheel(path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
100
build/rocm/tools/get_rocm.py
Normal file
100
build/rocm/tools/get_rocm.py
Normal file
@ -0,0 +1,100 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
# NOTE(mrodden): This file is part of the ROCm build scripts, and
|
||||
# needs be compatible with Python 3.6. Please do not include these
|
||||
# in any "upgrade" scripts
|
||||
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import subprocess
|
||||
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def which_linux():
|
||||
try:
|
||||
os_rel = open("/etc/os-release").read()
|
||||
|
||||
kvs = {}
|
||||
for line in os_rel.split("\n"):
|
||||
if line.strip():
|
||||
k, v = line.strip().split("=", 1)
|
||||
v = v.strip('"')
|
||||
kvs[k] = v
|
||||
|
||||
print(kvs)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
rocm_package_names = [
|
||||
"libdrm-amdgpu",
|
||||
"rocm-dev",
|
||||
"rocm-ml-sdk",
|
||||
"miopen-hip ",
|
||||
"miopen-hip-devel",
|
||||
"rocblas",
|
||||
"rocblas-devel",
|
||||
"rocsolver-devel",
|
||||
"rocrand-devel",
|
||||
"rocfft-devel",
|
||||
"hipfft-devel",
|
||||
"hipblas-devel",
|
||||
"rocprim-devel",
|
||||
"hipcub-devel",
|
||||
"rccl-devel",
|
||||
"hipsparse-devel",
|
||||
"hipsolver-devel",
|
||||
]
|
||||
|
||||
|
||||
def install_rocm_el8(rocm_version_str):
|
||||
|
||||
with open("/etc/yum.repos.d/rocm.repo", "w") as rfd:
|
||||
rfd.write(
|
||||
"""
|
||||
[ROCm]
|
||||
name=ROCm
|
||||
baseurl=http://repo.radeon.com/rocm/rhel8/%s/main
|
||||
enabled=1
|
||||
gpgcheck=1
|
||||
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key
|
||||
"""
|
||||
% rocm_version_str
|
||||
)
|
||||
|
||||
with open("/etc/yum.repos.d/amdgpu.repo", "w") as afd:
|
||||
afd.write(
|
||||
"""
|
||||
[amdgpu]
|
||||
name=amdgpu
|
||||
baseurl=https://repo.radeon.com/amdgpu/latest/rhel/8.8/main/x86_64/
|
||||
enabled=1
|
||||
gpgcheck=1
|
||||
gpgkey=https://repo.radeon.com/rocm/rocm.gpg.key
|
||||
"""
|
||||
)
|
||||
|
||||
cmd = ["dnf", "install", "-y"]
|
||||
cmd.extend(rocm_package_names)
|
||||
LOG.info("Running %r" % cmd)
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
|
||||
def parse_args():
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--rocm-version", help="ROCm version to install", default="6.1.1")
|
||||
return p.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
install_rocm_el8(args.rocm_version)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
main()
|
48
build/rocm/tools/libc.py
Normal file
48
build/rocm/tools/libc.py
Normal file
@ -0,0 +1,48 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
# NOTE(mrodden): This file is part of the ROCm build scripts, and
|
||||
# needs be compatible with Python 3.6. Please do not include these
|
||||
# in any "upgrade" scripts
|
||||
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
def get_libc_version():
|
||||
"""
|
||||
Detect and return glibc version that the current Python is linked against.
|
||||
|
||||
This mimics the detection behavior of the 'wheel' and 'auditwheel' projects,
|
||||
but without any PyPy or libmusl support.
|
||||
"""
|
||||
|
||||
try:
|
||||
version_str = os.confstr("CS_GNU_LIBC_VERSION")
|
||||
return version_str
|
||||
except Exception:
|
||||
print("WARN: lookup by confstr failed", file=sys.stderr)
|
||||
pass
|
||||
|
||||
try:
|
||||
import ctypes
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
pn = ctypes.CDLL(None)
|
||||
print(dir(pn))
|
||||
|
||||
try:
|
||||
gnu_get_libc_version = pn.gnu_get_libc_version
|
||||
except AttributeError:
|
||||
return None
|
||||
|
||||
gnu_get_libc_version.restype = ctypes.c_char_p
|
||||
version_str = gnu_get_libc_version()
|
||||
|
||||
return version_str
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(get_libc_version())
|
53
build/rocm/tools/symbols.py
Normal file
53
build/rocm/tools/symbols.py
Normal file
@ -0,0 +1,53 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
|
||||
# NOTE(mrodden): This file is part of the ROCm build scripts, and
|
||||
# needs be compatible with Python 3.6. Please do not include these
|
||||
# in any "upgrade" scripts
|
||||
|
||||
|
||||
import pprint
|
||||
import re
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
"""
|
||||
Utility for examining GLIBC versioned symbols
|
||||
for an object file (shared object or ELF binary)
|
||||
"""
|
||||
|
||||
|
||||
def main():
|
||||
sofile = sys.argv[1]
|
||||
|
||||
s = highest_for_file(sofile)
|
||||
|
||||
print("%s: %r" % (sofile, s))
|
||||
|
||||
|
||||
def highest_for_file(sofile):
|
||||
output = subprocess.check_output(["objdump", "-T", sofile])
|
||||
|
||||
r = re.compile("\(GLIBC_(.*)\)")
|
||||
versions = {}
|
||||
|
||||
for line in output.decode("utf-8").split("\n"):
|
||||
line = line.strip()
|
||||
match = r.search(line)
|
||||
if match:
|
||||
version_str = match.group(1)
|
||||
count = versions.get(version_str, 0)
|
||||
versions[version_str] = count + 1
|
||||
|
||||
vtups = list(map(lambda x: parse(x), versions.keys()))
|
||||
s = sorted(vtups)
|
||||
|
||||
return s[-1]
|
||||
|
||||
|
||||
def parse(version_str):
|
||||
return tuple(map(int, version_str.split(".")))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
x
Reference in New Issue
Block a user