mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #11961 from jakeh-gc:plugin_device
PiperOrigin-RevId: 476363760
This commit is contained in:
commit
254dc24a8b
3
.bazelrc
3
.bazelrc
@ -35,6 +35,7 @@ build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
|
||||
# these values are overridden.
|
||||
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_gpu=false
|
||||
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=false
|
||||
build --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=false
|
||||
|
||||
###########################################################################
|
||||
|
||||
@ -113,6 +114,8 @@ build:short_logs --output_filter=DONT_MATCH_ANYTHING
|
||||
build:tpu --@org_tensorflow//tensorflow/compiler/xla/python:enable_tpu=true
|
||||
build:tpu --define=with_tpu_support=true
|
||||
|
||||
build:plugin_device --@org_tensorflow//tensorflow/compiler/xla/python:enable_plugin_device=true
|
||||
|
||||
#########################################################################
|
||||
# RBE config options below.
|
||||
# Flag to enable remote config
|
||||
|
@ -219,7 +219,8 @@ def write_bazelrc(*, python_bin_path, remote_build,
|
||||
cpu, cuda_compute_capabilities,
|
||||
rocm_amdgpu_targets, bazel_options, target_cpu_features,
|
||||
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
|
||||
enable_tpu, enable_remote_tpu, enable_rocm):
|
||||
enable_tpu, enable_remote_tpu, enable_rocm,
|
||||
enable_plugin_device):
|
||||
tf_cuda_paths = []
|
||||
|
||||
with open("../.jax_configure.bazelrc", "w") as f:
|
||||
@ -291,6 +292,8 @@ def write_bazelrc(*, python_bin_path, remote_build,
|
||||
f.write("build --config=rocm\n")
|
||||
if not enable_nccl:
|
||||
f.write("build --config=nonccl\n")
|
||||
if enable_plugin_device:
|
||||
f.write("build --config=plugin_device\n")
|
||||
|
||||
BANNER = r"""
|
||||
_ _ __ __
|
||||
@ -386,6 +389,11 @@ def main():
|
||||
default=True,
|
||||
help_str="Should we build with NCCL enabled? Has no effect for non-CUDA "
|
||||
"builds.")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"enable_plugin_device",
|
||||
default=False,
|
||||
help_str="Should we build with a plugin device enable?")
|
||||
add_boolean_argument(
|
||||
parser,
|
||||
"remote_build",
|
||||
@ -514,6 +522,8 @@ def main():
|
||||
print(f"ROCm toolkit path: {rocm_toolkit_path}")
|
||||
print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}")
|
||||
|
||||
print("Plugin device enabled: {}".format("yes" if args.enable_plugin_device else "no"))
|
||||
|
||||
write_bazelrc(
|
||||
python_bin_path=python_bin_path,
|
||||
remote_build=args.remote_build,
|
||||
@ -534,6 +544,7 @@ def main():
|
||||
enable_tpu=args.enable_tpu,
|
||||
enable_remote_tpu=args.enable_remote_tpu,
|
||||
enable_rocm=args.enable_rocm,
|
||||
enable_plugin_device=args.enable_plugin_device,
|
||||
)
|
||||
|
||||
if args.configure_only:
|
||||
|
@ -251,6 +251,13 @@ if hasattr(xla_client, "make_tpu_client"):
|
||||
register_backend_factory(
|
||||
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
|
||||
|
||||
if hasattr(xla_client, "make_plugin_device_client"):
|
||||
# It is assumed that if jax has been built with a plugin client, then the
|
||||
# user wants to use the plugin client by default. Therefore, it gets the
|
||||
# highest priority.
|
||||
register_backend_factory("plugin", xla_client.make_plugin_device_client,
|
||||
priority=400)
|
||||
|
||||
if iree is not None:
|
||||
register_backend_factory("iree", iree.iree_client_factory, priority=-100)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user