mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[PJRT C API] Register a backend factory for every PJRT plugin set in PJRT_NAMES_AND_LIBRARY_PATHS.
Loading TPU PJRT plugin is moved to make_tpu_client. This change is based on https://github.com/google/jax/pull/14011. PiperOrigin-RevId: 508477737
This commit is contained in:
parent
15c9bca67f
commit
668b82d529
@ -254,6 +254,71 @@ if hasattr(xla_client, "make_plugin_device_client"):
|
||||
register_backend_factory("plugin", xla_client.make_plugin_device_client,
|
||||
priority=400)
|
||||
|
||||
|
||||
def _get_pjrt_plugin_names_and_library_paths(
|
||||
plugins_from_env: str,
|
||||
) -> Dict[str, str]:
|
||||
"""Gets the names and library paths of PJRT plugins to load from env var.
|
||||
|
||||
Args:
|
||||
plugins_from_env: plugin name and pathes from env var. It is in the format
|
||||
of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for windows).
|
||||
|
||||
Returns:
|
||||
A dict of {plugin_name: library path} for the PJRT plugins to load.
|
||||
"""
|
||||
if not plugins_from_env:
|
||||
return {}
|
||||
|
||||
pjrt_plugins = {}
|
||||
for plugin in plugins_from_env.split(','):
|
||||
try:
|
||||
name, library_path = plugin.split(os.path.pathsep)
|
||||
pjrt_plugins[name] = library_path
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
'invalid value %s in env var PJRT_NAMES_AND_LIBRARY_PATHS %s',
|
||||
plugin,
|
||||
plugins_from_env,
|
||||
)
|
||||
return pjrt_plugins
|
||||
|
||||
|
||||
def register_pjrt_plugin_factories(plugins_from_env: str):
|
||||
"""Registers backend factories for PJRT plugins.
|
||||
|
||||
A backend factory will be registered for every PJRT plugin in the input
|
||||
string, in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2'
|
||||
for windows). TPU PJRT plugin will be loaded and registered separately in
|
||||
make_tpu_client.
|
||||
"""
|
||||
|
||||
def make_factory(name, path):
|
||||
def factory():
|
||||
xla_client.load_pjrt_plugin_dynamically(name, path)
|
||||
return xla_client.make_c_api_client(name)
|
||||
|
||||
return factory
|
||||
|
||||
pjrt_plugins = _get_pjrt_plugin_names_and_library_paths(plugins_from_env)
|
||||
for plugin_name, library_path in pjrt_plugins.items():
|
||||
logger.debug(
|
||||
'registering PJRT plugin %s from %s', plugin_name, library_path
|
||||
)
|
||||
# It is assumed that if a plugin is installed, then the user wants to use
|
||||
# the plugin by default. Therefore, plugins get the highest priority.
|
||||
# For a PJRT plugin, its plugin_name is the same as its platform_name.
|
||||
register_backend_factory(
|
||||
plugin_name, make_factory(plugin_name, library_path), priority=400
|
||||
)
|
||||
|
||||
|
||||
if lib.xla_extension_version >= 126:
|
||||
# The plugin names and paths are set in env var PJRT_NAMES_AND_LIBRARY_PATHS,
|
||||
# in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for
|
||||
# windows).
|
||||
register_pjrt_plugin_factories(os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', ''))
|
||||
|
||||
if iree is not None:
|
||||
register_backend_factory("iree", iree.iree_client_factory, priority=-100)
|
||||
|
||||
@ -328,8 +393,6 @@ def backends():
|
||||
(platform, priority) for platform, (_, priority)
|
||||
in _backend_factories.items())
|
||||
default_priority = -1000
|
||||
if hasattr(xla_client, "maybe_load_pjrt_plugins"):
|
||||
xla_client.maybe_load_pjrt_plugins()
|
||||
for platform, priority in platforms_and_priorites:
|
||||
try:
|
||||
backend = _init_backend(platform)
|
||||
|
@ -89,6 +89,30 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
side_effect=_mock_tpu_client):
|
||||
xb.tpu_client_timer_callback(0.01)
|
||||
|
||||
def test_register_plugin(self):
|
||||
if xc._version < 126:
|
||||
return
|
||||
|
||||
with self.assertLogs(level="WARNING") as log_output:
|
||||
xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3")
|
||||
client_factory, priotiy = xb._backend_factories["name1"]
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
with mock.patch.object(
|
||||
xc, "load_pjrt_plugin_dynamically", autospec=True
|
||||
) as mock_load_plugin:
|
||||
client_factory()
|
||||
|
||||
self.assertRegex(
|
||||
log_output[1][0],
|
||||
r"invalid value name3 in env var PJRT_NAMES_AND_LIBRARY_PATHS"
|
||||
r" name1:path1,name2:path2,name3",
|
||||
)
|
||||
self.assertIn("name1", xb._backend_factories)
|
||||
self.assertIn("name2", xb._backend_factories)
|
||||
self.assertEqual(priotiy, 400)
|
||||
mock_load_plugin.assert_called_once_with("name1", "path1")
|
||||
mock_make.assert_called_once_with("name1")
|
||||
|
||||
|
||||
class GetBackendTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user