[PJRT C API] Adding Profiler C APIs and related framework changes.

C API changes:
- Profiler C APIs are added in profiler_c_api.h.
- Add a PJRT C API extension for the profiler C APIs in pjrt_c_api_profiler_extension.h.

Framework changes:
- Add a plugin_tracer that calls profiler C APIs.
- Add a pybind method xla_client.profiler.register_plugin_profiler to register plugin_tracer with the plugin's PJRT_Api*.
- Update xla_bridge.register_plugin to call register_plugin_profiler to register profiler for that plugin.

PiperOrigin-RevId: 572027222
This commit is contained in:
Jieying Luo 2023-10-09 13:35:33 -07:00 committed by jax authors
parent 84b58ec7f3
commit cb51e37008
2 changed files with 18 additions and 3 deletions

View File

@ -460,7 +460,10 @@ def register_plugin(
fail_quietly=False, experimental=experimental)
if library_path is not None:
if xla_extension_version >= 198:
return xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) # type: ignore
if xla_extension_version >= 203:
xla_client.profiler.register_plugin_profiler(c_api)
return c_api
else:
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
return None

View File

@ -175,7 +175,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
"name1:path1,name2:path2,name3"
)
xb.register_pjrt_plugin_factories_from_env()
if xla_extension_version < 203:
xb.register_pjrt_plugin_factories_from_env()
else:
with mock.patch.object(
xc.profiler, "register_plugin_profiler", autospec=True
):
xb.register_pjrt_plugin_factories_from_env()
registration = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
if xla_extension_version < 183:
@ -208,7 +214,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
else f"name1:{test_json_file_path}"
)
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
xb.register_pjrt_plugin_factories_from_env()
if xla_extension_version < 203:
xb.register_pjrt_plugin_factories_from_env()
else:
with mock.patch.object(
xc.profiler, "register_plugin_profiler", autospec=True
):
xb.register_pjrt_plugin_factories_from_env()
registration = xb._backend_factories["name1"]
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
if xla_extension_version < 183: