#!/usr/bin/env python3 # Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE(mrodden): This file is part of the ROCm build scripts, and # needs be compatible with Python 3.6. Please do not include these # in any "upgrade" scripts import argparse import os import subprocess import sys def image_by_name(name): cmd = ["docker", "images", "-q", "-f", "reference=%s" % name] out = subprocess.check_output(cmd) image_id = out.decode("utf8").strip().split("\n")[0] or None return image_id def dist_wheels( rocm_version, python_versions, xla_path, rocm_build_job="", rocm_build_num="", compiler="gcc", ): if xla_path: xla_path = os.path.abspath(xla_path) # create manylinux image with requested ROCm installed image = "jax-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(".", "") cmd = [ "docker", "build", "-f", "build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm", "--build-arg=ROCM_VERSION=%s" % rocm_version, "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--tag=%s" % image, ".", ] if not image_by_name(image): _ = subprocess.run(cmd, check=True) # use image to build JAX/jaxlib wheels os.makedirs("wheelhouse", exist_ok=True) pyver_string = ",".join(python_versions) container_xla_path = "/xla" bw_cmd = [ "python3", "/jax/build/rocm/tools/build_wheels.py", "--rocm-version", rocm_version, "--python-versions", pyver_string, "--compiler", compiler, ] if xla_path: bw_cmd.extend(["--xla-path", container_xla_path]) bw_cmd.append("/jax") cmd = ["docker", "run"] mounts = [ "-v", "./:/jax", "-v", "./wheelhouse:/wheelhouse", ] if xla_path: mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)]) cmd.extend(mounts) if os.isatty(sys.stdout.fileno()): cmd.append("-it") # NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID cmd.extend( [ "--init", "--rm", image, "bash", "-c", " ".join(bw_cmd), ] ) _ = subprocess.run(cmd, check=True) def _fetch_jax_metadata(xla_path): cmd = ["git", "rev-parse", "HEAD"] jax_commit = subprocess.check_output(cmd) xla_commit = b"" if xla_path: try: xla_commit = subprocess.check_output(cmd, cwd=xla_path) except Exception as ex: LOG.warning("Exception while retrieving xla_commit: %s" % ex) cmd = ["python3", "setup.py", "-V"] env = dict(os.environ) env["JAX_RELEASE"] = "1" jax_version = subprocess.check_output(cmd, env=env) return { "jax_version": jax_version.decode("utf8").strip(), "jax_commit": jax_commit.decode("utf8").strip(), "xla_commit": xla_commit.decode("utf8").strip(), } def dist_docker( rocm_version, base_docker, python_versions, xla_path, rocm_build_job="", rocm_build_num="", tag="rocm/jax-dev", dockerfile=None, keep_image=True, ): if not dockerfile: dockerfile = "build/rocm/Dockerfile.ms" python_version = python_versions[0] md = _fetch_jax_metadata(xla_path) cmd = [ "docker", "build", "-f", dockerfile, "--target", "rt_build", "--build-arg=ROCM_VERSION=%s" % rocm_version, "--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job, "--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num, "--build-arg=BASE_DOCKER=%s" % base_docker, "--build-arg=PYTHON_VERSION=%s" % python_version, "--build-arg=JAX_VERSION=%(jax_version)s" % md, "--build-arg=JAX_COMMIT=%(jax_commit)s" % md, "--build-arg=XLA_COMMIT=%(xla_commit)s" % md, "--tag=%s" % tag, ] if not keep_image: cmd.append("--rm") # context dir cmd.append(".") subprocess.check_call(cmd) def test(image_name): """Run unit tests like CI would inside a JAX image.""" gpu_args = [ "--device=/dev/kfd", "--device=/dev/dri", "--group-add", "video", "--cap-add=SYS_PTRACE", "--security-opt", "seccomp=unconfined", "--shm-size", "16G", ] cmd = [ "docker", "run", "-it", "--rm", ] # NOTE(mrodden): we need jax source dir for the unit test code only, # JAX and jaxlib are already installed from wheels mounts = [ "-v", "./:/jax", ] cmd.extend(mounts) cmd.extend(gpu_args) container_cmd = "cd /jax && ./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh" cmd.append(image_name) cmd.extend( [ "bash", "-c", container_cmd, ] ) subprocess.check_call(cmd) def parse_args(): p = argparse.ArgumentParser() p.add_argument( "--base-docker", default="", help="Argument to override base docker in dockerfile", ) p.add_argument( "--python-versions", type=lambda x: x.split(","), default="3.12", help="Comma separated list of CPython versions to build wheels for", ) p.add_argument( "--rocm-version", default="6.1.1", help="ROCm version used for building wheels, testing, and installing into Docker image", ) p.add_argument( "--rocm-build-job", default="", help="ROCm build job for development ROCm builds", ) p.add_argument( "--rocm-build-num", default="", help="ROCm build number for development ROCm builds", ) p.add_argument( "--xla-source-dir", help="Path to XLA source to use during jaxlib build, instead of builtin XLA", ) p.add_argument( "--compiler", choices=["gcc", "clang"], help="Compiler backend to use when compiling jax/jaxlib", ) subp = p.add_subparsers(dest="action", required=True) dwp = subp.add_parser("dist_wheels") testp = subp.add_parser("test") testp.add_argument("image_name") ddp = subp.add_parser("dist_docker") ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms") ddp.add_argument("--keep-image", action="store_true") ddp.add_argument("--image-tag", default="rocm/jax-dev") return p.parse_args() def main(): args = parse_args() if args.action == "dist_wheels": dist_wheels( args.rocm_version, args.python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, ) elif args.action == "test": test(args.image_name) elif args.action == "dist_docker": dist_wheels( args.rocm_version, args.python_versions, args.xla_source_dir, args.rocm_build_job, args.rocm_build_num, args.compiler, ) dist_docker( args.rocm_version, args.base_docker, args.python_versions, args.xla_source_dir, rocm_build_job=args.rocm_build_job, rocm_build_num=args.rocm_build_num, tag=args.image_tag, dockerfile=args.dockerfile, keep_image=args.keep_image, ) if __name__ == "__main__": main()