From f5c61a892acc657167f9d9ddb507ae7b373c4b0f Mon Sep 17 00:00:00 2001 From: Peter Hawkins Date: Mon, 12 Jul 2021 16:33:12 -0400 Subject: [PATCH] Add support for cross-compiling jaxlib for Mac ARM. --- build/build.py | 48 +++++++++++++++++++++++++++++++++----------- build/build_wheel.py | 26 +++++++++++++----------- 2 files changed, 50 insertions(+), 24 deletions(-) diff --git a/build/build.py b/build/build.py index 30a0385a4..f1eea0317 100755 --- a/build/build.py +++ b/build/build.py @@ -135,9 +135,9 @@ bazel_packages = { } -def download_and_verify_bazel(): +def download_and_verify_bazel(target_cpu): """Downloads a bazel binary from Github, verifying its SHA256 hash.""" - package = bazel_packages.get((platform.system(), platform.machine())) + package = bazel_packages.get((platform.system(), target_cpu)) if package is None: return None @@ -182,16 +182,16 @@ def download_and_verify_bazel(): return os.path.join(".", package.file) -def get_bazel_paths(bazel_path_flag): +def get_bazel_paths(bazel_path_flag, target_cpu): """Yields a sequence of guesses about bazel path. Some of sequence elements can be None. The resulting iterator is lazy and potentially has a side effects.""" yield bazel_path_flag yield which("bazel") - yield download_and_verify_bazel() + yield download_and_verify_bazel(target_cpu) -def get_bazel_path(bazel_path_flag): +def get_bazel_path(bazel_path_flag, target_cpu): """Returns the path to a Bazel binary, downloading Bazel if not found. Also, it checks Bazel's version at lease newer than 2.0.0. @@ -199,7 +199,7 @@ def get_bazel_path(bazel_path_flag): releases performs version check against .bazelversion (see for details https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes). """ - for path in filter(None, get_bazel_paths(bazel_path_flag)): + for path in filter(None, get_bazel_paths(bazel_path_flag, target_cpu)): if check_bazel_version(path): return path @@ -230,7 +230,6 @@ build --repo_env TF_NEED_CUDA="{tf_need_cuda}" build --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}" build --repo_env TF_NEED_ROCM="{tf_need_rocm}" build --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}" -build --distinct_host_configuration=false build:posix --copt=-Wno-sign-compare build -c opt build:avx_posix --copt=-mavx @@ -314,7 +313,8 @@ build:linux --copt=-Wno-stringop-truncation def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, - cuda_version=None, cudnn_version=None, rocm_toolkit_path=None, **kwargs): + cuda_version=None, cudnn_version=None, rocm_toolkit_path=None, + cpu=None, **kwargs): with open("../.bazelrc", "w") as f: f.write(BAZELRC_TEMPLATE.format(**kwargs)) if cuda_toolkit_path: @@ -332,6 +332,12 @@ def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None, if rocm_toolkit_path: f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n" .format(rocm_toolkit_path=rocm_toolkit_path)) + if cpu is not None: + f.write("build --distinct_host_configuration=true\n") + f.write(f"build --cpu={cpu}\n") + else: + f.write("build --distinct_host_configuration=false\n") + BANNER = r""" _ _ __ __ @@ -463,6 +469,11 @@ def main(): "--output_path", default=os.path.join(cwd, "dist"), help="Directory to which the jaxlib wheel should be written") + parser.add_argument( + "--target_cpu", + default=None, + help="CPU platform to target. Default is the same as the host machine. " + "Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.") args = parser.parse_args() if is_windows() and args.enable_cuda: @@ -479,8 +490,17 @@ def main(): output_path = os.path.abspath(args.output_path) os.chdir(os.path.dirname(__file__ or args.prog) or '.') + host_cpu = platform.machine() + wheel_cpus = { + "darwin_arm64": "arm64", + "darwin_x86_64": "x86_64", + } + # TODO(phawkins): support other bazel cpu overrides. + wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None + else host_cpu) + # Find a working Bazel. - bazel_path = get_bazel_path(args.bazel_path) + bazel_path = get_bazel_path(args.bazel_path, wheel_cpu) print("Bazel binary path: {}".format(bazel_path)) python_bin_path = get_python_bin_path(args.python_bin_path) @@ -495,6 +515,7 @@ def main(): print("SciPy version: {}".format(scipy_version)) print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no")) + print("Target CPU: {}".format(wheel_cpu)) print("Target CPU features: {}".format(args.target_cpu_features)) cuda_toolkit_path = args.cuda_path @@ -532,13 +553,15 @@ def main(): cudnn_version=args.cudnn_version, rocm_toolkit_path=rocm_toolkit_path, rocm_amdgpu_targets=args.rocm_amdgpu_targets, -) + cpu=args.target_cpu, + ) + print("\nBuilding XLA and installing it in the jaxlib source tree...") config_args = args.bazel_options config_args += ["--config=short_logs"] if args.target_cpu_features == "release": - if platform.uname().machine == "x86_64": + if wheel_cpu == "x86_64": config_args += ["--config=avx_windows" if is_windows() else "--config=avx_posix"] elif args.target_cpu_features == "native": @@ -563,7 +586,8 @@ def main(): command = ([bazel_path] + args.bazel_startup_options + ["run", "--verbose_failures=true"] + config_args + [":build_wheel", "--", - f"--output_path={output_path}"]) + f"--output_path={output_path}", + f"--cpu={wheel_cpu}"]) print(" ".join(command)) shell(command) shell([bazel_path, "shutdown"]) diff --git a/build/build_wheel.py b/build/build_wheel.py index b0bf88901..2f102984e 100644 --- a/build/build_wheel.py +++ b/build/build_wheel.py @@ -40,6 +40,11 @@ parser.add_argument( default=None, required=True, help="Path to which the output wheel should be written. Required.") +parser.add_argument( + "--cpu", + default=None, + required=True, + help="Target CPU architecture. Required.") args = parser.parse_args() r = runfiles.Create() @@ -215,18 +220,15 @@ def prepare_wheel(sources_path): patch_copy_tpu_client_py(jaxlib_dir) -def build_wheel(sources_path, output_path): +def build_wheel(sources_path, output_path, cpu): """Builds a wheel in `output_path` using the source tree in `sources_path`.""" - if platform.system() == "Windows": - cpu_name = "amd64" - platform_name = "win" - else: - platform_name, cpu_name = { - ("Linux", "x86_64"): ("manylinux2010", "x86_64"), - ("Linux", "aarch64"): ("manylinux2014", "aarch64"), - ("Darwin", "x86_64"): ("macosx_10_9", "x86_64"), - ("Darwin", "arm64"): ("macosx_11_0", "arm64"), - }[(platform.system(), platform.machine())] + platform_name, cpu_name = { + ("Linux", "x86_64"): ("manylinux2010", "x86_64"), + ("Linux", "aarch64"): ("manylinux2014", "aarch64"), + ("Darwin", "x86_64"): ("macosx_10_9", "x86_64"), + ("Darwin", "arm64"): ("macosx_11_0", "arm64"), + ("Windows", "x86_64"): ("win", "amd64"), + }[(platform.system(), cpu)] python_tag_arg = (f"--python-tag=cp{sys.version_info.major}" f"{sys.version_info.minor}") platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}" @@ -252,7 +254,7 @@ if sources_path is None: try: os.makedirs(args.output_path, exist_ok=True) prepare_wheel(sources_path) - build_wheel(sources_path, args.output_path) + build_wheel(sources_path, args.output_path, args.cpu) finally: if tmpdir: tmpdir.cleanup()