Add an unregister event listener function in JAX monitoring.

Add a private function _unregister_event_listener_by_callback to remove registered event listeners. The functions are supposed to be called in test only. Add a getter function for event listeners to help unit testing the unregister function.

PiperOrigin-RevId: 558309557
This commit is contained in:
jax authors 2023-08-18 19:59:26 -07:00
parent 1a9109f32e
commit a8945fd2bd
2 changed files with 36 additions and 0 deletions

View File

@ -45,8 +45,13 @@ def register_event_duration_secs_listener(
_event_duration_secs_listeners.append(callback)
def get_event_duration_listeners() -> list[Callable[[str, float], None]]:
"""Get event duration listeners."""
return list(_event_duration_secs_listeners)
def get_event_listeners() -> list[Callable[[str], None]]:
"""Get event listeners."""
return list(_event_listeners)
def _clear_event_listeners():
"""Clear event listeners."""
global _event_listeners, _event_duration_secs_listeners
@ -70,3 +75,12 @@ def _unregister_event_duration_listener_by_index(index: int) -> None:
size = len(_event_duration_secs_listeners)
assert -size <= index < size
del _event_duration_secs_listeners[index]
def _unregister_event_listener_by_callback(
callback: Callable[[str], None]) -> None:
"""Unregister an event listener by callback.
This function is supposed to be called for testing only.
"""
assert callback in _event_listeners
_event_listeners.remove(callback)

View File

@ -118,5 +118,27 @@ class MonitoringTest(absltest.TestCase):
self.assertNotEqual(original_duration_listeners,
jax_src_monitoring.get_event_duration_listeners())
def test_unregister_exist_event_callback_success(self):
original_event_listeners = jax_src_monitoring.get_event_listeners()
callback = lambda event: None
self.assertNotIn(callback, original_event_listeners)
monitoring.register_event_listener(callback)
self.assertIn(callback, jax_src_monitoring.get_event_listeners())
# Verify that original listeners list is not modified by register function.
self.assertNotEqual(original_event_listeners,
jax_src_monitoring.get_event_listeners())
jax_src_monitoring._unregister_event_listener_by_callback(callback)
self.assertEqual(original_event_listeners,
jax_src_monitoring.get_event_listeners())
def test_unregister_not_exist_event_callback_fail(self):
callback = lambda event: None
self.assertNotIn(callback, jax_src_monitoring.get_event_listeners())
with self.assertRaises(AssertionError):
jax_src_monitoring._unregister_event_listener_by_callback(callback)
if __name__ == "__main__":
absltest.main()