mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Extend plugin discovery to also include entry-points.
This effectively implements a mix of option 2 and option 3 from https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/ as a pragmatic way to cover all packaging cases. The namespace/path based iteration works for situations where code has not been packaged and is present on the PYTHONPATH, whereas the advertised entry-points work around setuptools/pkgutil issues that make it impossible to reliably iterate over installed modules in certain scenarios (noted for editable installs which use a custom finder that does not implement iter_modules()). A plugin entry-point can be advertised in setup.py (or equivalent pyproject.toml) with something like: ``` entry_points={ "jax_plugins": [ "openxla-cpu = jax_plugins.openxla_cpu", ], } ```
This commit is contained in:
parent
1d20d2f301
commit
221aa76d81
@ -27,6 +27,7 @@ import logging
|
||||
import os
|
||||
import platform as py_platform
|
||||
import pkgutil
|
||||
import sys
|
||||
import threading
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
||||
import warnings
|
||||
@ -324,14 +325,57 @@ def discover_pjrt_plugins() -> None:
|
||||
method, which calls jax._src.xla_bridge.register_plugin with its plugin_name,
|
||||
path to .so file, and optional create options.
|
||||
"""
|
||||
if jax_plugins is None:
|
||||
return
|
||||
for _, name, _ in pkgutil.iter_modules(
|
||||
jax_plugins.__path__, jax_plugins.__name__ + '.'
|
||||
):
|
||||
# TODO(b/261345120): Add a try-catch to defend against a broken plugin.
|
||||
module = importlib.import_module(name)
|
||||
module.initialize()
|
||||
plugin_modules = set()
|
||||
# Scan installed modules under |jax_plugins|. Note that not all packaging
|
||||
# scenarios are amenable to such scanning, so we also use the entry-point
|
||||
# method to seed the list.
|
||||
if jax_plugins:
|
||||
for _, name, _ in pkgutil.iter_modules(
|
||||
jax_plugins.__path__, jax_plugins.__name__ + '.'
|
||||
):
|
||||
logger.debug("Discovered path based JAX plugin: %s", name)
|
||||
plugin_modules.add(name)
|
||||
else:
|
||||
logger.debug("No jax_plugins namespace packages available")
|
||||
|
||||
# Augment with advertised entrypoints.
|
||||
if sys.version_info < (3, 10):
|
||||
# Use the backport library because it provides a forward-compatible
|
||||
# implementation.
|
||||
try:
|
||||
from importlib_metadata import entry_points
|
||||
except ModuleNotFoundError:
|
||||
logger.debug(
|
||||
f"No importlib_metadata found (for Python < 3.10): "
|
||||
f"Plugins advertised from entrypoints will not be found.")
|
||||
entry_points = None
|
||||
else:
|
||||
from importlib.metadata import entry_points
|
||||
if entry_points:
|
||||
for entry_point in entry_points(group="jax_plugins"):
|
||||
logger.debug("Discovered entry-point based JAX plugin: %s",
|
||||
entry_point.value)
|
||||
plugin_modules.add(entry_point.value)
|
||||
|
||||
# Now load and initialize them all.
|
||||
for plugin_module_name in plugin_modules:
|
||||
logger.debug("Loading plugin module %s", plugin_module_name)
|
||||
plugin_module = None
|
||||
try:
|
||||
plugin_module = importlib.import_module(plugin_module_name)
|
||||
except ModuleNotFoundError:
|
||||
logger.warning("Jax plugin configuration error: Plugin module %s "
|
||||
"does not exist", plugin_module_name)
|
||||
except ImportError:
|
||||
logger.exception("Jax plugin configuration error: Plugin module %s "
|
||||
"could not be loaded")
|
||||
|
||||
if plugin_module:
|
||||
try:
|
||||
plugin_module.initialize()
|
||||
except:
|
||||
logger.exception("Jax plugin configuration error: Exception when "
|
||||
"calling %s.initialize()", plugin_module_name)
|
||||
|
||||
|
||||
# TODO(b/261345120): decide on a public name and expose a public method which is
|
||||
|
Loading…
x
Reference in New Issue
Block a user