Enable jax's cloud-tpu configs when libtpu is present via through "pip install" or set by custom through the $TPU_LIBRARY_PATH env var

PiperOrigin-RevId: 642688204
This commit is contained in:
jax authors 2024-06-12 11:54:57 -07:00 committed by jax authors
parent 544975f622
commit 987a2f0850
2 changed files with 19 additions and 17 deletions

View File

@ -16,6 +16,7 @@ import os
from jax import version
from jax._src import config
from jax._src import hardware_utils
from typing import Optional
running_in_cloud_tpu_vm: bool = False
@ -34,6 +35,18 @@ def maybe_import_libtpu():
return libtpu
def get_tpu_library_path() -> Optional[str]:
path_from_env = os.getenv("TPU_LIBRARY_PATH")
if path_from_env is not None and os.path.isfile(path_from_env):
return path_from_env
libtpu_module = maybe_import_libtpu()
if libtpu_module is not None:
return libtpu_module.get_library_path()
return None
def jax_force_tpu_init() -> bool:
return 'JAX_FORCE_TPU_INIT' in os.environ
@ -57,9 +70,9 @@ def cloud_tpu_init() -> None:
global running_in_cloud_tpu_vm
# Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed.
libtpu_module = maybe_import_libtpu()
libtpu_path = get_tpu_library_path()
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
if (libtpu_module is None or num_tpu_chips == 0) and not jax_force_tpu_init():
if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init():
return
running_in_cloud_tpu_vm = True

View File

@ -41,7 +41,7 @@ from jax._src import distributed
from jax._src import hardware_utils
from jax._src import traceback_util
from jax._src import util
from jax._src.cloud_tpu_init import maybe_import_libtpu
from jax._src.cloud_tpu_init import get_tpu_library_path
from jax._src.lib import cuda_versions
from jax._src.lib import xla_client
from jax._src.lib import xla_extension
@ -133,17 +133,6 @@ _at_fork_handler_installed = False
# Backends
def _get_tpu_library_path() -> str | None:
path_from_env = os.getenv("TPU_LIBRARY_PATH")
if path_from_env is not None:
return path_from_env
libtpu_module = maybe_import_libtpu()
if libtpu_module is not None:
return libtpu_module.get_library_path()
return None
def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None:
def _log_warning():
@ -160,11 +149,11 @@ def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None:
try:
if xla_extension_version >= 267:
client = xla_client.make_tpu_client( # type: ignore
_get_tpu_library_path(),
get_tpu_library_path(),
_options_from_jax_configs("tpu"))
else:
client = xla_client.make_tpu_client(
_get_tpu_library_path())
get_tpu_library_path())
finally:
t.cancel()
@ -1223,7 +1212,7 @@ def make_pjrt_topology(platform: str, topology_name='', **kwargs):
# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
def make_pjrt_tpu_topology(topology_name='', **kwargs):
if not xla_client.pjrt_plugin_loaded("tpu"):
library_path = _get_tpu_library_path()
library_path = get_tpu_library_path()
if library_path is None:
raise RuntimeError(
"JAX TPU support not installed; cannot generate TPU topology. See"