Do not init backends from topology construction, instead directly init the

plugin.

PiperOrigin-RevId: 577331743
This commit is contained in:
Parker Schuh 2023-10-27 16:20:28 -07:00 committed by jax authors
parent b99958db37
commit 19c65353d2

View File

@ -940,9 +940,16 @@ def using_pjrt_c_api(backend=None):
# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
def make_pjrt_tpu_topology(topology_name='', **kwargs):
# TODO(b/261484192): Make a system for lazily loading libtpu.so and call
# that inside make_tfrt_tpu_c_api_device_topology.
get_backend() # Properly initialize libtpu.so.
if not xla_client.pjrt_plugin_loaded("tpu"):
library_path = _get_tpu_library_path()
if library_path is None:
raise RuntimeError(
"JAX TPU support not installed; cannot generate TPU topology. See"
" https://github.com/google/jax#installation")
xla_client.load_pjrt_plugin_dynamically("tpu", library_path)
assert xla_client.pjrt_plugin_loaded("tpu")
if not xla_client.pjrt_plugin_initialized("tpu"):
xla_client.initialize_pjrt_plugin("tpu")
return xla_client.make_tfrt_tpu_c_api_device_topology(
topology_name, **kwargs
)