mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00

Use get_rocm.py changes in ci_build to pull in development builds for ROCm. Specify ROCM_BUILD_JOB and ROCM_BUILD_NUM for activating the development build path.
292 lines
6.9 KiB
Python
Executable File
292 lines
6.9 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
|
|
# NOTE(mrodden): This file is part of the ROCm build scripts, and
|
|
# needs be compatible with Python 3.6. Please do not include these
|
|
# in any "upgrade" scripts
|
|
|
|
|
|
import argparse
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
|
|
|
|
def image_by_name(name):
|
|
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
|
|
out = subprocess.check_output(cmd)
|
|
image_id = out.decode("utf8").strip().split("\n")[0] or None
|
|
return image_id
|
|
|
|
|
|
def dist_wheels(
|
|
rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num=""
|
|
):
|
|
if xla_path:
|
|
xla_path = os.path.abspath(xla_path)
|
|
|
|
# create manylinux image with requested ROCm installed
|
|
image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "")
|
|
|
|
cmd = [
|
|
"docker",
|
|
"build",
|
|
"-f",
|
|
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
|
|
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
|
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
|
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
|
"--tag=%s" % image,
|
|
".",
|
|
]
|
|
|
|
if not image_by_name(image):
|
|
_ = subprocess.run(cmd, check=True)
|
|
|
|
# use image to build JAX/jaxlib wheels
|
|
os.makedirs("wheelhouse", exist_ok=True)
|
|
|
|
pyver_string = ",".join(python_versions)
|
|
|
|
container_xla_path = "/xla"
|
|
|
|
bw_cmd = [
|
|
"python3",
|
|
"/jax/build/rocm/tools/build_wheels.py",
|
|
"--rocm-version",
|
|
rocm_version,
|
|
"--python-versions",
|
|
pyver_string,
|
|
]
|
|
|
|
if xla_path:
|
|
bw_cmd.extend(["--xla-path", container_xla_path])
|
|
|
|
bw_cmd.append("/jax")
|
|
|
|
cmd = ["docker", "run", "-it"]
|
|
|
|
mounts = [
|
|
"-v",
|
|
"./:/jax",
|
|
"-v",
|
|
"./wheelhouse:/wheelhouse",
|
|
]
|
|
|
|
if xla_path:
|
|
mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)])
|
|
|
|
cmd.extend(mounts)
|
|
|
|
# NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID
|
|
cmd.extend(
|
|
[
|
|
"--init",
|
|
"--rm",
|
|
image,
|
|
"bash",
|
|
"-c",
|
|
" ".join(bw_cmd),
|
|
]
|
|
)
|
|
|
|
_ = subprocess.run(cmd, check=True)
|
|
|
|
|
|
def _fetch_jax_metadata(xla_path):
|
|
cmd = ["git", "rev-parse", "HEAD"]
|
|
jax_commit = subprocess.check_output(cmd)
|
|
xla_commit = ""
|
|
|
|
if xla_path:
|
|
try:
|
|
xla_commit = subprocess.check_output(cmd, cwd=xla_path)
|
|
except Exception as ex:
|
|
LOG.warning("Exception while retrieving xla_commit: %s" % ex)
|
|
|
|
cmd = ["python", "setup.py", "-V"]
|
|
env = dict(os.environ)
|
|
env["JAX_RELEASE"] = "1"
|
|
|
|
jax_version = subprocess.check_output(cmd, env=env)
|
|
|
|
return {
|
|
"jax_version": jax_version.decode("utf8").strip(),
|
|
"jax_commit": jax_commit.decode("utf8").strip(),
|
|
"xla_commit": xla_commit.decode("utf8").strip(),
|
|
}
|
|
|
|
|
|
def dist_docker(
|
|
rocm_version,
|
|
python_versions,
|
|
xla_path,
|
|
rocm_build_job="",
|
|
rocm_build_num="",
|
|
tag="rocm/jax-dev",
|
|
dockerfile=None,
|
|
keep_image=True,
|
|
):
|
|
if not dockerfile:
|
|
dockerfile = "build/rocm/Dockerfile.ms"
|
|
|
|
python_version = python_versions[0]
|
|
|
|
md = _fetch_jax_metadata(xla_path)
|
|
|
|
cmd = [
|
|
"docker",
|
|
"build",
|
|
"-f",
|
|
dockerfile,
|
|
"--target",
|
|
"rt_build",
|
|
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
|
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
|
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
|
"--build-arg=PYTHON_VERSION=%s" % python_version,
|
|
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
|
|
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
|
|
"--build-arg=XLA_COMMIT=%(xla_commit)s" % md,
|
|
"--tag=%s" % tag,
|
|
]
|
|
|
|
if not keep_image:
|
|
cmd.append("--rm")
|
|
|
|
# context dir
|
|
cmd.append(".")
|
|
|
|
subprocess.check_call(cmd)
|
|
|
|
|
|
def test(image_name):
|
|
"""Run unit tests like CI would inside a JAX image."""
|
|
|
|
gpu_args = [
|
|
"--device=/dev/kfd",
|
|
"--device=/dev/dri",
|
|
"--group-add",
|
|
"video",
|
|
"--cap-add=SYS_PTRACE",
|
|
"--security-opt",
|
|
"seccomp=unconfined",
|
|
"--shm-size",
|
|
"16G",
|
|
]
|
|
|
|
cmd = [
|
|
"docker",
|
|
"run",
|
|
"-it",
|
|
"--rm",
|
|
]
|
|
|
|
# NOTE(mrodden): we need jax source dir for the unit test code only,
|
|
# JAX and jaxlib are already installed from wheels
|
|
mounts = [
|
|
"-v",
|
|
"./:/jax",
|
|
]
|
|
|
|
cmd.extend(mounts)
|
|
cmd.extend(gpu_args)
|
|
|
|
container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh"
|
|
cmd.append(image_name)
|
|
cmd.extend(
|
|
[
|
|
"bash",
|
|
"-c",
|
|
container_cmd,
|
|
]
|
|
)
|
|
|
|
subprocess.check_call(cmd)
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument(
|
|
"--python-versions",
|
|
type=lambda x: x.split(","),
|
|
default="3.12",
|
|
help="Comma separated list of CPython versions to build wheels for",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--rocm-version",
|
|
default="6.1.1",
|
|
help="ROCm version used for building wheels, testing, and installing into Docker image",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--rocm-build-job",
|
|
default="",
|
|
help="ROCm build job for development ROCm builds",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--rocm-build-num",
|
|
default="",
|
|
help="ROCm build number for development ROCm builds",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--xla-source-dir",
|
|
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
|
|
)
|
|
|
|
subp = p.add_subparsers(dest="action", required=True)
|
|
|
|
dwp = subp.add_parser("dist_wheels")
|
|
|
|
testp = subp.add_parser("test")
|
|
testp.add_argument("image_name")
|
|
|
|
ddp = subp.add_parser("dist_docker")
|
|
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")
|
|
ddp.add_argument("--keep-image", action="store_true")
|
|
ddp.add_argument("--image-tag", default="rocm/jax-dev")
|
|
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
if args.action == "dist_wheels":
|
|
dist_wheels(
|
|
args.rocm_version,
|
|
args.python_versions,
|
|
args.xla_source_dir,
|
|
args.rocm_build_job,
|
|
args.rocm_build_num,
|
|
)
|
|
|
|
elif args.action == "test":
|
|
test(args.image_name)
|
|
|
|
elif args.action == "dist_docker":
|
|
dist_wheels(
|
|
args.rocm_version,
|
|
args.python_versions,
|
|
args.xla_source_dir,
|
|
args.rocm_build_job,
|
|
args.rocm_build_num,
|
|
)
|
|
dist_docker(
|
|
args.rocm_version,
|
|
args.python_versions,
|
|
args.xla_source_dir,
|
|
rocm_build_job=args.rocm_build_job,
|
|
rocm_build_num=args.rocm_build_num,
|
|
tag=args.image_tag,
|
|
dockerfile=args.dockerfile,
|
|
keep_image=args.keep_image,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|