[PJRT PLUGIN] Provide a register_plugin method that plugin can use to register their backend factory.

The plugin is expected to calls jax._src.xla_bridge.register_plugin with its plugin_name, priority (default to be 400), path to .so file, and optional create options in their initialize() method.

Logics to register a plugin from ENV is not deleted to facilitate development with ENV.

PiperOrigin-RevId: 533280115
This commit is contained in:
Jieying Luo 2023-05-18 16:12:23 -07:00 committed by jax authors
parent 4a5c6f8200
commit 9da52e8905
2 changed files with 63 additions and 40 deletions

View File

@ -321,14 +321,8 @@ def discover_pjrt_plugins() -> None:
The plugins need to (1) be place in a root folder `jax_plugins` and follow
other namespace package requirements, and (2) implement an initialize()
method, which appends `plugin_name:file_path` to env var
`PJRT_NAMES_AND_LIBRARY_PATHS`. The file_path is the path to the .so file, or
the path to the json file with configurations. Please refer to the comment of
register_pjrt_plugin_factories or
jax/tests/testdata/example_pjrt_plugin_config.json about the json file format.
TODO(b/261345120): using env var `PJRT_NAMES_AND_LIBRARY_PATHS` is a short
term solution. What initialize() method should do is in discussion.
method, which calls jax._src.xla_bridge.register_plugin with its plugin_name,
path to .so file, and optional create options.
"""
if jax_plugins is None:
return
@ -340,7 +334,51 @@ def discover_pjrt_plugins() -> None:
module.initialize()
def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
# TODO(b/261345120): decide on a public name and expose a public method which is
# an alias of this method.
def register_plugin(
plugin_name: str,
*,
priority: int = 400,
library_path: Optional[str] = None,
options: Optional[Mapping[str, Union[str, int, List[int], float]]] = None,
) -> None:
"""Registers a backend factory for the PJRT plugin.
Args:
plugin_name: the name of the plugin.
priority: the priority this plugin should be registered in jax backends.
Default to be 400.
library_path: Optional. The full path to the .so file of the plugin.
Required when the plugin is dynamically linked.
options: Optional. It is used when creating a PJRT plugin client.
"""
def factory():
if xla_extension_version >= 152:
# Plugin may already be statically linked in some configurations.
if not xla_client.pjrt_plugin_loaded(plugin_name):
if library_path is None:
raise ValueError(
'The library path is None when trying to dynamically load the'
' plugin.'
)
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
else:
if library_path is None:
raise ValueError(
'The library path is None when trying to dynamically load the'
' plugin.'
)
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
return xla_client.make_c_api_client(plugin_name, options)
logger.debug(
'registering PJRT plugin %s from %s', plugin_name, library_path
)
register_backend_factory(plugin_name, factory, priority=priority)
def register_pjrt_plugin_factories_from_env() -> None:
"""Registers backend factories for PJRT plugins.
A backend factory will be registered for every PJRT plugin in the input
@ -354,44 +392,27 @@ def register_pjrt_plugin_factories(plugins_from_env: str) -> None:
TPU PJRT plugin will be loaded and registered separately in make_tpu_client.
"""
def make_factory(name: str, path: str):
def factory():
if path.endswith('.json'):
library_path, options = _get_pjrt_plugin_config(path)
else:
library_path = path
options = None
if xla_extension_version >= 152:
# Plugin may already be statically linked in some configurations.
if not xla_client.pjrt_plugin_loaded(name):
xla_client.load_pjrt_plugin_dynamically(name, library_path)
else:
xla_client.load_pjrt_plugin_dynamically(name, library_path)
return xla_client.make_c_api_client(name, options)
return factory
pjrt_plugins = _get_pjrt_plugin_names_and_library_paths(plugins_from_env)
for plugin_name, library_path in pjrt_plugins.items():
pjrt_plugins = _get_pjrt_plugin_names_and_library_paths(
os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', '')
)
for plugin_name, path in pjrt_plugins.items():
if path.endswith('.json'):
library_path, options = _get_pjrt_plugin_config(path)
else:
library_path = path
options = None
logger.debug(
'registering PJRT plugin %s from %s', plugin_name, library_path
)
# It is assumed that if a plugin is installed, then the user wants to use
# the plugin by default. Therefore, plugins get the highest priority.
# For a PJRT plugin, its plugin_name is the same as its platform_name.
register_backend_factory(
plugin_name, make_factory(plugin_name, library_path), priority=400
)
register_plugin(plugin_name, library_path=library_path, options=options)
# Plugins in the namespace package `jax_plugins` will be imported.
discover_pjrt_plugins()
# The plugin names and paths are set in env var PJRT_NAMES_AND_LIBRARY_PATHS,
# Registers plugins names and paths set in env var PJRT_NAMES_AND_LIBRARY_PATHS,
# in the format of 'name1:path1,name2:path2' ('name1;path1,name2;path2' for
# windows).
register_pjrt_plugin_factories(os.getenv('PJRT_NAMES_AND_LIBRARY_PATHS', ''))
register_pjrt_plugin_factories_from_env()
if iree is not None:
register_backend_factory("iree", iree.iree_client_factory, priority=-100)

View File

@ -92,7 +92,8 @@ class XlaBridgeTest(jtu.JaxTestCase):
def test_register_plugin(self):
with self.assertLogs(level="WARNING") as log_output:
xb.register_pjrt_plugin_factories("name1:path1,name2:path2,name3")
os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = "name1:path1,name2:path2,name3"
xb.register_pjrt_plugin_factories_from_env()
client_factory, priotiy = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(
@ -129,7 +130,8 @@ class XlaBridgeTest(jtu.JaxTestCase):
test_json_file_path = os.path.join(
os.path.dirname(__file__), "testdata/example_pjrt_plugin_config.json"
)
xb.register_pjrt_plugin_factories(f"name1:{test_json_file_path}")
os.environ['PJRT_NAMES_AND_LIBRARY_PATHS'] = f"name1:{test_json_file_path}"
xb.register_pjrt_plugin_factories_from_env()
client_factory, priority = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
with mock.patch.object(