#!/usr/bin/env python3 # Copyright 2024 The JAX Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # NOTE(mrodden): This file is part of the ROCm build scripts, and # needs be compatible with Python 3.6. Please do not include these # in any "upgrade" scripts import argparse from collections import deque import fcntl import logging import os import re import select import subprocess import shutil import sys LOG = logging.getLogger(__name__) GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201" def build_rocm_path(rocm_version_str): path = "/opt/rocm-%s" % rocm_version_str if os.path.exists(path): return path else: return os.path.realpath("/opt/rocm") def update_rocm_targets(rocm_path, targets): target_fp = os.path.join(rocm_path, "bin/target.lst") version_fp = os.path.join(rocm_path, ".info/version") with open(target_fp, "w") as fd: fd.write("%s\n" % targets) # mimic touch open(version_fp, "a").close() def find_clang_path(): llvm_base_path = "/usr/lib/" # Search for llvm directories and pick the highest version. llvm_dirs = [d for d in os.listdir(llvm_base_path) if d.startswith("llvm-")] if llvm_dirs: # Sort to get the highest llvm version. llvm_dirs.sort(reverse=True) clang_bin_dir = os.path.join(llvm_base_path, llvm_dirs[0], "bin") # Prefer versioned clang binaries (e.g., clang-18). versioned_clang = None generic_clang = None for f in os.listdir(clang_bin_dir): # Checks for versioned clang binaries. if f.startswith("clang-") and f[6:].isdigit(): versioned_clang = os.path.join(clang_bin_dir, f) # Fallback to non-versioned clang. elif f == "clang": generic_clang = os.path.join(clang_bin_dir, f) # Return versioned clang if available, otherwise return generic clang. if versioned_clang: return versioned_clang elif generic_clang: return generic_clang return None def build_jaxlib_wheel( jax_path, rocm_path, python_version, xla_path=None, compiler="gcc" ): use_clang = "true" if compiler == "clang" else "false" # Avoid git warning by setting safe.directory. try: subprocess.run( ["git", "config", "--global", "--add", "safe.directory", "*"], check=True, ) except subprocess.CalledProcessError as e: print(f"Failed to configure Git safe directory: {e}") raise cmd = [ "python", "build/build.py", "build", "--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt", "--rocm_path=%s" % rocm_path, "--rocm_version=60", "--use_clang=%s" % use_clang, "--verbose", ] # Add clang path if clang is used. if compiler == "clang": clang_path = find_clang_path() if clang_path: LOG.info("Found clang at path: %s", clang_path) cmd.append("--clang_path=%s" % clang_path) else: raise RuntimeError("Clang binary not found in /usr/lib/llvm-*") if xla_path: cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path) cpy = to_cpy_ver(python_version) py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy) env = dict(os.environ) env["JAX_RELEASE"] = str(1) env["JAXLIB_RELEASE"] = str(1) env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) pattern = re.compile("Output wheel: (.+)\n") _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr") def build_jax_wheel(jax_path, python_version): cmd = [ "python", "-m", "build", ] cpy = to_cpy_ver(python_version) py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy) env = dict(os.environ) env["JAX_RELEASE"] = str(1) env["JAXLIB_RELEASE"] = str(1) env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) LOG.info("Running %r from cwd=%r" % (cmd, jax_path)) pattern = re.compile(r"Successfully built jax-.+ and (jax-.+\.whl)\n") _run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout") def _run_scan_for_output(cmd, pattern, env=None, cwd=None, capture=None): buf = deque(maxlen=20000) if capture == "stderr": p = subprocess.Popen(cmd, env=env, cwd=cwd, stderr=subprocess.PIPE) redir = sys.stderr cap_fd = p.stderr else: p = subprocess.Popen(cmd, env=env, cwd=cwd, stdout=subprocess.PIPE) redir = sys.stdout cap_fd = p.stdout flags = fcntl.fcntl(cap_fd, fcntl.F_GETFL) fcntl.fcntl(cap_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) eof = False while not eof: r, _, _ = select.select([cap_fd], [], []) for fd in r: dat = fd.read(512) if dat is None: continue elif dat: t = dat.decode("utf8") redir.write(t) buf.extend(t) else: eof = True # wait and drain pipes _, _ = p.communicate() if p.returncode != 0: raise Exception( "Child process exited with nonzero result: rc=%d" % p.returncode ) text = "".join(buf) matches = pattern.findall(text) if not matches: LOG.error("No wheel name found in output: %r" % text) raise Exception("No wheel name found in output") wheels = [] for match in matches: LOG.info("Found built wheel: %r" % match) wheels.append(match) return wheels def to_cpy_ver(python_version): tup = python_version.split(".") return "cp%d%d" % (int(tup[0]), int(tup[1])) def fix_wheel(path, jax_path): try: # NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8 # so use one of the CPythons in /opt to run env = dict(os.environ) py_bin = "/opt/python/cp310-cp310/bin" env["PATH"] = "%s:%s" % (py_bin, env["PATH"]) # NOTE(mrodden): auditwheel 6.0 added lddtree module, but 6.3.0 changed # the fuction to ldd and also changed its behavior # constrain range to 6.0 to 6.2.x cmd = ["pip", "install", "auditwheel>=6,<6.3"] subprocess.run(cmd, check=True, env=env) fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py") cmd = ["python", fixwheel_path, path] subprocess.run(cmd, check=True, env=env) LOG.info("Wheel fix completed successfully.") except subprocess.CalledProcessError as cpe: LOG.error(f"Subprocess failed with error: {cpe}") raise except Exception as e: LOG.error(f"An unexpected error occurred: {e}") raise def parse_args(): p = argparse.ArgumentParser() p.add_argument( "--rocm-version", default="6.1.1", help="ROCM Version to build JAX against" ) p.add_argument( "--python-versions", default=["3.10.19,3.12"], help="Comma separated CPython versions that wheels will be built and output for", ) p.add_argument( "--xla-path", type=str, default=None, help="Optional directory where XLA source is located to use instead of JAX builtin XLA", ) p.add_argument( "--compiler", type=str, default="gcc", help="Compiler backend to use when compiling jax/jaxlib", ) p.add_argument("jax_path", help="Directory where JAX source directory is located") return p.parse_args() def find_wheels(path): wheels = [] for f in os.listdir(path): if f.endswith(".whl"): wheels.append(os.path.join(path, f)) LOG.info("Found wheels: %r" % wheels) return wheels def main(): args = parse_args() python_versions = args.python_versions.split(",") print("ROCM_VERSION=%s" % args.rocm_version) print("PYTHON_VERSIONS=%r" % python_versions) print("JAX_PATH=%s" % args.jax_path) print("XLA_PATH=%s" % args.xla_path) print("COMPILER=%s" % args.compiler) rocm_path = build_rocm_path(args.rocm_version) update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS) for py in python_versions: build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler) wheel_paths = find_wheels(os.path.join(args.jax_path, "dist")) for wheel_path in wheel_paths: # skip jax wheel since it is non-platform if not os.path.basename(wheel_path).startswith("jax-"): fix_wheel(wheel_path, args.jax_path) # build JAX wheel for completeness build_jax_wheel(args.jax_path, python_versions[-1]) wheels = find_wheels(os.path.join(args.jax_path, "dist")) # NOTE(mrodden): the jax wheel is a "non-platform wheel", so auditwheel will # do nothing, and in fact will throw an Exception. we just need to copy it # along with the jaxlib and plugin ones # copy jax wheel(s) to wheelhouse wheelhouse_dir = "/wheelhouse/" for whl in wheels: if os.path.basename(whl).startswith("jax-"): LOG.info("Copying %s into %s" % (whl, wheelhouse_dir)) shutil.copy(whl, wheelhouse_dir) # Delete the 'dist' directory since it causes permissions issues logging.info("Deleting dist, egg-info and cache directory") shutil.rmtree(os.path.join(args.jax_path, "dist")) shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info")) shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__")) # Make the wheels deleteable by the runner whl_house = os.path.join(args.jax_path, "wheelhouse") logging.info("Changing permissions for %s" % whl_house) mode = 0o664 for item in os.listdir(whl_house): whl_path = os.path.join(whl_house, item) if os.path.isfile(whl_path): os.chmod(whl_path, mode) if __name__ == "__main__": logging.basicConfig(level=logging.INFO) main()