From 987a2f0850189f839e3b4d22b06340c6ddcc4457 Mon Sep 17 00:00:00 2001 From: jax authors Date: Wed, 12 Jun 2024 11:54:57 -0700 Subject: [PATCH] Enable jax's cloud-tpu configs when libtpu is present via through "pip install" or set by custom through the $TPU_LIBRARY_PATH env var PiperOrigin-RevId: 642688204 --- jax/_src/cloud_tpu_init.py | 17 +++++++++++++++-- jax/_src/xla_bridge.py | 19 ++++--------------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 71827d8f8..9fa45e703 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -16,6 +16,7 @@ import os from jax import version from jax._src import config from jax._src import hardware_utils +from typing import Optional running_in_cloud_tpu_vm: bool = False @@ -34,6 +35,18 @@ def maybe_import_libtpu(): return libtpu +def get_tpu_library_path() -> Optional[str]: + path_from_env = os.getenv("TPU_LIBRARY_PATH") + if path_from_env is not None and os.path.isfile(path_from_env): + return path_from_env + + libtpu_module = maybe_import_libtpu() + if libtpu_module is not None: + return libtpu_module.get_library_path() + + return None + + def jax_force_tpu_init() -> bool: return 'JAX_FORCE_TPU_INIT' in os.environ @@ -57,9 +70,9 @@ def cloud_tpu_init() -> None: global running_in_cloud_tpu_vm # Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed. - libtpu_module = maybe_import_libtpu() + libtpu_path = get_tpu_library_path() num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0] - if (libtpu_module is None or num_tpu_chips == 0) and not jax_force_tpu_init(): + if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init(): return running_in_cloud_tpu_vm = True diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index fa8bced4f..d370a31b9 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -41,7 +41,7 @@ from jax._src import distributed from jax._src import hardware_utils from jax._src import traceback_util from jax._src import util -from jax._src.cloud_tpu_init import maybe_import_libtpu +from jax._src.cloud_tpu_init import get_tpu_library_path from jax._src.lib import cuda_versions from jax._src.lib import xla_client from jax._src.lib import xla_extension @@ -133,17 +133,6 @@ _at_fork_handler_installed = False # Backends -def _get_tpu_library_path() -> str | None: - path_from_env = os.getenv("TPU_LIBRARY_PATH") - if path_from_env is not None: - return path_from_env - - libtpu_module = maybe_import_libtpu() - if libtpu_module is not None: - return libtpu_module.get_library_path() - - return None - def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None: def _log_warning(): @@ -160,11 +149,11 @@ def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None: try: if xla_extension_version >= 267: client = xla_client.make_tpu_client( # type: ignore - _get_tpu_library_path(), + get_tpu_library_path(), _options_from_jax_configs("tpu")) else: client = xla_client.make_tpu_client( - _get_tpu_library_path()) + get_tpu_library_path()) finally: t.cancel() @@ -1223,7 +1212,7 @@ def make_pjrt_topology(platform: str, topology_name='', **kwargs): # TODO(parkers): Get rid of this in favor of a generic way to get topologies. def make_pjrt_tpu_topology(topology_name='', **kwargs): if not xla_client.pjrt_plugin_loaded("tpu"): - library_path = _get_tpu_library_path() + library_path = get_tpu_library_path() if library_path is None: raise RuntimeError( "JAX TPU support not installed; cannot generate TPU topology. See"