mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[PJRT C API] Adding Profiler C APIs and related framework changes.
C API changes: - Profiler C APIs are added in profiler_c_api.h. - Add a PJRT C API extension for the profiler C APIs in pjrt_c_api_profiler_extension.h. Framework changes: - Add a plugin_tracer that calls profiler C APIs. - Add a pybind method xla_client.profiler.register_plugin_profiler to register plugin_tracer with the plugin's PJRT_Api*. - Update xla_bridge.register_plugin to call register_plugin_profiler to register profiler for that plugin. PiperOrigin-RevId: 572027222
This commit is contained in:
parent
84b58ec7f3
commit
cb51e37008
@ -460,7 +460,10 @@ def register_plugin(
|
||||
fail_quietly=False, experimental=experimental)
|
||||
if library_path is not None:
|
||||
if xla_extension_version >= 198:
|
||||
return xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
|
||||
c_api = xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path) # type: ignore
|
||||
if xla_extension_version >= 203:
|
||||
xla_client.profiler.register_plugin_profiler(c_api)
|
||||
return c_api
|
||||
else:
|
||||
xla_client.load_pjrt_plugin_dynamically(plugin_name, library_path)
|
||||
return None
|
||||
|
@ -175,7 +175,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
os.environ["PJRT_NAMES_AND_LIBRARY_PATHS"] = (
|
||||
"name1:path1,name2:path2,name3"
|
||||
)
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
if xla_extension_version < 203:
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
else:
|
||||
with mock.patch.object(
|
||||
xc.profiler, "register_plugin_profiler", autospec=True
|
||||
):
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
registration = xb._backend_factories["name1"]
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
if xla_extension_version < 183:
|
||||
@ -208,7 +214,13 @@ class XlaBridgeTest(jtu.JaxTestCase):
|
||||
else f"name1:{test_json_file_path}"
|
||||
)
|
||||
with mock.patch.object(xc, "load_pjrt_plugin_dynamically", autospec=True):
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
if xla_extension_version < 203:
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
else:
|
||||
with mock.patch.object(
|
||||
xc.profiler, "register_plugin_profiler", autospec=True
|
||||
):
|
||||
xb.register_pjrt_plugin_factories_from_env()
|
||||
registration = xb._backend_factories["name1"]
|
||||
with mock.patch.object(xc, "make_c_api_client", autospec=True) as mock_make:
|
||||
if xla_extension_version < 183:
|
||||
|
Loading…
x
Reference in New Issue
Block a user