Add support for cross-compiling jaxlib for Mac ARM.

This commit is contained in:
Peter Hawkins 2021-07-12 16:33:12 -04:00
parent 25e44821dd
commit f5c61a892a
2 changed files with 50 additions and 24 deletions

View File

@ -135,9 +135,9 @@ bazel_packages = {
}
def download_and_verify_bazel():
def download_and_verify_bazel(target_cpu):
"""Downloads a bazel binary from Github, verifying its SHA256 hash."""
package = bazel_packages.get((platform.system(), platform.machine()))
package = bazel_packages.get((platform.system(), target_cpu))
if package is None:
return None
@ -182,16 +182,16 @@ def download_and_verify_bazel():
return os.path.join(".", package.file)
def get_bazel_paths(bazel_path_flag):
def get_bazel_paths(bazel_path_flag, target_cpu):
"""Yields a sequence of guesses about bazel path. Some of sequence elements
can be None. The resulting iterator is lazy and potentially has a side
effects."""
yield bazel_path_flag
yield which("bazel")
yield download_and_verify_bazel()
yield download_and_verify_bazel(target_cpu)
def get_bazel_path(bazel_path_flag):
def get_bazel_path(bazel_path_flag, target_cpu):
"""Returns the path to a Bazel binary, downloading Bazel if not found. Also,
it checks Bazel's version at lease newer than 2.0.0.
@ -199,7 +199,7 @@ def get_bazel_path(bazel_path_flag):
releases performs version check against .bazelversion (see for details
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
"""
for path in filter(None, get_bazel_paths(bazel_path_flag)):
for path in filter(None, get_bazel_paths(bazel_path_flag, target_cpu)):
if check_bazel_version(path):
return path
@ -230,7 +230,6 @@ build --repo_env TF_NEED_CUDA="{tf_need_cuda}"
build --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"
build --repo_env TF_NEED_ROCM="{tf_need_rocm}"
build --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"
build --distinct_host_configuration=false
build:posix --copt=-Wno-sign-compare
build -c opt
build:avx_posix --copt=-mavx
@ -314,7 +313,8 @@ build:linux --copt=-Wno-stringop-truncation
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None,
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None, **kwargs):
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None,
cpu=None, **kwargs):
with open("../.bazelrc", "w") as f:
f.write(BAZELRC_TEMPLATE.format(**kwargs))
if cuda_toolkit_path:
@ -332,6 +332,12 @@ def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None,
if rocm_toolkit_path:
f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n"
.format(rocm_toolkit_path=rocm_toolkit_path))
if cpu is not None:
f.write("build --distinct_host_configuration=true\n")
f.write(f"build --cpu={cpu}\n")
else:
f.write("build --distinct_host_configuration=false\n")
BANNER = r"""
_ _ __ __
@ -463,6 +469,11 @@ def main():
"--output_path",
default=os.path.join(cwd, "dist"),
help="Directory to which the jaxlib wheel should be written")
parser.add_argument(
"--target_cpu",
default=None,
help="CPU platform to target. Default is the same as the host machine. "
"Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.")
args = parser.parse_args()
if is_windows() and args.enable_cuda:
@ -479,8 +490,17 @@ def main():
output_path = os.path.abspath(args.output_path)
os.chdir(os.path.dirname(__file__ or args.prog) or '.')
host_cpu = platform.machine()
wheel_cpus = {
"darwin_arm64": "arm64",
"darwin_x86_64": "x86_64",
}
# TODO(phawkins): support other bazel cpu overrides.
wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None
else host_cpu)
# Find a working Bazel.
bazel_path = get_bazel_path(args.bazel_path)
bazel_path = get_bazel_path(args.bazel_path, wheel_cpu)
print("Bazel binary path: {}".format(bazel_path))
python_bin_path = get_python_bin_path(args.python_bin_path)
@ -495,6 +515,7 @@ def main():
print("SciPy version: {}".format(scipy_version))
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
print("Target CPU: {}".format(wheel_cpu))
print("Target CPU features: {}".format(args.target_cpu_features))
cuda_toolkit_path = args.cuda_path
@ -532,13 +553,15 @@ def main():
cudnn_version=args.cudnn_version,
rocm_toolkit_path=rocm_toolkit_path,
rocm_amdgpu_targets=args.rocm_amdgpu_targets,
)
cpu=args.target_cpu,
)
print("\nBuilding XLA and installing it in the jaxlib source tree...")
config_args = args.bazel_options
config_args += ["--config=short_logs"]
if args.target_cpu_features == "release":
if platform.uname().machine == "x86_64":
if wheel_cpu == "x86_64":
config_args += ["--config=avx_windows" if is_windows()
else "--config=avx_posix"]
elif args.target_cpu_features == "native":
@ -563,7 +586,8 @@ def main():
command = ([bazel_path] + args.bazel_startup_options +
["run", "--verbose_failures=true"] + config_args +
[":build_wheel", "--",
f"--output_path={output_path}"])
f"--output_path={output_path}",
f"--cpu={wheel_cpu}"])
print(" ".join(command))
shell(command)
shell([bazel_path, "shutdown"])

View File

@ -40,6 +40,11 @@ parser.add_argument(
default=None,
required=True,
help="Path to which the output wheel should be written. Required.")
parser.add_argument(
"--cpu",
default=None,
required=True,
help="Target CPU architecture. Required.")
args = parser.parse_args()
r = runfiles.Create()
@ -215,18 +220,15 @@ def prepare_wheel(sources_path):
patch_copy_tpu_client_py(jaxlib_dir)
def build_wheel(sources_path, output_path):
def build_wheel(sources_path, output_path, cpu):
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
if platform.system() == "Windows":
cpu_name = "amd64"
platform_name = "win"
else:
platform_name, cpu_name = {
("Linux", "x86_64"): ("manylinux2010", "x86_64"),
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
("Darwin", "x86_64"): ("macosx_10_9", "x86_64"),
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
}[(platform.system(), platform.machine())]
platform_name, cpu_name = {
("Linux", "x86_64"): ("manylinux2010", "x86_64"),
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
("Darwin", "x86_64"): ("macosx_10_9", "x86_64"),
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
("Windows", "x86_64"): ("win", "amd64"),
}[(platform.system(), cpu)]
python_tag_arg = (f"--python-tag=cp{sys.version_info.major}"
f"{sys.version_info.minor}")
platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}"
@ -252,7 +254,7 @@ if sources_path is None:
try:
os.makedirs(args.output_path, exist_ok=True)
prepare_wheel(sources_path)
build_wheel(sources_path, args.output_path)
build_wheel(sources_path, args.output_path, args.cpu)
finally:
if tmpdir:
tmpdir.cleanup()