mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[PJRT PLUGIN] Provide a register_plugin method that plugin can use to register their backend factory.
The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method. Logics to register a plugin from ENV is not deleted to facilitate development with ENV. PiperOrigin-RevId: 533280115
This commit is contained in:
parent
4a5c6f8200
commit
9da52e8905
@ -321,14 +321,8 @@ def discover_pjrt_plugins() -> None:
|
||||
|
||||
The plugins need to (1) be place in a root folder `jax_plugins` and follow
|
||||
other namespace package requirements, and (2) implement an initialize()
|
||||
method, which appends `plugin_name:file_path` to env var
|
||||
`PJRT_NAMES_AND_LIBRARY_PATHS`. The file_path is the path to the .so file, or
|
||||
the path to the json file with configurations. Please refer to the comment of
|
||||
register_pjrt_plugin_factories or
|
||||
jax/tests/testdata/example_pjrt_plugin_config.json about the json file format.
|
||||
|
||||
TODO(b/261345120): using env var `PJRT_NAMES_AND_LIBRARY_PATHS` is a short
|
||||
term solution. What initialize() method should do is in discussion.
|
||||
method, which calls jax._src.xla_bridge.register_plugin with its plugin_name,
|
||||
path to .so file, and optional create options.
|
||||
"""
|
||||
if jax_plugins is None:
|
||||
return
|
||||
@ -340,7 +334,51 @@ def discover_pjrt_plugins() -> None:
|
||||
module.initialize()
|
||||
|
||||
|
||||
def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
|
||||
# TODO(b/261345120): decide on a public name and expose a public method which is
|
||||
# an alias of this method.
|
||||
def register_plugin(
|
||||
plugin_name: str,
|
||||
*,
|
||||
priority: int = 400,
|
||||
library_path: Optional[str] = None,
|
||||
options: Optional[Mapping[str, Union[str, int, List[int], float]]] = None,
|
||||
) -> None:
|
||||
"""Registers a backend factory for the PJRT plugin.
|
||||
|
||||
Args:
|
||||
plugin_name: the name of the plugin.
|
||||
priority: the priority this plugin should be registered in jax backends.
|
||||
Default to be 400.
|
||||
library_path: Optional. The full path to the .so file of the plugin.
|
||||
Required when the plugin is dynamically linked.
|
||||
options: Optional. It is used when creating a PJRT plugin client.
|
||||
"""
|
||||
def factory():
|
||||
if xla_extension_version >= 152:
|
||||
# Plugin may already be statically linked in some configurations.
|
||||
if not xla_client.pjrt_plugin_loaded(plugin_name):
|
||||
if library_path is None:
|
||||
raise ValueError(
|
||||
'The library path is None when trying to dynamically load the'
|
||||
' plugin.'
|
||||
)
|
||||
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
|
||||
else:
|
||||
if library_path is None:
|
||||
raise ValueError(
|
||||
'The library path is None when trying to dynamically load the'
|
||||
' plugin.'
|
||||
)
|
||||
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
|
||||
return xla_client.make_c_api_client(plugin_name, options)
|
||||
|
||||
logger.debug(
|
||||
'registering PJRT plugin %s from %s', plugin_name, library_path
|
||||
)
|
||||
register_backend_factory(plugin_name, factory, priority=priority)
|
||||
|
||||
|
||||
def register_pjrt_plugin_factories_from_env() -> None:
|
||||
"""Registers backend factories for PJRT plugins.
|
||||
|
||||
A backend factory will be registered for every PJRT plugin in the input
|
||||
@ -354,44 +392,27 @@ def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
|
||||
|
||||
TPU PJRT plugin will be loaded and registered separately in make_tpu_client.
|
||||
"""
|
||||
|
||||
def make_factory(name: str, path: str):
|
||||
def factory():
|
||||
if path.endswith('.json'):
|
||||
library_path, options = _get_pjrt_plugin_config(path)
|
||||
else:
|
||||
library_path = path
|
||||
options = None
|
||||
|
||||
if xla_extension_version >= 152:
|
||||
# Plugin may already be statically linked in some configurations.
|
||||
if not xla_client.pjrt_plugin_loaded(name):
|
||||
xla_client.load_pjrt_plugin_dynamically(name, library_path)
|
||||
else:
|
||||
xla_client.load_pjrt_plugin_dynamically(name, library_path)
|
||||
return xla_client.make_c_api_client(name, options)
|
||||
|
||||
return factory
|
||||
|
||||
pjrt_plugins = _get_pjrt_plugin_names_and_library_paths(plugins_from_env)
|
||||
for plugin_name, library_path in pjrt_plugins.items():
|
||||
pjrt_plugins = _get_pjrt_plugin_names_and_library_paths(
|
||||
os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', '')
|
||||
)
|
||||
for plugin_name, path in pjrt_plugins.items():
|
||||
if path.endswith('.json'):
|
||||
library_path, options = _get_pjrt_plugin_config(path)
|
||||
else:
|
||||
library_path = path
|
||||
options = None
|
||||
logger.debug(
|
||||
'registering PJRT plugin %s from %s', plugin_name, library_path
|
||||
)
|
||||
# It is assumed that if a plugin is installed, then the user wants to use
|
||||
# the plugin by default. Therefore, plugins get the highest priority.
|
||||
# For a PJRT plugin, its plugin_name is the same as its platform_name.
|
||||
register_backend_factory(
|
||||
plugin_name, make_factory(plugin_name, library_path), priority=400
|
||||
)
|
||||
register_plugin(plugin_name, library_path=library_path, options=options)
|
||||
|
||||
|
||||
# Plugins in the namespace package `jax_plugins` will be imported.
|
||||
discover_pjrt_plugins()
|
||||
# The plugin names and paths are set in env var PJRT_NAMES_AND_LIBRARY_PATHS,
|
||||
# Registers plugins names and paths set in env var PJRT_NAMES_AND_LIBRARY_PATHS,
|
||||
# in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for
|
||||
# windows).
|
||||
register_pjrt_plugin_factories(os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', ''))
|
||||
register_pjrt_plugin_factories_from_env()
|
||||
|
||||
if iree is not None:
|
||||
register_backend_factory("iree", iree.iree_client_factory, priority=-100)
|
||||
|
@ -92,7 +92,8 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
|
||||
def test_register_plugin(self):
|
||||
with self.assertLogs(level="WARNING") as log_output:
|
||||
xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3")
|
||||
os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = "name1:path1,name2:path2,name3"
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
client_factory, priotiy = xb._backend_factories["name1"]
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
with mock.patch.object(
|
||||
@ -129,7 +130,8 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
test_json_file_path = os.path.join(
|
||||
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
|
||||
)
|
||||
xb.register_pjrt_plugin_factories(f"name1:{test_json_file_path}")
|
||||
os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = f"name1:{test_json_file_path}"
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
client_factory, priority = xb._backend_factories["name1"]
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
with mock.patch.object(
|
||||
|
Loading…
x
Reference in New Issue
Block a user