mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 20:36:05 +00:00
[PJRT C API] Let framework explicitly check whether a plugin is initialized and initialize the plugin.
Before this change, PJRT_Plugin_Initialize was called in LoadPjrtPlugin, which is only used in dynamic linking case. This change adds a bool and a method to check whether the plugin is initialized. The framework will explicitly check whether a plugin is initialized, and call InitializePjrtPlugin if it is not. This will be apply to both static linking and dynamic linking case. PiperOrigin-RevId: 557268670
This commit is contained in:
parent
b7796710e4
commit
c7e8b81a74
@ -37,6 +37,7 @@ from jax._src import distributed
|
||||
from jax._src import config as jax_config
|
||||
from jax._src.config import config
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
|
||||
@ -351,7 +352,8 @@ def register_plugin(
|
||||
options: Optional. It is used when creating a PJRT plugin client.
|
||||
"""
|
||||
def factory():
|
||||
# Plugin may already be statically linked in some configurations.
|
||||
# Plugin may already be statically linked in some configurations, or we
|
||||
# could be creating a client twice.
|
||||
if not xla_client.pjrt_plugin_loaded(plugin_name):
|
||||
if library_path is None:
|
||||
raise ValueError(
|
||||
@ -359,6 +361,9 @@ def register_plugin(
|
||||
' plugin.'
|
||||
)
|
||||
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
|
||||
if xla_extension_version >= 183:
|
||||
if not xla_client.pjrt_plugin_initialized(plugin_name):
|
||||
xla_client.initialize_pjrt_plugin(plugin_name)
|
||||
|
||||
if distributed.global_state.client is None:
|
||||
return xla_client.make_c_api_client(plugin_name, options, None)
|
||||
|
@ -23,6 +23,7 @@ from jax._src import compiler
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.interpreters import xla
|
||||
|
||||
from jax._src import config as jax_config
|
||||
@ -152,8 +153,18 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
|
||||
with mock.patch.object(
|
||||
xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded:
|
||||
registration.factory()
|
||||
xc, "pjrt_plugin_loaded", autospec=True
|
||||
) as mock_plugin_loaded:
|
||||
if xla_extension_version < 183:
|
||||
registration.factory()
|
||||
else:
|
||||
with mock.patch.object(
|
||||
xc, "pjrt_plugin_initialized", autospec=True, return_vale=True
|
||||
):
|
||||
with mock.patch.object(
|
||||
xc, "initialize_pjrt_plugin", autospec=True
|
||||
):
|
||||
registration.factory()
|
||||
|
||||
self.assertRegex(
|
||||
log_output[1][0],
|
||||
@ -171,16 +182,28 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
test_json_file_path = os.path.join(
|
||||
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
|
||||
)
|
||||
os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = (
|
||||
f"name1;{test_json_file_path}" if platform.system() == "Windows"
|
||||
else f"name1:{test_json_file_path}")
|
||||
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
|
||||
f"name1;{test_json_file_path}"
|
||||
if platform.system() == "Windows"
|
||||
else f"name1:{test_json_file_path}"
|
||||
)
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
registration = xb._backend_factories["name1"]
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
|
||||
with mock.patch.object(
|
||||
xc, "pjrt_plugin_loaded", autospec=True) as mock_plugin_loaded:
|
||||
registration.factory()
|
||||
xc, "pjrt_plugin_loaded", autospec=True
|
||||
) as mock_plugin_loaded:
|
||||
if xla_extension_version < 183:
|
||||
registration.factory()
|
||||
else:
|
||||
with mock.patch.object(
|
||||
xc, "pjrt_plugin_initialized", autospec=True, return_vale=True
|
||||
):
|
||||
with mock.patch.object(
|
||||
xc, "initialize_pjrt_plugin", autospec=True
|
||||
):
|
||||
registration.factory()
|
||||
|
||||
self.assertIn("name1", xb._backend_factories)
|
||||
self.assertEqual(registration.priority, 400)
|
||||
|
Loading…
x
Reference in New Issue
Block a user