mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #18108 from google:libtpu_import_fix
PiperOrigin-RevId: 573302310
This commit is contained in:
commit
ce4d0d2b2d
@ -128,6 +128,9 @@ def tpu_client_timer_callback(timer_secs: float) -> Optional[xla_client.Client]:
|
||||
if xla_extension_version >= 205:
|
||||
client = xla_client.make_tpu_client(_get_tpu_library_path()) # type: ignore
|
||||
else:
|
||||
libtpu_module = maybe_import_libtpu()
|
||||
if libtpu_module is not None:
|
||||
libtpu_module.configure_library_path()
|
||||
client = xla_client.make_tpu_client() # type: ignore
|
||||
finally:
|
||||
t.cancel()
|
||||
|
Loading…
x
Reference in New Issue
Block a user