rocm_jax/build/rocm/ci_build
Mathew Odden ec4b8ee1ed Fixes for 0.5.0 build ported to rocm-main
(cherry picked from commit c23a81461192a2b6da3d364076a261714d2dc64f)
2025-03-25 17:51:30 -05:00

386 lines
9.7 KiB
Python
Executable File

#!/usr/bin/env python3
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE(mrodden): This file is part of the ROCm build scripts, and
# needs be compatible with Python 3.6. Please do not include these
# in any "upgrade" scripts
import argparse
import logging
import os
import subprocess
import sys
LOG = logging.getLogger("ci_build")
def image_by_name(name):
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
out = subprocess.check_output(cmd)
image_id = out.decode("utf8").strip().split("\n")[0] or None
return image_id
def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(
".", ""
)
cmd = [
"docker",
"build",
"-f",
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--tag=%s" % image_name,
".",
]
LOG.info("Creating manylinux build image. Running: %s", cmd)
_ = subprocess.run(cmd, check=True)
return image_name
def dist_wheels(
rocm_version,
python_versions,
xla_path,
rocm_build_job="",
rocm_build_num="",
compiler="gcc",
):
# We want to make sure the wheels we build are manylinux compliant. We'll
# do the build in a container. Build the image for this.
image_name = create_manylinux_build_image(
rocm_version, rocm_build_job, rocm_build_num
)
if xla_path:
xla_path = os.path.abspath(xla_path)
# use image to build JAX/jaxlib wheels
os.makedirs("wheelhouse", exist_ok=True)
pyver_string = ",".join(python_versions)
container_xla_path = "/xla"
bw_cmd = [
"python3",
"/jax/build/rocm/tools/build_wheels.py",
"--rocm-version",
rocm_version,
"--python-versions",
pyver_string,
"--compiler",
compiler,
]
if xla_path:
bw_cmd.extend(["--xla-path", container_xla_path])
bw_cmd.append("/jax")
cmd = ["docker", "run"]
mounts = [
"-v",
os.path.abspath("./") + ":/jax",
"-v",
os.path.abspath("./wheelhouse") + ":/wheelhouse",
]
if xla_path:
mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)])
cmd.extend(mounts)
if os.isatty(sys.stdout.fileno()):
cmd.append("-it")
# NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID
cmd.extend(
[
"--init",
"--rm",
image_name,
"bash",
"-c",
" ".join(bw_cmd),
]
)
# Add command for unit tests
cmd.extend(
[
"&&",
"bazel",
"test",
"-k",
"--jobs=4",
"--test_verbose_timeout_warnings=true",
"--test_output=all",
"--test_summary=detailed",
"--local_test_jobs=1",
"--test_env=JAX_ACCELERATOR_COUNT=%i" % 4,
"--test_env=JAX_SKIP_SLOW_TESTS=0",
"--verbose_failures=true",
"--config=rocm",
"--action_env=ROCM_PATH=/opt/rocm",
"--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a",
"--test_tag_filters=-multiaccelerator",
"--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform",
"--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow",
"//tests:gpu_tests",
]
)
LOG.info("Running: %s", cmd)
_ = subprocess.run(cmd, check=True)
def _fetch_jax_metadata(xla_path):
cmd = ["git", "rev-parse", "HEAD"]
jax_commit = subprocess.check_output(cmd)
xla_commit = b""
if xla_path:
try:
xla_commit = subprocess.check_output(cmd, cwd=xla_path)
except Exception as ex:
LOG.warning("Exception while retrieving xla_commit: %s" % ex)
cmd = ["python3", "setup.py", "-V"]
env = dict(os.environ)
env["JAX_RELEASE"] = "1"
jax_version = subprocess.check_output(cmd, env=env)
def safe_decode(x):
if isinstance(x, str):
return x
else:
return x.decode("utf8")
return {
"jax_version": safe_decode(jax_version).strip(),
"jax_commit": safe_decode(jax_commit).strip(),
"xla_commit": safe_decode(xla_commit).strip(),
}
def dist_docker(
rocm_version,
base_docker,
python_versions,
xla_path,
rocm_build_job="",
rocm_build_num="",
tag="rocm/jax-dev",
dockerfile=None,
keep_image=True,
):
if not dockerfile:
dockerfile = "build/rocm/Dockerfile.ms"
python_version = python_versions[0]
md = _fetch_jax_metadata(xla_path)
cmd = [
"docker",
"build",
"-f",
dockerfile,
"--target",
"rt_build",
"--build-arg=ROCM_VERSION=%s" % rocm_version,
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
"--build-arg=BASE_DOCKER=%s" % base_docker,
"--build-arg=PYTHON_VERSION=%s" % python_version,
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
"--build-arg=XLA_COMMIT=%(xla_commit)s" % md,
"--tag=%s" % tag,
]
if not keep_image:
cmd.append("--rm")
# context dir
cmd.append(".")
subprocess.check_call(cmd)
def test(image_name, test_cmd):
"""Run unit tests like CI would inside a JAX image."""
gpu_args = [
"--device=/dev/kfd",
"--device=/dev/dri",
"--group-add",
"video",
"--cap-add=SYS_PTRACE",
"--security-opt",
"seccomp=unconfined",
"--shm-size",
"16G",
]
cmd = [
"docker",
"run",
"--rm",
]
if os.isatty(sys.stdout.fileno()):
cmd.append("-it")
# NOTE(mrodden): we need jax source dir for the unit test code only,
# JAX and jaxlib are already installed from wheels
mounts = [
"-v",
os.path.abspath("./") + ":/jax",
]
cmd.extend(mounts)
cmd.extend(gpu_args)
container_cmd = "cd /jax && " + test_cmd
cmd.append(image_name)
cmd.extend(
[
"bash",
"-c",
container_cmd,
]
)
subprocess.check_call(cmd)
def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
"--base-docker",
default="",
help="Argument to override base docker in dockerfile",
)
p.add_argument(
"--python-versions",
type=lambda x: x.split(","),
default="3.12",
help="Comma separated list of CPython versions to build wheels for",
)
p.add_argument(
"--rocm-version",
default="6.1.1",
help="ROCm version used for building wheels, testing, and installing into Docker image",
)
p.add_argument(
"--rocm-build-job",
default="",
help="ROCm build job for development ROCm builds",
)
p.add_argument(
"--rocm-build-num",
default="",
help="ROCm build number for development ROCm builds",
)
p.add_argument(
"--xla-source-dir",
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
)
p.add_argument(
"--compiler",
choices=["gcc", "clang"],
help="Compiler backend to use when compiling jax/jaxlib",
)
subp = p.add_subparsers(dest="action", required=True)
dwp = subp.add_parser("dist_wheels")
testp = subp.add_parser("test")
testp.add_argument("image_name")
testp.add_argument(
"--test-cmd",
default="./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh",
)
ddp = subp.add_parser("dist_docker")
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")
ddp.add_argument("--keep-image", action="store_true")
ddp.add_argument("--image-tag", default="rocm/jax-dev")
return p.parse_args()
def main():
logging.basicConfig(level=logging.INFO)
args = parse_args()
if args.action == "dist_wheels":
dist_wheels(
args.rocm_version,
args.python_versions,
args.xla_source_dir,
args.rocm_build_job,
args.rocm_build_num,
compiler=args.compiler,
)
elif args.action == "test":
test(args.image_name, args.test_cmd)
elif args.action == "dist_docker":
dist_wheels(
args.rocm_version,
args.python_versions,
args.xla_source_dir,
args.rocm_build_job,
args.rocm_build_num,
compiler=args.compiler,
)
dist_docker(
args.rocm_version,
args.base_docker,
args.python_versions,
args.xla_source_dir,
rocm_build_job=args.rocm_build_job,
rocm_build_num=args.rocm_build_num,
tag=args.image_tag,
dockerfile=args.dockerfile,
keep_image=args.keep_image,
)
if __name__ == "__main__":
main()