Add a build flag that allows disabling remote TPU builds.

Disable remote TPU by default.
This commit is contained in:
Peter Hawkins 2022-06-23 21:14:52 +00:00
parent e4d1e1beb3
commit 22304eeb2e
3 changed files with 26 additions and 5 deletions

View File

@ -14,6 +14,7 @@
# JAX is Autograd and XLA
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows", "if_windows")
@ -22,6 +23,18 @@ licenses(["notice"]) # Apache 2
package(default_visibility = ["//visibility:public"])
bool_flag(
name = "enable_remote_tpu",
build_setting_default = False,
)
config_setting(
name = "remote_tpu_enabled",
flag_values = {
":enable_remote_tpu": "True",
},
)
py_binary(
name = "build_wheel",
srcs = ["build_wheel.py"],
@ -40,9 +53,10 @@ py_binary(
"//jaxlib/mlir:transforms",
] + if_windows([
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
]) + if_not_windows([
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
]) + if_cuda([
]) + select({
":remote_tpu_enabled": ["@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client"],
"//conditions:default": [],
}) + if_cuda([
"//jaxlib/cuda:cuda_gpu_support",
"@local_config_cuda//cuda:cuda-nvvm",
]) + if_rocm([

View File

@ -343,7 +343,11 @@ def main():
add_boolean_argument(
parser,
"enable_tpu",
help_str="Should we build with Cloud TPU support enabled?")
help_str="Should we build with Cloud TPU VM support enabled?")
add_boolean_argument(
parser,
"enable_remote_tpu",
help_str="Should we build with remote Cloud TPU support enabled?")
add_boolean_argument(
parser,
"enable_rocm",
@ -468,6 +472,7 @@ def main():
print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no"))
print("TPU enabled: {}".format("yes" if args.enable_tpu else "no"))
print("Remote TPU enabled: {}".format("yes" if args.enable_remote_tpu else "no"))
print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no"))
if args.enable_rocm:
@ -509,6 +514,8 @@ def main():
config_args += ["--config=nonccl"]
if args.enable_tpu:
config_args += ["--config=tpu"]
if args.enable_remote_tpu:
config_args += ["--//build:enable_remote_tpu=true"]
if args.enable_rocm:
config_args += ["--config=rocm"]
if not args.enable_nccl:

View File

@ -251,7 +251,7 @@ def prepare_wheel(sources_path):
patch_copy_xla_extension_stubs(jaxlib_dir)
patch_copy_xla_client_py(jaxlib_dir)
if not _is_windows():
if exists("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"):
copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so")
patch_copy_tpu_client_py(jaxlib_dir)