rocm_jax/build/rocm/ci_build
Mathew Odden a1a0a4ecdd Add support for ROCm development builds
Use get_rocm.py changes in ci_build to pull in
development builds for ROCm.

Specify ROCM_BUILD_JOB and ROCM_BUILD_NUM for
activating the development build path.
2024-08-12 15:01:34 -05:00

292 lines
6.9 KiB
Python
Executable File

#!/usr/bin/env python3
# NOTE(mrodden): This file is part of the ROCm build scripts, and
# needs be compatible with Python 3.6. Please do not include these
# in any "upgrade" scripts
import argparse
import os
import subprocess
import sys
def image_by_name(name):
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
out = subprocess.check_output(cmd)
image_id = out.decode("utf8").strip().split("\n")[0] or None
return image_id
def dist_wheels(
rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num=""
):
if xla_path:
xla_path = os.path.abspath(xla_path)
# create manylinux image with requested ROCm installed
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
cmd = [
"docker",
"build",
"-f",
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--tag=%s" % image,
".",
]
if not image_by_name(image):
_ = subprocess.run(cmd, check=True)
# use image to build JAX/jaxlib wheels
os.makedirs("wheelhouse", exist_ok=True)
pyver_string = ",".join(python_versions)
container_xla_path = "/xla"
bw_cmd = [
"python3",
"/jax/build/rocm/tools/build_wheels.py",
"--rocm-version",
rocm_version,
"--python-versions",
pyver_string,
]
if xla_path:
bw_cmd.extend(["--xla-path", container_xla_path])
bw_cmd.append("/jax")
cmd = ["docker", "run", "-it"]
mounts = [
"-v",
"./:/jax",
"-v",
"./wheelhouse:/wheelhouse",
]
if xla_path:
mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)])
cmd.extend(mounts)
# NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID
cmd.extend(
[
"--init",
"--rm",
image,
"bash",
"-c",
" ".join(bw_cmd),
]
)
_ = subprocess.run(cmd, check=True)
def _fetch_jax_metadata(xla_path):
cmd = ["git", "rev-parse", "HEAD"]
jax_commit = subprocess.check_output(cmd)
xla_commit = ""
if xla_path:
try:
xla_commit = subprocess.check_output(cmd, cwd=xla_path)
except Exception as ex:
LOG.warning("Exception while retrieving xla_commit: %s" % ex)
cmd = ["python", "setup.py", "-V"]
env = dict(os.environ)
env["JAX_RELEASE"] = "1"
jax_version = subprocess.check_output(cmd, env=env)
return {
"jax_version": jax_version.decode("utf8").strip(),
"jax_commit": jax_commit.decode("utf8").strip(),
"xla_commit": xla_commit.decode("utf8").strip(),
}
def dist_docker(
rocm_version,
python_versions,
xla_path,
rocm_build_job="",
rocm_build_num="",
tag="rocm/jax-dev",
dockerfile=None,
keep_image=True,
):
if not dockerfile:
dockerfile = "build/rocm/Dockerfile.ms"
python_version = python_versions[0]
md = _fetch_jax_metadata(xla_path)
cmd = [
"docker",
"build",
"-f",
dockerfile,
"--target",
"rt_build",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--build-arg=PYTHON_VERSION=%s" % python_version,
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
"--build-arg=XLA_COMMIT=%(xla_commit)s" % md,
"--tag=%s" % tag,
]
if not keep_image:
cmd.append("--rm")
# context dir
cmd.append(".")
subprocess.check_call(cmd)
def test(image_name):
"""Run unit tests like CI would inside a JAX image."""
gpu_args = [
"--device=/dev/kfd",
"--device=/dev/dri",
"--group-add",
"video",
"--cap-add=SYS_PTRACE",
"--security-opt",
"seccomp=unconfined",
"--shm-size",
"16G",
]
cmd = [
"docker",
"run",
"-it",
"--rm",
]
# NOTE(mrodden): we need jax source dir for the unit test code only,
# JAX and jaxlib are already installed from wheels
mounts = [
"-v",
"./:/jax",
]
cmd.extend(mounts)
cmd.extend(gpu_args)
container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh"
cmd.append(image_name)
cmd.extend(
[
"bash",
"-c",
container_cmd,
]
)
subprocess.check_call(cmd)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
"--python-versions",
type=lambda x: x.split(","),
default="3.12",
help="Comma separated list of CPython versions to build wheels for",
)
p.add_argument(
"--rocm-version",
default="6.1.1",
help="ROCm version used for building wheels, testing, and installing into Docker image",
)
p.add_argument(
"--rocm-build-job",
default="",
help="ROCm build job for development ROCm builds",
)
p.add_argument(
"--rocm-build-num",
default="",
help="ROCm build number for development ROCm builds",
)
p.add_argument(
"--xla-source-dir",
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
)
subp = p.add_subparsers(dest="action", required=True)
dwp = subp.add_parser("dist_wheels")
testp = subp.add_parser("test")
testp.add_argument("image_name")
ddp = subp.add_parser("dist_docker")
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")
ddp.add_argument("--keep-image", action="store_true")
ddp.add_argument("--image-tag", default="rocm/jax-dev")
return p.parse_args()
def main():
args = parse_args()
if args.action == "dist_wheels":
dist_wheels(
args.rocm_version,
args.python_versions,
args.xla_source_dir,
args.rocm_build_job,
args.rocm_build_num,
)
elif args.action == "test":
test(args.image_name)
elif args.action == "dist_docker":
dist_wheels(
args.rocm_version,
args.python_versions,
args.xla_source_dir,
args.rocm_build_job,
args.rocm_build_num,
)
dist_docker(
args.rocm_version,
args.python_versions,
args.xla_source_dir,
rocm_build_job=args.rocm_build_job,
rocm_build_num=args.rocm_build_num,
tag=args.image_tag,
dockerfile=args.dockerfile,
keep_image=args.keep_image,
)
if __name__ == "__main__":
main()