mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add support for ROCm development builds
Use get_rocm.py changes in ci_build to pull in development builds for ROCm. Specify ROCM_BUILD_JOB and ROCM_BUILD_NUM for activating the development build path.
This commit is contained in:
parent
3175f13c59
commit
a1a0a4ecdd
@ -2,18 +2,21 @@
|
||||
FROM ubuntu:20.04 AS rocm_base
|
||||
################################################################################
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
apt-get update && apt-get install -y python3
|
||||
|
||||
# Add target file to help determine which device(s) to build for
|
||||
ARG GPU_DEVICE_TARGETS="gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 gfx1100"
|
||||
ENV GPU_DEVICE_TARGETS=${GPU_DEVICE_TARGETS}
|
||||
|
||||
# Install ROCM
|
||||
ARG ROCM_VERSION=6.0.0
|
||||
ARG CUSTOM_INSTALL
|
||||
ARG ROCM_PATH=/opt/rocm-${ROCM_VERSION}
|
||||
ENV ROCM_PATH=${ROCM_PATH}
|
||||
#COPY ${CUSTOM_INSTALL} /${CUSTOM_INSTALL}
|
||||
RUN --mount=type=bind,source=build/rocm/setup.rocm.sh,target=/setup.rocm.sh \
|
||||
/setup.rocm.sh $ROCM_VERSION
|
||||
ARG ROCM_BUILD_JOB
|
||||
ARG ROCM_BUILD_NUM
|
||||
RUN --mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
|
||||
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM
|
||||
|
||||
# Set up paths
|
||||
ENV HCC_HOME=$ROCM_PATH/hcc
|
||||
@ -24,6 +27,10 @@ ENV PATH="$ROCM_PATH/bin:${PATH}"
|
||||
ENV PATH="$OPENCL_ROOT/bin:${PATH}"
|
||||
ENV PATH="/root/bin:/root/.local/bin:$PATH"
|
||||
|
||||
# install pyenv dependencies
|
||||
RUN --mount=type=cache,target=/var/cache/apt \
|
||||
apt-get update && apt-get install -y git libssl-dev
|
||||
|
||||
# Install pyenv with different python versions
|
||||
ARG PYTHON_VERSION=3.10.14
|
||||
RUN git clone https://github.com/pyenv/pyenv.git /pyenv
|
||||
|
@ -1,7 +1,9 @@
|
||||
FROM quay.io/pypa/manylinux_2_28_x86_64
|
||||
|
||||
ARG ROCM_VERSION=6.1.1
|
||||
ARG ROCM_BUILD_JOB
|
||||
ARG ROCM_BUILD_NUM
|
||||
|
||||
RUN --mount=type=cache,target=/var/cache/dnf \
|
||||
--mount=type=bind,source=build/rocm/tools/get_rocm.py,target=get_rocm.py \
|
||||
python3 get_rocm.py --rocm-version $ROCM_VERSION
|
||||
python3 get_rocm.py --rocm-version=$ROCM_VERSION --job-name=$ROCM_BUILD_JOB --build-num=$ROCM_BUILD_NUM
|
||||
|
@ -19,8 +19,11 @@ def image_by_name(name):
|
||||
return image_id
|
||||
|
||||
|
||||
def dist_wheels(rocm_version, python_versions, xla_path):
|
||||
xla_path = os.path.abspath(xla_path)
|
||||
def dist_wheels(
|
||||
rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num=""
|
||||
):
|
||||
if xla_path:
|
||||
xla_path = os.path.abspath(xla_path)
|
||||
|
||||
# create manylinux image with requested ROCm installed
|
||||
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
|
||||
@ -31,6 +34,8 @@ def dist_wheels(rocm_version, python_versions, xla_path):
|
||||
"-f",
|
||||
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
|
||||
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
||||
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
||||
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
||||
"--tag=%s" % image,
|
||||
".",
|
||||
]
|
||||
@ -116,6 +121,8 @@ def dist_docker(
|
||||
rocm_version,
|
||||
python_versions,
|
||||
xla_path,
|
||||
rocm_build_job="",
|
||||
rocm_build_num="",
|
||||
tag="rocm/jax-dev",
|
||||
dockerfile=None,
|
||||
keep_image=True,
|
||||
@ -135,6 +142,8 @@ def dist_docker(
|
||||
"--target",
|
||||
"rt_build",
|
||||
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
||||
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
||||
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
||||
"--build-arg=PYTHON_VERSION=%s" % python_version,
|
||||
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
|
||||
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
|
||||
@ -211,6 +220,18 @@ def parse_args():
|
||||
help="ROCm version used for building wheels, testing, and installing into Docker image",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--rocm-build-job",
|
||||
default="",
|
||||
help="ROCm build job for development ROCm builds",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--rocm-build-num",
|
||||
default="",
|
||||
help="ROCm build number for development ROCm builds",
|
||||
)
|
||||
|
||||
p.add_argument(
|
||||
"--xla-source-dir",
|
||||
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
|
||||
@ -235,17 +256,31 @@ def main():
|
||||
args = parse_args()
|
||||
|
||||
if args.action == "dist_wheels":
|
||||
dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir)
|
||||
dist_wheels(
|
||||
args.rocm_version,
|
||||
args.python_versions,
|
||||
args.xla_source_dir,
|
||||
args.rocm_build_job,
|
||||
args.rocm_build_num,
|
||||
)
|
||||
|
||||
elif args.action == "test":
|
||||
test(args.image_name)
|
||||
|
||||
elif args.action == "dist_docker":
|
||||
dist_wheels(args.rocm_version, args.python_versions, args.xla_source_dir)
|
||||
dist_wheels(
|
||||
args.rocm_version,
|
||||
args.python_versions,
|
||||
args.xla_source_dir,
|
||||
args.rocm_build_job,
|
||||
args.rocm_build_num,
|
||||
)
|
||||
dist_docker(
|
||||
args.rocm_version,
|
||||
args.python_versions,
|
||||
args.xla_source_dir,
|
||||
rocm_build_job=args.rocm_build_job,
|
||||
rocm_build_num=args.rocm_build_num,
|
||||
tag=args.image_tag,
|
||||
dockerfile=args.dockerfile,
|
||||
keep_image=args.keep_image,
|
||||
|
@ -79,6 +79,14 @@ while [[ $# -gt 0 ]]; do
|
||||
ROCM_VERSION="$2"
|
||||
shift 2
|
||||
;;
|
||||
--rocm_job)
|
||||
ROCM_BUILD_JOB="$2"
|
||||
shift 2
|
||||
;;
|
||||
--rocm_build)
|
||||
ROCM_BUILD_NUM="$2"
|
||||
shift 2
|
||||
;;
|
||||
*)
|
||||
POSITIONAL_ARGS+=("$1")
|
||||
shift
|
||||
@ -132,7 +140,9 @@ export XLA_CLONE_DIR="${XLA_CLONE_DIR:-}"
|
||||
./build/rocm/ci_build \
|
||||
--rocm-version $ROCM_VERSION \
|
||||
--python-versions $PYTHON_VERSION \
|
||||
--xla-source-dir $XLA_CLONE_DIR \
|
||||
--xla-source-dir=$XLA_CLONE_DIR \
|
||||
--rocm-build-job=$ROCM_BUILD_JOB \
|
||||
--rocm-build-num=$ROCM_BUILD_NUM \
|
||||
dist_docker \
|
||||
--dockerfile $DOCKERFILE_PATH \
|
||||
--image-tag $DOCKER_IMG_NAME
|
||||
|
@ -25,7 +25,11 @@ GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030 g
|
||||
|
||||
|
||||
def build_rocm_path(rocm_version_str):
|
||||
return "/opt/rocm-%s" % rocm_version_str
|
||||
path = "/opt/rocm-%s" % rocm_version_str
|
||||
if os.path.exists(path):
|
||||
return path
|
||||
else:
|
||||
return os.path.realpath("/opt/rocm")
|
||||
|
||||
|
||||
def update_rocm_targets(rocm_path, targets):
|
||||
|
Loading…
x
Reference in New Issue
Block a user