diff --git a/jax/BUILD b/jax/BUILD index 9a98cf4fa..ea3ddcc76 100644 --- a/jax/BUILD +++ b/jax/BUILD @@ -389,6 +389,7 @@ pytype_strict_library( name = "cloud_tpu_init", srcs = ["_src/cloud_tpu_init.py"], deps = [ + ":config", ":hardware_utils", ":version", ], diff --git a/jax/_src/cloud_tpu_init.py b/jax/_src/cloud_tpu_init.py index 68fe25562..71827d8f8 100644 --- a/jax/_src/cloud_tpu_init.py +++ b/jax/_src/cloud_tpu_init.py @@ -13,8 +13,9 @@ # limitations under the License. import os -from jax._src import hardware_utils from jax import version +from jax._src import config +from jax._src import hardware_utils running_in_cloud_tpu_vm: bool = False @@ -73,3 +74,9 @@ def cloud_tpu_init() -> None: # this makes tensorstore serialization work better on TPU os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_TIME_SECONDS', '60') os.environ.setdefault('TENSORSTORE_CURL_LOW_SPEED_LIMIT_BYTES', '256') + + if config.jax_pjrt_client_create_options.value is None: + config.update( + 'jax_pjrt_client_create_options', + f'ml_framework_name:JAX;ml_framework_version:{version.__version__}' + ) diff --git a/jax/_src/config.py b/jax/_src/config.py index 8ae6dab0b..4ee6b16ab 100644 --- a/jax/_src/config.py +++ b/jax/_src/config.py @@ -935,6 +935,12 @@ jax_platforms = define_optional_string_state( 'otherwise.' )) +jax_pjrt_client_create_options = define_optional_string_state( + name='jax_pjrt_client_create_options', + default=None, + help=('A set of key-value pairs in the format of "k1:v1;k2:v2" strings ' + 'provided to a device platform pjrt client as extra arguments.')) + enable_checks = define_bool_state( name='jax_enable_checks', default=False, diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 55ec8455d..c68d39abf 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -47,6 +47,7 @@ from jax._src.cloud_tpu_init import maybe_import_libtpu from jax._src.lib import cuda_versions from jax._src.lib import xla_client from jax._src.lib import xla_extension +from jax._src.lib import xla_extension_version from jax._src.lib import jaxlib logger = logging.getLogger(__name__) @@ -160,7 +161,13 @@ def tpu_client_timer_callback(timer_secs: float) -> xla_client.Client | None: t.start() try: - client = xla_client.make_tpu_client(_get_tpu_library_path()) + if xla_extension_version >= 267: + client = xla_client.make_tpu_client( # type: ignore + _get_tpu_library_path(), + _options_from_jax_configs("tpu")) + else: + client = xla_client.make_tpu_client( + _get_tpu_library_path()) finally: t.cancel() @@ -618,16 +625,30 @@ def discover_pjrt_plugins() -> None: def _options_from_jax_configs(plugin_name): - if plugin_name != "cuda": - return {} - options = {} - visible_devices = CUDA_VISIBLE_DEVICES.value - if visible_devices != 'all': - options['visible_devices'] = [int(x) for x in visible_devices.split(',')] - options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value - if options['enable_mock_nccl']: - options['num_nodes'] = _MOCK_NUM_GPUS.value + + pjrt_client_options = config.jax_pjrt_client_create_options.value + pjrt_client_option_list = [] + if pjrt_client_options: + pjrt_client_option_list = pjrt_client_options.split(";") + + for option in pjrt_client_option_list: + option_list = option.split(":") + if (len(option_list) != 2): + raise RuntimeError( + "Multiple ':' separators for option in " + f"jax_pjrt_client_create_options: '{option}'. " + "Should be in format 'key:value'") + options[option_list[0]] = option_list[1] + + if plugin_name == "cuda": + visible_devices = CUDA_VISIBLE_DEVICES.value + if visible_devices != 'all': + options['visible_devices'] = [int(x) for x in visible_devices.split(',')] + options['enable_mock_nccl'] = _USE_MOCK_GPU_CLIENT.value + if options['enable_mock_nccl']: + options['num_nodes'] = _MOCK_NUM_GPUS.value + return options diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index ecbf2e202..d118d0e64 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -26,6 +26,7 @@ from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.interpreters import xla from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version config.parse_flags_with_absl() @@ -143,7 +144,7 @@ class XlaBridgeTest(jtu.JaxTestCase): with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") - def _mock_tpu_client(library_path=None): + def _mock_tpu_client_with_options(library_path=None, options=None): time_to_wait = 5 start = time.time() while not w: @@ -157,9 +158,17 @@ class XlaBridgeTest(jtu.JaxTestCase): msg = str(w[-1].message) self.assertIn("Did you run your code on all TPU hosts?", msg) - with mock.patch.object(xc, "make_tpu_client", - side_effect=_mock_tpu_client): - xb.tpu_client_timer_callback(0.01) + def _mock_tpu_client(library_path=None): + _mock_tpu_client_with_options(library_path=library_path, options=None) + + if xla_extension_version >= 267: + with mock.patch.object(xc, "make_tpu_client", + side_effect=_mock_tpu_client_with_options): + xb.tpu_client_timer_callback(0.01) + else: + with mock.patch.object(xc, "make_tpu_client", + side_effect=_mock_tpu_client): + xb.tpu_client_timer_callback(0.01) def test_register_plugin(self): with self.assertLogs(level="WARNING") as log_output: