Merge pull request #11961 from jakeh-gc:plugin_device

PiperOrigin-RevId: 476363760
This commit is contained in:
jax authors 2022-09-23 07:29:17 -07:00
commit 254dc24a8b
3 changed files with 22 additions and 1 deletions

View File

@ -35,6 +35,7 @@ build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
# these values are overridden.
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=false
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false
###########################################################################
@ -113,6 +114,8 @@ build:short_logs --output_filter=DONT_MATCH_ANYTHING
build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true
build:tpu --define=with_tpu_support=true
build:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true
#########################################################################
# RBE config options below.
# Flag to enable remote config

View File

@ -219,7 +219,8 @@ def write_bazelrc(*, python_bin_path, remote_build,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, bazel_options, target_cpu_features,
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
enable_tpu, enable_remote_tpu, enable_rocm):
enable_tpu, enable_remote_tpu, enable_rocm,
enable_plugin_device):
tf_cuda_paths = []
with open("../.jax_configure.bazelrc", "w") as f:
@ -291,6 +292,8 @@ def write_bazelrc(*, python_bin_path, remote_build,
f.write("build --config=rocm\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if enable_plugin_device:
f.write("build --config=plugin_device\n")
BANNER = r"""
_ _ __ __
@ -386,6 +389,11 @@ def main():
default=True,
help_str="Should we build with NCCL enabled? Has no effect for non-CUDA "
"builds.")
add_boolean_argument(
parser,
"enable_plugin_device",
default=False,
help_str="Should we build with a plugin device enable?")
add_boolean_argument(
parser,
"remote_build",
@ -514,6 +522,8 @@ def main():
print(f"ROCm toolkit path: {rocm_toolkit_path}")
print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}")
print("Plugin device enabled: {}".format("yes" if args.enable_plugin_device else "no"))
write_bazelrc(
python_bin_path=python_bin_path,
remote_build=args.remote_build,
@ -534,6 +544,7 @@ def main():
enable_tpu=args.enable_tpu,
enable_remote_tpu=args.enable_remote_tpu,
enable_rocm=args.enable_rocm,
enable_plugin_device=args.enable_plugin_device,
)
if args.configure_only:

View File

@ -251,6 +251,13 @@ if hasattr(xla_client, "make_tpu_client"):
register_backend_factory(
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
if hasattr(xla_client, "make_plugin_device_client"):
# It is assumed that if jax has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the
# highest priority.
register_backend_factory("plugin", xla_client.make_plugin_device_client,
priority=400)
if iree is not None:
register_backend_factory("iree", iree.iree_client_factory, priority=-100)