#!/usr/bin/env python3 # Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE(mrodden): This file is part of the ROCm build scripts, and # needs be compatible with Python 3.6. Please do not include these # in any "upgrade" scripts import argparse import logging import os import subprocess import sys LOG = logging.getLogger("ci_build") def image_by_name(name): cmd = ["docker", "images", "-q", "-f", "reference=%s" % name] out = subprocess.check_output(cmd) image_id = out.decode("utf8").strip().split("\n")[0] or None return image_id def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num): image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace( ".", "" ) cmd = [ "docker", "build", "-f", "build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm", "--build-arg=ROCM_VERSION=%s" % rocm_version, "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--tag=%s" % image_name, ".", ] LOG.info("Creating manylinux build image. Running: %s", cmd) _ = subprocess.run(cmd, check=True) return image_name def dist_wheels( rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="", compiler="gcc", ): # We want to make sure the wheels we build are manylinux compliant. We'll # do the build in a container. Build the image for this. image_name = create_manylinux_build_image( rocm_version, rocm_build_job, rocm_build_num ) if xla_path: xla_path = os.path.abspath(xla_path) # use image to build JAX/jaxlib wheels os.makedirs("wheelhouse", exist_ok=True) pyver_string = ",".join(python_versions) container_xla_path = "/xla" bw_cmd = [ "python3", "/jax/build/rocm/tools/build_wheels.py", "--rocm-version", rocm_version, "--python-versions", pyver_string, "--compiler", compiler, ] if xla_path: bw_cmd.extend(["--xla-path", container_xla_path]) bw_cmd.append("/jax") cmd = ["docker", "run"] mounts = [ "-v", os.path.abspath("./") + ":/jax", "-v", os.path.abspath("./wheelhouse") + ":/wheelhouse", ] if xla_path: mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)]) cmd.extend(mounts) if os.isatty(sys.stdout.fileno()): cmd.append("-it") # NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID cmd.extend( [ "--init", "--rm", image_name, "bash", "-c", " ".join(bw_cmd), ] ) # Add command for unit tests cmd.extend( [ "&&", "bazel", "test", "-k", "--jobs=4", "--test_verbose_timeout_warnings=true", "--test_output=all", "--test_summary=detailed", "--local_test_jobs=1", "--test_env=JAX_ACCELERATOR_COUNT=%i" % 4, "--test_env=JAX_SKIP_SLOW_TESTS=0", "--verbose_failures=true", "--config=rocm", "--action_env=ROCM_PATH=/opt/rocm", "--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a", "--test_tag_filters=-multiaccelerator", "--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform", "--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow", "//tests:gpu_tests", ] ) LOG.info("Running: %s", cmd) _ = subprocess.run(cmd, check=True) def _fetch_jax_metadata(xla_path): cmd = ["git", "rev-parse", "HEAD"] jax_commit = subprocess.check_output(cmd) xla_commit = b"" if xla_path: try: xla_commit = subprocess.check_output(cmd, cwd=xla_path) except Exception as ex: LOG.warning("Exception while retrieving xla_commit: %s" % ex) cmd = ["python3", "setup.py", "-V"] env = dict(os.environ) env["JAX_RELEASE"] = "1" jax_version = subprocess.check_output(cmd, env=env) def safe_decode(x): if isinstance(x, str): return x else: return x.decode("utf8") return { "jax_version": safe_decode(jax_version).strip(), "jax_commit": safe_decode(jax_commit).strip(), "xla_commit": safe_decode(xla_commit).strip(), } def dist_docker( rocm_version, base_docker, python_versions, xla_path, rocm_build_job="", rocm_build_num="", tag="rocm/jax-dev", dockerfile=None, keep_image=True, ): if not dockerfile: dockerfile = "build/rocm/Dockerfile.ms" python_version = python_versions[0] md = _fetch_jax_metadata(xla_path) cmd = [ "docker", "build", "-f", dockerfile, "--target", "rt_build", "--build-arg=ROCM_VERSION=%s" % rocm_version, "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--build-arg=BASE_DOCKER=%s" % base_docker, "--build-arg=PYTHON_VERSION=%s" % python_version, "--build-arg=JAX_VERSION=%(jax_version)s" % md, "--build-arg=JAX_COMMIT=%(jax_commit)s" % md, "--build-arg=XLA_COMMIT=%(xla_commit)s" % md, "--tag=%s" % tag, ] if not keep_image: cmd.append("--rm") # context dir cmd.append(".") subprocess.check_call(cmd) def test(image_name, test_cmd): """Run unit tests like CI would inside a JAX image.""" gpu_args = [ "--device=/dev/kfd", "--device=/dev/dri", "--group-add", "video", "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined", "--shm-size", "16G", ] cmd = [ "docker", "run", "--rm", ] if os.isatty(sys.stdout.fileno()): cmd.append("-it") # NOTE(mrodden): we need jax source dir for the unit test code only, # JAX and jaxlib are already installed from wheels mounts = [ "-v", os.path.abspath("./") + ":/jax", ] cmd.extend(mounts) cmd.extend(gpu_args) container_cmd = "cd /jax && " + test_cmd cmd.append(image_name) cmd.extend( [ "bash", "-c", container_cmd, ] ) subprocess.check_call(cmd) def parse_args(): p = argparse.ArgumentParser() p.add_argument( "--base-docker", default="", help="Argument to override base docker in dockerfile", ) p.add_argument( "--python-versions", type=lambda x: x.split(","), default="3.12", help="Comma separated list of CPython versions to build wheels for", ) p.add_argument( "--rocm-version", default="6.1.1", help="ROCm version used for building wheels, testing, and installing into Docker image", ) p.add_argument( "--rocm-build-job", default="", help="ROCm build job for development ROCm builds", ) p.add_argument( "--rocm-build-num", default="", help="ROCm build number for development ROCm builds", ) p.add_argument( "--xla-source-dir", help="Path to XLA source to use during jaxlib build, instead of builtin XLA", ) p.add_argument( "--compiler", choices=["gcc", "clang"], help="Compiler backend to use when compiling jax/jaxlib", ) subp = p.add_subparsers(dest="action", required=True) dwp = subp.add_parser("dist_wheels") testp = subp.add_parser("test") testp.add_argument("image_name") testp.add_argument( "--test-cmd", default="./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh", ) ddp = subp.add_parser("dist_docker") ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms") ddp.add_argument("--keep-image", action="store_true") ddp.add_argument("--image-tag", default="rocm/jax-dev") return p.parse_args() def main(): logging.basicConfig(level=logging.INFO) args = parse_args() if args.action == "dist_wheels": dist_wheels( args.rocm_version, args.python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, compiler=args.compiler, ) elif args.action == "test": test(args.image_name, args.test_cmd) elif args.action == "dist_docker": dist_wheels( args.rocm_version, args.python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, compiler=args.compiler, ) dist_docker( args.rocm_version, args.base_docker, args.python_versions, args.xla_source_dir, rocm_build_job=args.rocm_build_job, rocm_build_num=args.rocm_build_num, tag=args.image_tag, dockerfile=args.dockerfile, keep_image=args.keep_image, ) if __name__ == "__main__": main()