mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add an unregister event listener function in JAX monitoring.
Add a private function _unregister_event_listener_by_callback to remove registered event listeners. The functions are supposed to be called in test only. Add a getter function for event listeners to help unit testing the unregister function. PiperOrigin-RevId: 558309557
This commit is contained in:
parent
1a9109f32e
commit
a8945fd2bd
@ -45,8 +45,13 @@ def register_event_duration_secs_listener(
|
||||
_event_duration_secs_listeners.append(callback)
|
||||
|
||||
def get_event_duration_listeners() -> list[Callable[[str, float], None]]:
|
||||
"""Get event duration listeners."""
|
||||
return list(_event_duration_secs_listeners)
|
||||
|
||||
def get_event_listeners() -> list[Callable[[str], None]]:
|
||||
"""Get event listeners."""
|
||||
return list(_event_listeners)
|
||||
|
||||
def _clear_event_listeners():
|
||||
"""Clear event listeners."""
|
||||
global _event_listeners, _event_duration_secs_listeners
|
||||
@ -70,3 +75,12 @@ def _unregister_event_duration_listener_by_index(index: int) -> None:
|
||||
size = len(_event_duration_secs_listeners)
|
||||
assert -size <= index < size
|
||||
del _event_duration_secs_listeners[index]
|
||||
|
||||
def _unregister_event_listener_by_callback(
|
||||
callback: Callable[[str], None]) -> None:
|
||||
"""Unregister an event listener by callback.
|
||||
|
||||
This function is supposed to be called for testing only.
|
||||
"""
|
||||
assert callback in _event_listeners
|
||||
_event_listeners.remove(callback)
|
||||
|
@ -118,5 +118,27 @@ class MonitoringTest(absltest.TestCase):
|
||||
self.assertNotEqual(original_duration_listeners,
|
||||
jax_src_monitoring.get_event_duration_listeners())
|
||||
|
||||
def test_unregister_exist_event_callback_success(self):
|
||||
original_event_listeners = jax_src_monitoring.get_event_listeners()
|
||||
callback = lambda event: None
|
||||
self.assertNotIn(callback, original_event_listeners)
|
||||
monitoring.register_event_listener(callback)
|
||||
self.assertIn(callback, jax_src_monitoring.get_event_listeners())
|
||||
# Verify that original listeners list is not modified by register function.
|
||||
self.assertNotEqual(original_event_listeners,
|
||||
jax_src_monitoring.get_event_listeners())
|
||||
|
||||
jax_src_monitoring._unregister_event_listener_by_callback(callback)
|
||||
|
||||
self.assertEqual(original_event_listeners,
|
||||
jax_src_monitoring.get_event_listeners())
|
||||
|
||||
def test_unregister_not_exist_event_callback_fail(self):
|
||||
callback = lambda event: None
|
||||
self.assertNotIn(callback, jax_src_monitoring.get_event_listeners())
|
||||
|
||||
with self.assertRaises(AssertionError):
|
||||
jax_src_monitoring._unregister_event_listener_by_callback(callback)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user