mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Do not init backends from topology construction, instead directly init the
plugin. PiperOrigin-RevId: 577331743
This commit is contained in:
parent
b99958db37
commit
19c65353d2
@ -940,9 +940,16 @@ def using_pjrt_c_api(backend=None):
|
||||
|
||||
# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
|
||||
def make_pjrt_tpu_topology(topology_name='', **kwargs):
|
||||
# TODO(b/261484192): Make a system for lazily loading libtpu.so and call
|
||||
# that inside make_tfrt_tpu_c_api_device_topology.
|
||||
get_backend() # Properly initialize libtpu.so.
|
||||
if not xla_client.pjrt_plugin_loaded("tpu"):
|
||||
library_path = _get_tpu_library_path()
|
||||
if library_path is None:
|
||||
raise RuntimeError(
|
||||
"JAX TPU support not installed; cannot generate TPU topology. See"
|
||||
" https://github.com/google/jax#installation")
|
||||
xla_client.load_pjrt_plugin_dynamically("tpu", library_path)
|
||||
assert xla_client.pjrt_plugin_loaded("tpu")
|
||||
if not xla_client.pjrt_plugin_initialized("tpu"):
|
||||
xla_client.initialize_pjrt_plugin("tpu")
|
||||
return xla_client.make_tfrt_tpu_c_api_device_topology(
|
||||
topology_name, **kwargs
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user