Add scalar event logging function

This commit is contained in:
jeffcarp 2025-02-27 15:51:00 -08:00 committed by Jeff Carpenter
parent 6b8821148d
commit 123ce5221b
3 changed files with 71 additions and 2 deletions

View File

@ -46,10 +46,18 @@ class EventTimeSpanListenerWithMetadata(Protocol):
) -> None:
...
class ScalarListenerWithMetadata(Protocol):
def __call__(
self, event: str, value: float | int, **kwargs: str | int,
) -> None:
...
_event_listeners: list[EventListenerWithMetadata] = []
_event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = []
_event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = []
_scalar_listeners: list[ScalarListenerWithMetadata] = []
def record_event(event: str, **kwargs: str | int) -> None:
@ -81,6 +89,14 @@ def record_event_time_span(
callback(event, start_time, end_time, **kwargs)
def record_scalar(
event: str, value: float | int, **kwargs: str | int
) -> None:
"""Record a scalar summary value."""
for callback in _scalar_listeners:
callback(event, value, **kwargs)
def register_event_listener(
callback: EventListenerWithMetadata,
) -> None:
@ -100,6 +116,14 @@ def register_event_duration_secs_listener(
"""Register a callback to be invoked during record_event_duration_secs()."""
_event_duration_secs_listeners.append(callback)
def register_scalar_listener(
callback : ScalarListenerWithMetadata,
) -> None:
"""Register a callback to be invoked during record_scalar()."""
_scalar_listeners.append(callback)
def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]:
"""Get event duration listeners."""
return list(_event_duration_secs_listeners)
@ -114,12 +138,20 @@ def get_event_listeners() -> list[EventListenerWithMetadata]:
"""Get event listeners."""
return list(_event_listeners)
def get_scalar_listeners() -> list[ScalarListenerWithMetadata]:
"""Get scalar event listeners."""
return list(_scalar_listeners)
def clear_event_listeners():
"""Clear event listeners."""
global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners
_event_listeners = []
_event_duration_secs_listeners = []
_event_time_span_listeners = []
_scalar_listeners = []
def _unregister_event_duration_listener_by_callback(
callback: EventDurationListenerWithMetadata) -> None:
@ -159,3 +191,14 @@ def _unregister_event_listener_by_callback(
"""
assert callback in _event_listeners
_event_listeners.remove(callback)
def _unregister_scalar_listener_by_callback(
callback: ScalarListenerWithMetadata,
) -> None:
"""Unregister a scalar event listener by callback.
This function is supposed to be called for testing only.
"""
assert callback in _scalar_listeners
_scalar_listeners.remove(callback)

View File

@ -26,7 +26,9 @@ from jax._src.monitoring import (
record_event_duration_secs as record_event_duration_secs,
record_event_time_span as record_event_time_span,
record_event as record_event,
record_scalar as record_scalar,
register_event_duration_secs_listener as register_event_duration_secs_listener,
register_event_listener as register_event_listener,
register_event_time_span_listener as register_event_time_span_listener,
register_scalar_listener as register_scalar_listener,
)

View File

@ -29,7 +29,7 @@ class MonitoringTest(absltest.TestCase):
def test_record_event(self):
events = []
counters = {} # Map event names to frequency.
counters = {} # Map event names to frequency.
def increment_event_counter(event):
if event not in counters:
counters[event] = 0
@ -48,7 +48,7 @@ class MonitoringTest(absltest.TestCase):
"test_common_event": 2})
def test_record_event_durations(self):
durations = {} # Map event names to frequency.
durations = {} # Map event names to frequency.
def increment_event_duration(event, duration):
if event not in durations:
durations[event] = 0.
@ -62,6 +62,30 @@ class MonitoringTest(absltest.TestCase):
self.assertDictEqual(durations, {"test_short_event": 3,
"test_long_event": 10})
def test_record_scalar(self):
observed_keys = []
observed_values = []
monitoring.register_scalar_listener(
lambda key, _: observed_keys.append(key),
)
monitoring.register_scalar_listener(
lambda _, value: observed_values.append(value),
)
monitoring.record_scalar("test_unique_event", 1)
monitoring.record_scalar("test_common_event", 2.5)
monitoring.record_scalar("test_common_event", 5e5)
self.assertListEqual(
observed_keys,
["test_unique_event", "test_common_event", "test_common_event"],
)
self.assertListEqual(
observed_values,
[1, 2.5, 5e5],
)
def test_unregister_exist_callback_success(self):
original_duration_listeners = jax_src_monitoring.get_event_duration_listeners()
callback = lambda event, durations: None