mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
386 lines
9.7 KiB
Python
Executable File
386 lines
9.7 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
# Copyright 2024 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
|
|
# NOTE(mrodden): This file is part of the ROCm build scripts, and
|
|
# needs be compatible with Python 3.6. Please do not include these
|
|
# in any "upgrade" scripts
|
|
|
|
|
|
import argparse
|
|
import logging
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
|
|
|
|
LOG = logging.getLogger("ci_build")
|
|
|
|
|
|
def image_by_name(name):
|
|
cmd = ["docker", "images", "-q", "-f", "reference=%s" % name]
|
|
out = subprocess.check_output(cmd)
|
|
image_id = out.decode("utf8").strip().split("\n")[0] or None
|
|
return image_id
|
|
|
|
|
|
def create_manylinux_build_image(rocm_version, rocm_build_job, rocm_build_num):
|
|
image_name = "jax-build-manylinux_2_28_x86_64_rocm%s" % rocm_version.replace(
|
|
".", ""
|
|
)
|
|
cmd = [
|
|
"docker",
|
|
"build",
|
|
"-f",
|
|
"build/rocm/build_wheels/Dockerfile.manylinux_2_28_x86_64.rocm",
|
|
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
|
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
|
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
|
"--tag=%s" % image_name,
|
|
".",
|
|
]
|
|
|
|
LOG.info("Creating manylinux build image. Running: %s", cmd)
|
|
_ = subprocess.run(cmd, check=True)
|
|
return image_name
|
|
|
|
|
|
def dist_wheels(
|
|
rocm_version,
|
|
python_versions,
|
|
xla_path,
|
|
rocm_build_job="",
|
|
rocm_build_num="",
|
|
compiler="gcc",
|
|
):
|
|
# We want to make sure the wheels we build are manylinux compliant. We'll
|
|
# do the build in a container. Build the image for this.
|
|
image_name = create_manylinux_build_image(
|
|
rocm_version, rocm_build_job, rocm_build_num
|
|
)
|
|
|
|
if xla_path:
|
|
xla_path = os.path.abspath(xla_path)
|
|
|
|
# use image to build JAX/jaxlib wheels
|
|
os.makedirs("wheelhouse", exist_ok=True)
|
|
|
|
pyver_string = ",".join(python_versions)
|
|
|
|
container_xla_path = "/xla"
|
|
|
|
bw_cmd = [
|
|
"python3",
|
|
"/jax/build/rocm/tools/build_wheels.py",
|
|
"--rocm-version",
|
|
rocm_version,
|
|
"--python-versions",
|
|
pyver_string,
|
|
"--compiler",
|
|
compiler,
|
|
]
|
|
|
|
if xla_path:
|
|
bw_cmd.extend(["--xla-path", container_xla_path])
|
|
|
|
bw_cmd.append("/jax")
|
|
|
|
cmd = ["docker", "run"]
|
|
|
|
mounts = [
|
|
"-v",
|
|
os.path.abspath("./") + ":/jax",
|
|
"-v",
|
|
os.path.abspath("./wheelhouse") + ":/wheelhouse",
|
|
]
|
|
|
|
if xla_path:
|
|
mounts.extend(["-v", "%s:%s" % (xla_path, container_xla_path)])
|
|
|
|
cmd.extend(mounts)
|
|
|
|
if os.isatty(sys.stdout.fileno()):
|
|
cmd.append("-it")
|
|
|
|
# NOTE(mrodden): bazel times out without --init, probably blocking on a zombie PID
|
|
cmd.extend(
|
|
[
|
|
"--init",
|
|
"--rm",
|
|
image_name,
|
|
"bash",
|
|
"-c",
|
|
" ".join(bw_cmd),
|
|
]
|
|
)
|
|
|
|
# Add command for unit tests
|
|
cmd.extend(
|
|
[
|
|
"&&",
|
|
"bazel",
|
|
"test",
|
|
"-k",
|
|
"--jobs=4",
|
|
"--test_verbose_timeout_warnings=true",
|
|
"--test_output=all",
|
|
"--test_summary=detailed",
|
|
"--local_test_jobs=1",
|
|
"--test_env=JAX_ACCELERATOR_COUNT=%i" % 4,
|
|
"--test_env=JAX_SKIP_SLOW_TESTS=0",
|
|
"--verbose_failures=true",
|
|
"--config=rocm",
|
|
"--action_env=ROCM_PATH=/opt/rocm",
|
|
"--action_env=TF_ROCM_AMDGPU_TARGETS=%s" % "gfx90a",
|
|
"--test_tag_filters=-multiaccelerator",
|
|
"--test_env=XLA_PYTHON_CLIENT_ALLOCATOR=platform",
|
|
"--test_env=JAX_EXCLUDE_TEST_TARGETS=PmapTest.testSizeOverflow",
|
|
"//tests:gpu_tests",
|
|
]
|
|
)
|
|
|
|
LOG.info("Running: %s", cmd)
|
|
_ = subprocess.run(cmd, check=True)
|
|
|
|
|
|
def _fetch_jax_metadata(xla_path):
|
|
cmd = ["git", "rev-parse", "HEAD"]
|
|
jax_commit = subprocess.check_output(cmd)
|
|
xla_commit = b""
|
|
|
|
if xla_path:
|
|
try:
|
|
xla_commit = subprocess.check_output(cmd, cwd=xla_path)
|
|
except Exception as ex:
|
|
LOG.warning("Exception while retrieving xla_commit: %s" % ex)
|
|
|
|
cmd = ["python3", "setup.py", "-V"]
|
|
env = dict(os.environ)
|
|
env["JAX_RELEASE"] = "1"
|
|
|
|
jax_version = subprocess.check_output(cmd, env=env)
|
|
|
|
def safe_decode(x):
|
|
if isinstance(x, str):
|
|
return x
|
|
else:
|
|
return x.decode("utf8")
|
|
|
|
return {
|
|
"jax_version": safe_decode(jax_version).strip(),
|
|
"jax_commit": safe_decode(jax_commit).strip(),
|
|
"xla_commit": safe_decode(xla_commit).strip(),
|
|
}
|
|
|
|
|
|
def dist_docker(
|
|
rocm_version,
|
|
base_docker,
|
|
python_versions,
|
|
xla_path,
|
|
rocm_build_job="",
|
|
rocm_build_num="",
|
|
tag="rocm/jax-dev",
|
|
dockerfile=None,
|
|
keep_image=True,
|
|
):
|
|
if not dockerfile:
|
|
dockerfile = "build/rocm/Dockerfile.ms"
|
|
|
|
python_version = python_versions[0]
|
|
|
|
md = _fetch_jax_metadata(xla_path)
|
|
|
|
cmd = [
|
|
"docker",
|
|
"build",
|
|
"-f",
|
|
dockerfile,
|
|
"--target",
|
|
"rt_build",
|
|
"--build-arg=ROCM_VERSION=%s" % rocm_version,
|
|
"--build-arg=ROCM_BUILD_JOB=%s" % rocm_build_job,
|
|
"--build-arg=ROCM_BUILD_NUM=%s" % rocm_build_num,
|
|
"--build-arg=BASE_DOCKER=%s" % base_docker,
|
|
"--build-arg=PYTHON_VERSION=%s" % python_version,
|
|
"--build-arg=JAX_VERSION=%(jax_version)s" % md,
|
|
"--build-arg=JAX_COMMIT=%(jax_commit)s" % md,
|
|
"--build-arg=XLA_COMMIT=%(xla_commit)s" % md,
|
|
"--tag=%s" % tag,
|
|
]
|
|
|
|
if not keep_image:
|
|
cmd.append("--rm")
|
|
|
|
# context dir
|
|
cmd.append(".")
|
|
|
|
subprocess.check_call(cmd)
|
|
|
|
|
|
def test(image_name, test_cmd):
|
|
"""Run unit tests like CI would inside a JAX image."""
|
|
|
|
gpu_args = [
|
|
"--device=/dev/kfd",
|
|
"--device=/dev/dri",
|
|
"--group-add",
|
|
"video",
|
|
"--cap-add=SYS_PTRACE",
|
|
"--security-opt",
|
|
"seccomp=unconfined",
|
|
"--shm-size",
|
|
"16G",
|
|
]
|
|
|
|
cmd = [
|
|
"docker",
|
|
"run",
|
|
"--rm",
|
|
]
|
|
|
|
if os.isatty(sys.stdout.fileno()):
|
|
cmd.append("-it")
|
|
|
|
# NOTE(mrodden): we need jax source dir for the unit test code only,
|
|
# JAX and jaxlib are already installed from wheels
|
|
mounts = [
|
|
"-v",
|
|
os.path.abspath("./") + ":/jax",
|
|
]
|
|
|
|
cmd.extend(mounts)
|
|
cmd.extend(gpu_args)
|
|
|
|
container_cmd = "cd /jax && " + test_cmd
|
|
cmd.append(image_name)
|
|
cmd.extend(
|
|
[
|
|
"bash",
|
|
"-c",
|
|
container_cmd,
|
|
]
|
|
)
|
|
|
|
subprocess.check_call(cmd)
|
|
|
|
|
|
def parse_args():
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument(
|
|
"--base-docker",
|
|
default="",
|
|
help="Argument to override base docker in dockerfile",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--python-versions",
|
|
type=lambda x: x.split(","),
|
|
default="3.12",
|
|
help="Comma separated list of CPython versions to build wheels for",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--rocm-version",
|
|
default="6.1.1",
|
|
help="ROCm version used for building wheels, testing, and installing into Docker image",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--rocm-build-job",
|
|
default="",
|
|
help="ROCm build job for development ROCm builds",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--rocm-build-num",
|
|
default="",
|
|
help="ROCm build number for development ROCm builds",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--xla-source-dir",
|
|
help="Path to XLA source to use during jaxlib build, instead of builtin XLA",
|
|
)
|
|
|
|
p.add_argument(
|
|
"--compiler",
|
|
choices=["gcc", "clang"],
|
|
help="Compiler backend to use when compiling jax/jaxlib",
|
|
)
|
|
|
|
subp = p.add_subparsers(dest="action", required=True)
|
|
|
|
dwp = subp.add_parser("dist_wheels")
|
|
|
|
testp = subp.add_parser("test")
|
|
testp.add_argument("image_name")
|
|
testp.add_argument(
|
|
"--test-cmd",
|
|
default="./build/rocm/build_rocm.sh && ./build/rocm/run_single_gpu.py -c && ./build/rocm/run_multi_gpu.sh",
|
|
)
|
|
|
|
ddp = subp.add_parser("dist_docker")
|
|
ddp.add_argument("--dockerfile", default="build/rocm/Dockerfile.ms")
|
|
ddp.add_argument("--keep-image", action="store_true")
|
|
ddp.add_argument("--image-tag", default="rocm/jax-dev")
|
|
|
|
return p.parse_args()
|
|
|
|
|
|
def main():
|
|
logging.basicConfig(level=logging.INFO)
|
|
args = parse_args()
|
|
|
|
if args.action == "dist_wheels":
|
|
dist_wheels(
|
|
args.rocm_version,
|
|
args.python_versions,
|
|
args.xla_source_dir,
|
|
args.rocm_build_job,
|
|
args.rocm_build_num,
|
|
compiler=args.compiler,
|
|
)
|
|
|
|
elif args.action == "test":
|
|
test(args.image_name, args.test_cmd)
|
|
|
|
elif args.action == "dist_docker":
|
|
dist_wheels(
|
|
args.rocm_version,
|
|
args.python_versions,
|
|
args.xla_source_dir,
|
|
args.rocm_build_job,
|
|
args.rocm_build_num,
|
|
compiler=args.compiler,
|
|
)
|
|
dist_docker(
|
|
args.rocm_version,
|
|
args.base_docker,
|
|
args.python_versions,
|
|
args.xla_source_dir,
|
|
rocm_build_job=args.rocm_build_job,
|
|
rocm_build_num=args.rocm_build_num,
|
|
tag=args.image_tag,
|
|
dockerfile=args.dockerfile,
|
|
keep_image=args.keep_image,
|
|
)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|