rocm_jax/build/rocm/tools/build_wheels.py
Mathew Odden ec4b8ee1ed Fixes for 0.5.0 build ported to rocm-main
(cherry picked from commit c23a81461192a2b6da3d364076a261714d2dc64f)
2025-03-25 17:51:30 -05:00

342 lines
10 KiB
Python

#!/usr/bin/env python3
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# NOTE(mrodden): This file is part of the ROCm build scripts, and
# needs be compatible with Python 3.6. Please do not include these
# in any "upgrade" scripts
import argparse
from collections import deque
import fcntl
import logging
import os
import re
import select
import subprocess
import shutil
import sys
LOG = logging.getLogger(__name__)
GPU_DEVICE_TARGETS = "gfx900 gfx906 gfx908 gfx90a gfx942 gfx1030 gfx1100 gfx1101 gfx1200 gfx1201"
def build_rocm_path(rocm_version_str):
path = "/opt/rocm-%s" % rocm_version_str
if os.path.exists(path):
return path
else:
return os.path.realpath("/opt/rocm")
def update_rocm_targets(rocm_path, targets):
target_fp = os.path.join(rocm_path, "bin/target.lst")
version_fp = os.path.join(rocm_path, ".info/version")
with open(target_fp, "w") as fd:
fd.write("%s\n" % targets)
# mimic touch
open(version_fp, "a").close()
def find_clang_path():
llvm_base_path = "/usr/lib/"
# Search for llvm directories and pick the highest version.
llvm_dirs = [d for d in os.listdir(llvm_base_path) if d.startswith("llvm-")]
if llvm_dirs:
# Sort to get the highest llvm version.
llvm_dirs.sort(reverse=True)
clang_bin_dir = os.path.join(llvm_base_path, llvm_dirs[0], "bin")
# Prefer versioned clang binaries (e.g., clang-18).
versioned_clang = None
generic_clang = None
for f in os.listdir(clang_bin_dir):
# Checks for versioned clang binaries.
if f.startswith("clang-") and f[6:].isdigit():
versioned_clang = os.path.join(clang_bin_dir, f)
# Fallback to non-versioned clang.
elif f == "clang":
generic_clang = os.path.join(clang_bin_dir, f)
# Return versioned clang if available, otherwise return generic clang.
if versioned_clang:
return versioned_clang
elif generic_clang:
return generic_clang
return None
def build_jaxlib_wheel(
jax_path, rocm_path, python_version, xla_path=None, compiler="gcc"
):
use_clang = "true" if compiler == "clang" else "false"
# Avoid git warning by setting safe.directory.
try:
subprocess.run(
["git", "config", "--global", "--add", "safe.directory", "*"],
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Failed to configure Git safe directory: {e}")
raise
cmd = [
"python",
"build/build.py",
"build",
"--wheels=jaxlib,jax-rocm-plugin,jax-rocm-pjrt",
"--rocm_path=%s" % rocm_path,
"--rocm_version=60",
"--use_clang=%s" % use_clang,
"--verbose",
]
# Add clang path if clang is used.
if compiler == "clang":
clang_path = find_clang_path()
if clang_path:
LOG.info("Found clang at path: %s", clang_path)
cmd.append("--clang_path=%s" % clang_path)
else:
raise RuntimeError("Clang binary not found in /usr/lib/llvm-*")
if xla_path:
cmd.append("--bazel_options=--override_repository=xla=%s" % xla_path)
cpy = to_cpy_ver(python_version)
py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy)
env = dict(os.environ)
env["JAX_RELEASE"] = str(1)
env["JAXLIB_RELEASE"] = str(1)
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
LOG.info("Running %r from cwd=%r" % (cmd, jax_path))
pattern = re.compile("Output wheel: (.+)\n")
_run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stderr")
def build_jax_wheel(jax_path, python_version):
cmd = [
"python",
"-m",
"build",
]
cpy = to_cpy_ver(python_version)
py_bin = "/opt/python/%s-%s/bin" % (cpy, cpy)
env = dict(os.environ)
env["JAX_RELEASE"] = str(1)
env["JAXLIB_RELEASE"] = str(1)
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
LOG.info("Running %r from cwd=%r" % (cmd, jax_path))
pattern = re.compile(r"Successfully built jax-.+ and (jax-.+\.whl)\n")
_run_scan_for_output(cmd, pattern, env=env, cwd=jax_path, capture="stdout")
def _run_scan_for_output(cmd, pattern, env=None, cwd=None, capture=None):
buf = deque(maxlen=20000)
if capture == "stderr":
p = subprocess.Popen(cmd, env=env, cwd=cwd, stderr=subprocess.PIPE)
redir = sys.stderr
cap_fd = p.stderr
else:
p = subprocess.Popen(cmd, env=env, cwd=cwd, stdout=subprocess.PIPE)
redir = sys.stdout
cap_fd = p.stdout
flags = fcntl.fcntl(cap_fd, fcntl.F_GETFL)
fcntl.fcntl(cap_fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
eof = False
while not eof:
r, _, _ = select.select([cap_fd], [], [])
for fd in r:
dat = fd.read(512)
if dat is None:
continue
elif dat:
t = dat.decode("utf8")
redir.write(t)
buf.extend(t)
else:
eof = True
# wait and drain pipes
_, _ = p.communicate()
if p.returncode != 0:
raise Exception(
"Child process exited with nonzero result: rc=%d" % p.returncode
)
text = "".join(buf)
matches = pattern.findall(text)
if not matches:
LOG.error("No wheel name found in output: %r" % text)
raise Exception("No wheel name found in output")
wheels = []
for match in matches:
LOG.info("Found built wheel: %r" % match)
wheels.append(match)
return wheels
def to_cpy_ver(python_version):
tup = python_version.split(".")
return "cp%d%d" % (int(tup[0]), int(tup[1]))
def fix_wheel(path, jax_path):
try:
# NOTE(mrodden): fixwheel needs auditwheel 6.0.0, which has a min python of 3.8
# so use one of the CPythons in /opt to run
env = dict(os.environ)
py_bin = "/opt/python/cp310-cp310/bin"
env["PATH"] = "%s:%s" % (py_bin, env["PATH"])
# NOTE(mrodden): auditwheel 6.0 added lddtree module, but 6.3.0 changed
# the fuction to ldd and also changed its behavior
# constrain range to 6.0 to 6.2.x
cmd = ["pip", "install", "auditwheel>=6,<6.3"]
subprocess.run(cmd, check=True, env=env)
fixwheel_path = os.path.join(jax_path, "build/rocm/tools/fixwheel.py")
cmd = ["python", fixwheel_path, path]
subprocess.run(cmd, check=True, env=env)
LOG.info("Wheel fix completed successfully.")
except subprocess.CalledProcessError as cpe:
LOG.error(f"Subprocess failed with error: {cpe}")
raise
except Exception as e:
LOG.error(f"An unexpected error occurred: {e}")
raise
def parse_args():
p = argparse.ArgumentParser()
p.add_argument(
"--rocm-version", default="6.1.1", help="ROCM Version to build JAX against"
)
p.add_argument(
"--python-versions",
default=["3.10.19,3.12"],
help="Comma separated CPython versions that wheels will be built and output for",
)
p.add_argument(
"--xla-path",
type=str,
default=None,
help="Optional directory where XLA source is located to use instead of JAX builtin XLA",
)
p.add_argument(
"--compiler",
type=str,
default="gcc",
help="Compiler backend to use when compiling jax/jaxlib",
)
p.add_argument("jax_path", help="Directory where JAX source directory is located")
return p.parse_args()
def find_wheels(path):
wheels = []
for f in os.listdir(path):
if f.endswith(".whl"):
wheels.append(os.path.join(path, f))
LOG.info("Found wheels: %r" % wheels)
return wheels
def main():
args = parse_args()
python_versions = args.python_versions.split(",")
print("ROCM_VERSION=%s" % args.rocm_version)
print("PYTHON_VERSIONS=%r" % python_versions)
print("JAX_PATH=%s" % args.jax_path)
print("XLA_PATH=%s" % args.xla_path)
print("COMPILER=%s" % args.compiler)
rocm_path = build_rocm_path(args.rocm_version)
update_rocm_targets(rocm_path, GPU_DEVICE_TARGETS)
for py in python_versions:
build_jaxlib_wheel(args.jax_path, rocm_path, py, args.xla_path, args.compiler)
wheel_paths = find_wheels(os.path.join(args.jax_path, "dist"))
for wheel_path in wheel_paths:
# skip jax wheel since it is non-platform
if not os.path.basename(wheel_path).startswith("jax-"):
fix_wheel(wheel_path, args.jax_path)
# build JAX wheel for completeness
build_jax_wheel(args.jax_path, python_versions[-1])
wheels = find_wheels(os.path.join(args.jax_path, "dist"))
# NOTE(mrodden): the jax wheel is a "non-platform wheel", so auditwheel will
# do nothing, and in fact will throw an Exception. we just need to copy it
# along with the jaxlib and plugin ones
# copy jax wheel(s) to wheelhouse
wheelhouse_dir = "/wheelhouse/"
for whl in wheels:
if os.path.basename(whl).startswith("jax-"):
LOG.info("Copying %s into %s" % (whl, wheelhouse_dir))
shutil.copy(whl, wheelhouse_dir)
# Delete the 'dist' directory since it causes permissions issues
logging.info("Deleting dist, egg-info and cache directory")
shutil.rmtree(os.path.join(args.jax_path, "dist"))
shutil.rmtree(os.path.join(args.jax_path, "jax.egg-info"))
shutil.rmtree(os.path.join(args.jax_path, "jax", "__pycache__"))
# Make the wheels deleteable by the runner
whl_house = os.path.join(args.jax_path, "wheelhouse")
logging.info("Changing permissions for %s" % whl_house)
mode = 0o664
for item in os.listdir(whl_house):
whl_path = os.path.join(whl_house, item)
if os.path.isfile(whl_path):
os.chmod(whl_path, mode)
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)
main()