mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add a build flag that allows disabling remote TPU builds.
Disable remote TPU by default.
This commit is contained in:
parent
e4d1e1beb3
commit
22304eeb2e
@ -14,6 +14,7 @@
|
||||
|
||||
# JAX is Autograd and XLA
|
||||
|
||||
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
|
||||
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
|
||||
load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm")
|
||||
load("@org_tensorflow//tensorflow:tensorflow.bzl", "if_not_windows", "if_windows")
|
||||
@ -22,6 +23,18 @@ licenses(["notice"]) # Apache 2
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
|
||||
bool_flag(
|
||||
name = "enable_remote_tpu",
|
||||
build_setting_default = False,
|
||||
)
|
||||
|
||||
config_setting(
|
||||
name = "remote_tpu_enabled",
|
||||
flag_values = {
|
||||
":enable_remote_tpu": "True",
|
||||
},
|
||||
)
|
||||
|
||||
py_binary(
|
||||
name = "build_wheel",
|
||||
srcs = ["build_wheel.py"],
|
||||
@ -40,9 +53,10 @@ py_binary(
|
||||
"//jaxlib/mlir:transforms",
|
||||
] + if_windows([
|
||||
"//jaxlib/mlir/_mlir_libs:jaxlib_mlir_capi.dll",
|
||||
]) + if_not_windows([
|
||||
"@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client",
|
||||
]) + if_cuda([
|
||||
]) + select({
|
||||
":remote_tpu_enabled": ["@org_tensorflow//tensorflow/compiler/xla/python/tpu_driver/client:py_tpu_client"],
|
||||
"//conditions:default": [],
|
||||
}) + if_cuda([
|
||||
"//jaxlib/cuda:cuda_gpu_support",
|
||||
"@local_config_cuda//cuda:cuda-nvvm",
|
||||
]) + if_rocm([
|
||||
|
@ -343,7 +343,11 @@ def main():
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"enable_tpu",
|
||||
help_str="Should we build with Cloud TPU support enabled?")
|
||||
help_str="Should we build with Cloud TPU VM support enabled?")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"enable_remote_tpu",
|
||||
help_str="Should we build with remote Cloud TPU support enabled?")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"enable_rocm",
|
||||
@ -468,6 +472,7 @@ def main():
|
||||
print("NCCL enabled: {}".format("yes" if args.enable_nccl else "no"))
|
||||
|
||||
print("TPU enabled: {}".format("yes" if args.enable_tpu else "no"))
|
||||
print("Remote TPU enabled: {}".format("yes" if args.enable_remote_tpu else "no"))
|
||||
|
||||
print("ROCm enabled: {}".format("yes" if args.enable_rocm else "no"))
|
||||
if args.enable_rocm:
|
||||
@ -509,6 +514,8 @@ def main():
|
||||
config_args += ["--config=nonccl"]
|
||||
if args.enable_tpu:
|
||||
config_args += ["--config=tpu"]
|
||||
if args.enable_remote_tpu:
|
||||
config_args += ["--//build:enable_remote_tpu=true"]
|
||||
if args.enable_rocm:
|
||||
config_args += ["--config=rocm"]
|
||||
if not args.enable_nccl:
|
||||
|
@ -251,7 +251,7 @@ def prepare_wheel(sources_path):
|
||||
patch_copy_xla_extension_stubs(jaxlib_dir)
|
||||
patch_copy_xla_client_py(jaxlib_dir)
|
||||
|
||||
if not _is_windows():
|
||||
if exists("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so"):
|
||||
copy_to_jaxlib("org_tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/tpu_client_extension.so")
|
||||
patch_copy_tpu_client_py(jaxlib_dir)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user