mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Enable jax's cloud-tpu configs when libtpu is present via through "pip install" or set by custom through the $TPU_LIBRARY_PATH env var
PiperOrigin-RevId: 642688204
This commit is contained in:
parent
544975f622
commit
987a2f0850
@ -16,6 +16,7 @@ import os
|
||||
from jax import version
|
||||
from jax._src import config
|
||||
from jax._src import hardware_utils
|
||||
from typing import Optional
|
||||
|
||||
running_in_cloud_tpu_vm: bool = False
|
||||
|
||||
@ -34,6 +35,18 @@ def maybe_import_libtpu():
|
||||
return libtpu
|
||||
|
||||
|
||||
def get_tpu_library_path() -> Optional[str]:
|
||||
path_from_env = os.getenv("TPU_LIBRARY_PATH")
|
||||
if path_from_env is not None and os.path.isfile(path_from_env):
|
||||
return path_from_env
|
||||
|
||||
libtpu_module = maybe_import_libtpu()
|
||||
if libtpu_module is not None:
|
||||
return libtpu_module.get_library_path()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def jax_force_tpu_init() -> bool:
|
||||
return 'JAX_FORCE_TPU_INIT' in os.environ
|
||||
|
||||
@ -57,9 +70,9 @@ def cloud_tpu_init() -> None:
|
||||
global running_in_cloud_tpu_vm
|
||||
|
||||
# Exit early if we're not running on a Cloud TPU VM or libtpu isn't installed.
|
||||
libtpu_module = maybe_import_libtpu()
|
||||
libtpu_path = get_tpu_library_path()
|
||||
num_tpu_chips = hardware_utils.num_available_tpu_chips_and_device_id()[0]
|
||||
if (libtpu_module is None or num_tpu_chips == 0) and not jax_force_tpu_init():
|
||||
if (libtpu_path is None or num_tpu_chips == 0) and not jax_force_tpu_init():
|
||||
return
|
||||
|
||||
running_in_cloud_tpu_vm = True
|
||||
|
@ -41,7 +41,7 @@ from jax._src import distributed
|
||||
from jax._src import hardware_utils
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.cloud_tpu_init import maybe_import_libtpu
|
||||
from jax._src.cloud_tpu_init import get_tpu_library_path
|
||||
from jax._src.lib import cuda_versions
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension
|
||||
@ -133,17 +133,6 @@ _at_fork_handler_installed = False
|
||||
# Backends
|
||||
|
||||
|
||||
def _get_tpu_library_path() -> str | None:
|
||||
path_from_env = os.getenv("TPU_LIBRARY_PATH")
|
||||
if path_from_env is not None:
|
||||
return path_from_env
|
||||
|
||||
libtpu_module = maybe_import_libtpu()
|
||||
if libtpu_module is not None:
|
||||
return libtpu_module.get_library_path()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None:
|
||||
def _log_warning():
|
||||
@ -160,11 +149,11 @@ def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None:
|
||||
try:
|
||||
if xla_extension_version >= 267:
|
||||
client = xla_client.make_tpu_client( # type: ignore
|
||||
_get_tpu_library_path(),
|
||||
get_tpu_library_path(),
|
||||
_options_from_jax_configs("tpu"))
|
||||
else:
|
||||
client = xla_client.make_tpu_client(
|
||||
_get_tpu_library_path())
|
||||
get_tpu_library_path())
|
||||
finally:
|
||||
t.cancel()
|
||||
|
||||
@ -1223,7 +1212,7 @@ def make_pjrt_topology(platform: str, topology_name='', **kwargs):
|
||||
# TODO(parkers): Get rid of this in favor of a generic way to get topologies.
|
||||
def make_pjrt_tpu_topology(topology_name='', **kwargs):
|
||||
if not xla_client.pjrt_plugin_loaded("tpu"):
|
||||
library_path = _get_tpu_library_path()
|
||||
library_path = get_tpu_library_path()
|
||||
if library_path is None:
|
||||
raise RuntimeError(
|
||||
"JAX TPU support not installed; cannot generate TPU topology. See"
|
||||
|
Loading…
x
Reference in New Issue
Block a user