mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add support for cross-compiling jaxlib for Mac ARM.
This commit is contained in:
parent
25e44821dd
commit
f5c61a892a
@ -135,9 +135,9 @@ bazel_packages = {
|
||||
}
|
||||
|
||||
|
||||
def download_and_verify_bazel():
|
||||
def download_and_verify_bazel(target_cpu):
|
||||
"""Downloads a bazel binary from Github, verifying its SHA256 hash."""
|
||||
package = bazel_packages.get((platform.system(), platform.machine()))
|
||||
package = bazel_packages.get((platform.system(), target_cpu))
|
||||
if package is None:
|
||||
return None
|
||||
|
||||
@ -182,16 +182,16 @@ def download_and_verify_bazel():
|
||||
return os.path.join(".", package.file)
|
||||
|
||||
|
||||
def get_bazel_paths(bazel_path_flag):
|
||||
def get_bazel_paths(bazel_path_flag, target_cpu):
|
||||
"""Yields a sequence of guesses about bazel path. Some of sequence elements
|
||||
can be None. The resulting iterator is lazy and potentially has a side
|
||||
effects."""
|
||||
yield bazel_path_flag
|
||||
yield which("bazel")
|
||||
yield download_and_verify_bazel()
|
||||
yield download_and_verify_bazel(target_cpu)
|
||||
|
||||
|
||||
def get_bazel_path(bazel_path_flag):
|
||||
def get_bazel_path(bazel_path_flag, target_cpu):
|
||||
"""Returns the path to a Bazel binary, downloading Bazel if not found. Also,
|
||||
it checks Bazel's version at lease newer than 2.0.0.
|
||||
|
||||
@ -199,7 +199,7 @@ def get_bazel_path(bazel_path_flag):
|
||||
releases performs version check against .bazelversion (see for details
|
||||
https://blog.bazel.build/2019/12/19/bazel-2.0.html#other-important-changes).
|
||||
"""
|
||||
for path in filter(None, get_bazel_paths(bazel_path_flag)):
|
||||
for path in filter(None, get_bazel_paths(bazel_path_flag, target_cpu)):
|
||||
if check_bazel_version(path):
|
||||
return path
|
||||
|
||||
@ -230,7 +230,6 @@ build --repo_env TF_NEED_CUDA="{tf_need_cuda}"
|
||||
build --action_env TF_CUDA_COMPUTE_CAPABILITIES="{cuda_compute_capabilities}"
|
||||
build --repo_env TF_NEED_ROCM="{tf_need_rocm}"
|
||||
build --action_env TF_ROCM_AMDGPU_TARGETS="{rocm_amdgpu_targets}"
|
||||
build --distinct_host_configuration=false
|
||||
build:posix --copt=-Wno-sign-compare
|
||||
build -c opt
|
||||
build:avx_posix --copt=-mavx
|
||||
@ -314,7 +313,8 @@ build:linux --copt=-Wno-stringop-truncation
|
||||
|
||||
|
||||
def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None,
|
||||
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None, **kwargs):
|
||||
cuda_version=None, cudnn_version=None, rocm_toolkit_path=None,
|
||||
cpu=None, **kwargs):
|
||||
with open("../.bazelrc", "w") as f:
|
||||
f.write(BAZELRC_TEMPLATE.format(**kwargs))
|
||||
if cuda_toolkit_path:
|
||||
@ -332,6 +332,12 @@ def write_bazelrc(cuda_toolkit_path=None, cudnn_install_path=None,
|
||||
if rocm_toolkit_path:
|
||||
f.write("build --action_env ROCM_PATH=\"{rocm_toolkit_path}\"\n"
|
||||
.format(rocm_toolkit_path=rocm_toolkit_path))
|
||||
if cpu is not None:
|
||||
f.write("build --distinct_host_configuration=true\n")
|
||||
f.write(f"build --cpu={cpu}\n")
|
||||
else:
|
||||
f.write("build --distinct_host_configuration=false\n")
|
||||
|
||||
|
||||
BANNER = r"""
|
||||
_ _ __ __
|
||||
@ -463,6 +469,11 @@ def main():
|
||||
"--output_path",
|
||||
default=os.path.join(cwd, "dist"),
|
||||
help="Directory to which the jaxlib wheel should be written")
|
||||
parser.add_argument(
|
||||
"--target_cpu",
|
||||
default=None,
|
||||
help="CPU platform to target. Default is the same as the host machine. "
|
||||
"Currently supported values are 'darwin_arm64' and 'darwin_x86_64'.")
|
||||
args = parser.parse_args()
|
||||
|
||||
if is_windows() and args.enable_cuda:
|
||||
@ -479,8 +490,17 @@ def main():
|
||||
output_path = os.path.abspath(args.output_path)
|
||||
os.chdir(os.path.dirname(__file__ or args.prog) or '.')
|
||||
|
||||
host_cpu = platform.machine()
|
||||
wheel_cpus = {
|
||||
"darwin_arm64": "arm64",
|
||||
"darwin_x86_64": "x86_64",
|
||||
}
|
||||
# TODO(phawkins): support other bazel cpu overrides.
|
||||
wheel_cpu = (wheel_cpus[args.target_cpu] if args.target_cpu is not None
|
||||
else host_cpu)
|
||||
|
||||
# Find a working Bazel.
|
||||
bazel_path = get_bazel_path(args.bazel_path)
|
||||
bazel_path = get_bazel_path(args.bazel_path, wheel_cpu)
|
||||
print("Bazel binary path: {}".format(bazel_path))
|
||||
|
||||
python_bin_path = get_python_bin_path(args.python_bin_path)
|
||||
@ -495,6 +515,7 @@ def main():
|
||||
print("SciPy version: {}".format(scipy_version))
|
||||
|
||||
print("MKL-DNN enabled: {}".format("yes" if args.enable_mkl_dnn else "no"))
|
||||
print("Target CPU: {}".format(wheel_cpu))
|
||||
print("Target CPU features: {}".format(args.target_cpu_features))
|
||||
|
||||
cuda_toolkit_path = args.cuda_path
|
||||
@ -532,13 +553,15 @@ def main():
|
||||
cudnn_version=args.cudnn_version,
|
||||
rocm_toolkit_path=rocm_toolkit_path,
|
||||
rocm_amdgpu_targets=args.rocm_amdgpu_targets,
|
||||
)
|
||||
cpu=args.target_cpu,
|
||||
)
|
||||
|
||||
|
||||
print("\nBuilding XLA and installing it in the jaxlib source tree...")
|
||||
config_args = args.bazel_options
|
||||
config_args += ["--config=short_logs"]
|
||||
if args.target_cpu_features == "release":
|
||||
if platform.uname().machine == "x86_64":
|
||||
if wheel_cpu == "x86_64":
|
||||
config_args += ["--config=avx_windows" if is_windows()
|
||||
else "--config=avx_posix"]
|
||||
elif args.target_cpu_features == "native":
|
||||
@ -563,7 +586,8 @@ def main():
|
||||
command = ([bazel_path] + args.bazel_startup_options +
|
||||
["run", "--verbose_failures=true"] + config_args +
|
||||
[":build_wheel", "--",
|
||||
f"--output_path={output_path}"])
|
||||
f"--output_path={output_path}",
|
||||
f"--cpu={wheel_cpu}"])
|
||||
print(" ".join(command))
|
||||
shell(command)
|
||||
shell([bazel_path, "shutdown"])
|
||||
|
@ -40,6 +40,11 @@ parser.add_argument(
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to which the output wheel should be written. Required.")
|
||||
parser.add_argument(
|
||||
"--cpu",
|
||||
default=None,
|
||||
required=True,
|
||||
help="Target CPU architecture. Required.")
|
||||
args = parser.parse_args()
|
||||
|
||||
r = runfiles.Create()
|
||||
@ -215,18 +220,15 @@ def prepare_wheel(sources_path):
|
||||
patch_copy_tpu_client_py(jaxlib_dir)
|
||||
|
||||
|
||||
def build_wheel(sources_path, output_path):
|
||||
def build_wheel(sources_path, output_path, cpu):
|
||||
"""Builds a wheel in `output_path` using the source tree in `sources_path`."""
|
||||
if platform.system() == "Windows":
|
||||
cpu_name = "amd64"
|
||||
platform_name = "win"
|
||||
else:
|
||||
platform_name, cpu_name = {
|
||||
("Linux", "x86_64"): ("manylinux2010", "x86_64"),
|
||||
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
|
||||
("Darwin", "x86_64"): ("macosx_10_9", "x86_64"),
|
||||
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
|
||||
}[(platform.system(), platform.machine())]
|
||||
platform_name, cpu_name = {
|
||||
("Linux", "x86_64"): ("manylinux2010", "x86_64"),
|
||||
("Linux", "aarch64"): ("manylinux2014", "aarch64"),
|
||||
("Darwin", "x86_64"): ("macosx_10_9", "x86_64"),
|
||||
("Darwin", "arm64"): ("macosx_11_0", "arm64"),
|
||||
("Windows", "x86_64"): ("win", "amd64"),
|
||||
}[(platform.system(), cpu)]
|
||||
python_tag_arg = (f"--python-tag=cp{sys.version_info.major}"
|
||||
f"{sys.version_info.minor}")
|
||||
platform_tag_arg = f"--plat-name={platform_name}_{cpu_name}"
|
||||
@ -252,7 +254,7 @@ if sources_path is None:
|
||||
try:
|
||||
os.makedirs(args.output_path, exist_ok=True)
|
||||
prepare_wheel(sources_path)
|
||||
build_wheel(sources_path, args.output_path)
|
||||
build_wheel(sources_path, args.output_path, args.cpu)
|
||||
finally:
|
||||
if tmpdir:
|
||||
tmpdir.cleanup()
|
||||
|
Loading…
x
Reference in New Issue
Block a user