mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #16073 from stellaraccident:extplugin
PiperOrigin-RevId: 534237189
This commit is contained in:
commit
a7b8129ffb
@ -27,6 +27,7 @@ import logging
|
||||
import os
|
||||
import platform as py_platform
|
||||
import pkgutil
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||
import warnings
|
||||
@ -319,19 +320,80 @@ def _get_pjrt_plugin_config(
|
||||
def discover_pjrt_plugins() -> None:
|
||||
"""Discovers plugins in the namespace package `jax_plugins` and import them.
|
||||
|
||||
The plugins need to (1) be place in a root folder `jax_plugins` and follow
|
||||
other namespace package requirements, and (2) implement an initialize()
|
||||
method, which calls jax._src.xla_bridge.register_plugin with its plugin_name,
|
||||
path to .so file, and optional create options.
|
||||
There are two methods used to discover plugin modules. They are intended
|
||||
to be used together by implementors in order to cover all packaging and
|
||||
development cases:
|
||||
|
||||
1. Define a globally unique module under the `jax_plugins` namespace
|
||||
package (i.e. just create a `jax_plugins` directory and define your
|
||||
module below it).
|
||||
2. If building a package via pyproject.toml or setup.py, advertise your
|
||||
plugin module name by including an entry-point under the `jax_plugins`
|
||||
group which points to your full module name.
|
||||
|
||||
During Jax startup, Jax will load each module discovered in such a way and
|
||||
call its `initialize()` function. It is expected that this function should
|
||||
register its concrete plugin name/implementations via call(s) to
|
||||
`jax._src.xla_bridge.register_plugin(name, priority=, library_paty=,
|
||||
options=)`. Since `initialize()` functions are called for all installed
|
||||
plugins, they should avoid doing expensive, non-registration related work.
|
||||
|
||||
TODO: We should provide a variant of `register_plugin` which allows the
|
||||
library_path and options to be resolved via a callback. This would enable
|
||||
light-weight plugin registration in cases where options need to be derived
|
||||
from heavy-weight system initialization.
|
||||
"""
|
||||
if jax_plugins is None:
|
||||
return
|
||||
for _, name, _ in pkgutil.iter_modules(
|
||||
jax_plugins.__path__, jax_plugins.__name__ + '.'
|
||||
):
|
||||
# TODO(b/261345120): Add a try-catch to defend against a broken plugin.
|
||||
module = importlib.import_module(name)
|
||||
module.initialize()
|
||||
plugin_modules = set()
|
||||
# Scan installed modules under |jax_plugins|. Note that not all packaging
|
||||
# scenarios are amenable to such scanning, so we also use the entry-point
|
||||
# method to seed the list.
|
||||
if jax_plugins:
|
||||
for _, name, _ in pkgutil.iter_modules(
|
||||
jax_plugins.__path__, jax_plugins.__name__ + '.'
|
||||
):
|
||||
logger.debug("Discovered path based JAX plugin: %s", name)
|
||||
plugin_modules.add(name)
|
||||
else:
|
||||
logger.debug("No jax_plugins namespace packages available")
|
||||
|
||||
# Augment with advertised entrypoints.
|
||||
if sys.version_info < (3, 10):
|
||||
# Use the backport library because it provides a forward-compatible
|
||||
# implementation.
|
||||
try:
|
||||
from importlib_metadata import entry_points
|
||||
except ModuleNotFoundError:
|
||||
logger.debug(
|
||||
"No importlib_metadata found (for Python < 3.10): "
|
||||
"Plugins advertised from entrypoints will not be found.")
|
||||
entry_points = None
|
||||
else:
|
||||
from importlib.metadata import entry_points
|
||||
if entry_points:
|
||||
for entry_point in entry_points(group="jax_plugins"):
|
||||
logger.debug("Discovered entry-point based JAX plugin: %s",
|
||||
entry_point.value)
|
||||
plugin_modules.add(entry_point.value)
|
||||
|
||||
# Now load and initialize them all.
|
||||
for plugin_module_name in plugin_modules:
|
||||
logger.debug("Loading plugin module %s", plugin_module_name)
|
||||
plugin_module = None
|
||||
try:
|
||||
plugin_module = importlib.import_module(plugin_module_name)
|
||||
except ModuleNotFoundError:
|
||||
logger.warning("Jax plugin configuration error: Plugin module %s "
|
||||
"does not exist", plugin_module_name)
|
||||
except ImportError:
|
||||
logger.exception("Jax plugin configuration error: Plugin module %s "
|
||||
"could not be loaded")
|
||||
|
||||
if plugin_module:
|
||||
try:
|
||||
plugin_module.initialize()
|
||||
except:
|
||||
logger.exception("Jax plugin configuration error: Exception when "
|
||||
"calling %s.initialize()", plugin_module_name)
|
||||
|
||||
|
||||
# TODO(b/261345120): decide on a public name and expose a public method which is
|
||||
|
Loading…
x
Reference in New Issue
Block a user