[PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.

Loading TPU PJRT plugin is moved to make_tpu_client.

This change is based on https://github.com/google/jax/pull/14011.

PiperOrigin-RevId: 508477737
This commit is contained in:
Jieying Luo 2023-02-09 14:33:05 -08:00 committed by jax authors
parent 15c9bca67f
commit 668b82d529
2 changed files with 89 additions and 2 deletions

View File

@ -254,6 +254,71 @@ if hasattr(xla_client, "make_plugin_device_client"):
register_backend_factory("plugin", xla_client.make_plugin_device_client,
priority=400)
def _get_pjrt_plugin_names_and_library_paths(
plugins_from_env: str,
) -> Dict[str, str]:
"""Gets the names and library paths of PJRT plugins to load from env var.
Args:
plugins_from_env: plugin name and pathes from env var. It is in the format
of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for windows).
Returns:
A dict of {plugin_name: library path} for the PJRT plugins to load.
"""
if not plugins_from_env:
return {}
pjrt_plugins = {}
for plugin in plugins_from_env.split(','):
try:
name, library_path = plugin.split(os.path.pathsep)
pjrt_plugins[name] = library_path
except ValueError:
logger.warning(
'invalid value %s in env var PJRT_NAMES_AND_LIBRARY_PATHS %s',
plugin,
plugins_from_env,
)
return pjrt_plugins
def register_pjrt_plugin_factories(plugins_from_env: str):
"""Registers backend factories for PJRT plugins.
A backend factory will be registered for every PJRT plugin in the input
string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2'
for windows). TPU PJRT plugin will be loaded and registered separately in
make_tpu_client.
"""
def make_factory(name, path):
def factory():
xla_client.load_pjrt_plugin_dynamically(name, path)
return xla_client.make_c_api_client(name)
return factory
pjrt_plugins = _get_pjrt_plugin_names_and_library_paths(plugins_from_env)
for plugin_name, library_path in pjrt_plugins.items():
logger.debug(
'registering PJRT plugin %s from %s', plugin_name, library_path
)
# It is assumed that if a plugin is installed, then the user wants to use
# the plugin by default. Therefore, plugins get the highest priority.
# For a PJRT plugin, its plugin_name is the same as its platform_name.
register_backend_factory(
plugin_name, make_factory(plugin_name, library_path), priority=400
)
if lib.xla_extension_version >= 126:
# The plugin names and paths are set in env var PJRT_NAMES_AND_LIBRARY_PATHS,
# in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for
# windows).
register_pjrt_plugin_factories(os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', ''))
if iree is not None:
register_backend_factory("iree", iree.iree_client_factory, priority=-100)
@ -328,8 +393,6 @@ def backends():
(platform, priority) for platform, (_, priority)
in _backend_factories.items())
default_priority = -1000
if hasattr(xla_client, "maybe_load_pjrt_plugins"):
xla_client.maybe_load_pjrt_plugins()
for platform, priority in platforms_and_priorites:
try:
backend = _init_backend(platform)

View File

@ -89,6 +89,30 @@ class XlaBridgeTest(jtu.JaxTestCase):
side_effect=_mock_tpu_client):
xb.tpu_client_timer_callback(0.01)
def test_register_plugin(self):
if xc._version < 126:
return
with self.assertLogs(level="WARNING") as log_output:
xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3")
client_factory, priotiy = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(
xc, "load_pjrt_plugin_dynamically", autospec=True
) as mock_load_plugin:
client_factory()
self.assertRegex(
log_output[1][0],
r"invalid value name3 in env var PJRT_NAMES_AND_LIBRARY_PATHS"
r" name1:path1,name2:path2,name3",
)
self.assertIn("name1", xb._backend_factories)
self.assertIn("name2", xb._backend_factories)
self.assertEqual(priotiy, 400)
mock_load_plugin.assert_called_once_with("name1", "path1")
mock_make.assert_called_once_with("name1")
class GetBackendTest(jtu.JaxTestCase):