Remove older plugin device integration.

Users of this mechanism should migrate to the newer PJRT plugin registration mechanism (see the comments on discover_plugins() in this file).
This commit is contained in:
Peter Hawkins 2023-06-14 09:52:38 -04:00
parent b42282d30d
commit 119661ce6b
3 changed files with 1 additions and 21 deletions

View File

@ -41,7 +41,6 @@ build --copt=-DMLIR_PYTHON_PACKAGE_PREFIX=jaxlib.mlir.
# these values are overridden.
build --@xla//xla/python:enable_gpu=false
build --@xla//xla/python:enable_tpu=false
build --@xla//xla/python:enable_plugin_device=false
###########################################################################
@ -121,8 +120,6 @@ build:short_logs --output_filter=DONT_MATCH_ANYTHING
build:tpu --@xla//xla/python:enable_tpu=true
build:tpu --define=with_tpu_support=true
build:plugin_device --@xla//xla/python:enable_plugin_device=true
#########################################################################
# RBE config options below.
# Flag to enable remote config

View File

@ -219,7 +219,7 @@ def write_bazelrc(*, python_bin_path, remote_build,
cpu, cuda_compute_capabilities,
rocm_amdgpu_targets, bazel_options, target_cpu_features,
wheel_cpu, enable_mkl_dnn, enable_cuda, enable_nccl,
enable_tpu, enable_rocm, enable_plugin_device):
enable_tpu, enable_rocm):
tf_cuda_paths = []
with open("../.jax_configure.bazelrc", "w") as f:
@ -289,8 +289,6 @@ def write_bazelrc(*, python_bin_path, remote_build,
f.write("build --config=rocm\n")
if not enable_nccl:
f.write("build --config=nonccl\n")
if enable_plugin_device:
f.write("build --config=plugin_device\n")
BANNER = r"""
_ _ __ __
@ -382,11 +380,6 @@ def main():
default=True,
help_str="Should we build with NCCL enabled? Has no effect for non-CUDA "
"builds.")
add_boolean_argument(
parser,
"enable_plugin_device",
default=False,
help_str="Should we build with a plugin device enable?")
add_boolean_argument(
parser,
"remote_build",
@ -518,8 +511,6 @@ def main():
print(f"ROCm toolkit path: {rocm_toolkit_path}")
print(f"ROCm amdgpu targets: {args.rocm_amdgpu_targets}")
print("Plugin device enabled: {}".format("yes" if args.enable_plugin_device else "no"))
write_bazelrc(
python_bin_path=python_bin_path,
remote_build=args.remote_build,
@ -539,7 +530,6 @@ def main():
enable_nccl=args.enable_nccl,
enable_tpu=args.enable_tpu,
enable_rocm=args.enable_rocm,
enable_plugin_device=args.enable_plugin_device,
)
if args.configure_only:

View File

@ -268,13 +268,6 @@ if hasattr(xla_client, "make_tpu_client"):
register_backend_factory(
'tpu', partial(tpu_client_timer_callback, timer_secs=60.0), priority=300)
if hasattr(xla_client, "make_plugin_device_client"):
# It is assumed that if jax has been built with a plugin client, then the
# user wants to use the plugin client by default. Therefore, it gets the
# highest priority.
register_backend_factory("plugin", xla_client.make_plugin_device_client,
priority=400)
def _get_pjrt_plugin_names_and_library_paths(
plugins_from_env: str,