diff --git a/jax/_src/xla_bridge.py b/jax/_src/xla_bridge.py index 768bac203..ec7b1079b 100644 --- a/jax/_src/xla_bridge.py +++ b/jax/_src/xla_bridge.py @@ -37,6 +37,7 @@ from jax._src import distributed from jax._src import config as jax_config from jax._src.config import config from jax._src.lib import xla_client +from jax._src.lib import xla_extension_version from jax._src import traceback_util from jax._src import util @@ -351,7 +352,8 @@ def register_plugin( options: Optional. It is used when creating a PJRT plugin client. """ def factory(): - # Plugin may already be statically linked in some configurations. + # Plugin may already be statically linked in some configurations, or we + # could be creating a client twice. if not xla_client.pjrt_plugin_loaded(plugin_name): if library_path is None: raise ValueError( @@ -359,6 +361,9 @@ def register_plugin( ' plugin.' ) xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) + if xla_extension_version >= 183: + if not xla_client.pjrt_plugin_initialized(plugin_name): + xla_client.initialize_pjrt_plugin(plugin_name) if distributed.global_state.client is None: return xla_client.make_c_api_client(plugin_name, options, None) diff --git a/tests/xla_bridge_test.py b/tests/xla_bridge_test.py index 56770c7bf..9b59bebe9 100644 --- a/tests/xla_bridge_test.py +++ b/tests/xla_bridge_test.py @@ -23,6 +23,7 @@ from jax._src import compiler from jax._src import test_util as jtu from jax._src import xla_bridge as xb from jax._src.lib import xla_client as xc +from jax._src.lib import xla_extension_version from jax._src.interpreters import xla from jax._src import config as jax_config @@ -152,8 +153,18 @@ class XlaBridgeTest(jtu.JaxTestCase): with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): with mock.patch.object( - xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded: - registration.factory() + xc, "pjrt_plugin_loaded", autospec=True + ) as mock_plugin_loaded: + if xla_extension_version < 183: + registration.factory() + else: + with mock.patch.object( + xc, "pjrt_plugin_initialized", autospec=True, return_vale=True + ): + with mock.patch.object( + xc, "initialize_pjrt_plugin", autospec=True + ): + registration.factory() self.assertRegex( log_output[1][0], @@ -171,16 +182,28 @@ class XlaBridgeTest(jtu.JaxTestCase): test_json_file_path = os.path.join( os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json" ) - os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = ( - f"name1;{test_json_file_path}" if platform.system() == "Windows" - else f"name1:{test_json_file_path}") + os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = ( + f"name1;{test_json_file_path}" + if platform.system() == "Windows" + else f"name1:{test_json_file_path}" + ) xb.register_pjrt_plugin_factories_from_env() registration = xb._backend_factories["name1"] with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make: with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True): with mock.patch.object( - xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded: - registration.factory() + xc, "pjrt_plugin_loaded", autospec=True + ) as mock_plugin_loaded: + if xla_extension_version < 183: + registration.factory() + else: + with mock.patch.object( + xc, "pjrt_plugin_initialized", autospec=True, return_vale=True + ): + with mock.patch.object( + xc, "initialize_pjrt_plugin", autospec=True + ): + registration.factory() self.assertIn("name1", xb._backend_factories) self.assertEqual(registration.priority, 400)