[PJRT C API] Let framework explicitly check whether a plugin is initialized and initialize the plugin.

Before this change, PJRT_Plugin_Initialize was called in LoadPjrtPlugin, which is only used in dynamic linking case. This change adds a bool and a method to check whether the plugin is initialized. The framework will explicitly check whether a plugin is initialized, and call InitializePjrtPlugin if it is not. This will be apply to both static linking and dynamic linking case.

PiperOrigin-RevId: 557268670
This commit is contained in:
Jieying Luo 2023-08-15 15:23:17 -07:00 committed by jax authors
parent b7796710e4
commit c7e8b81a74
2 changed files with 36 additions and 8 deletions

View File

@ -37,6 +37,7 @@ from jax._src import distributed
from jax._src import config as jax_config
from jax._src.config import config
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src import traceback_util
from jax._src import util
@ -351,7 +352,8 @@ def register_plugin(
options: Optional. It is used when creating a PJRT plugin client.
"""
def factory():
# Plugin may already be statically linked in some configurations.
# Plugin may already be statically linked in some configurations, or we
# could be creating a client twice.
if not xla_client.pjrt_plugin_loaded(plugin_name):
if library_path is None:
raise ValueError(
@ -359,6 +361,9 @@ def register_plugin(
' plugin.'
)
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
if xla_extension_version >= 183:
if not xla_client.pjrt_plugin_initialized(plugin_name):
xla_client.initialize_pjrt_plugin(plugin_name)
if distributed.global_state.client is None:
return xla_client.make_c_api_client(plugin_name, options, None)

View File

@ -23,6 +23,7 @@ from jax._src import compiler
from jax._src import test_util as jtu
from jax._src import xla_bridge as xb
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.interpreters import xla
from jax._src import config as jax_config
@ -152,8 +153,18 @@ class XlaBridgeTest(jtu.JaxTestCase):
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded:
registration.factory()
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
if xla_extension_version < 183:
registration.factory()
else:
with mock.patch.object(
xc, "pjrt_plugin_initialized", autospec=True, return_vale=True
):
with mock.patch.object(
xc, "initialize_pjrt_plugin", autospec=True
):
registration.factory()
self.assertRegex(
log_output[1][0],
@ -171,16 +182,28 @@ class XlaBridgeTest(jtu.JaxTestCase):
test_json_file_path = os.path.join(
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
)
os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = (
f"name1;{test_json_file_path}" if platform.system() == "Windows"
else f"name1:{test_json_file_path}")
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
f"name1;{test_json_file_path}"
if platform.system() == "Windows"
else f"name1:{test_json_file_path}"
)
xb.register_pjrt_plugin_factories_from_env()
registration = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
with mock.patch.object(
xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded:
registration.factory()
xc, "pjrt_plugin_loaded", autospec=True
) as mock_plugin_loaded:
if xla_extension_version < 183:
registration.factory()
else:
with mock.patch.object(
xc, "pjrt_plugin_initialized", autospec=True, return_vale=True
):
with mock.patch.object(
xc, "initialize_pjrt_plugin", autospec=True
):
registration.factory()
self.assertIn("name1", xb._backend_factories)
self.assertEqual(registration.priority, 400)