mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Add scalar event logging function
This commit is contained in:
parent
6b8821148d
commit
123ce5221b
@ -46,10 +46,18 @@ class EventTimeSpanListenerWithMetadata(Protocol):
|
||||
) -> None:
|
||||
...
|
||||
|
||||
class ScalarListenerWithMetadata(Protocol):
|
||||
|
||||
def __call__(
|
||||
self, event: str, value: float | int, **kwargs: str | int,
|
||||
) -> None:
|
||||
...
|
||||
|
||||
|
||||
_event_listeners: list[EventListenerWithMetadata] = []
|
||||
_event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = []
|
||||
_event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = []
|
||||
_scalar_listeners: list[ScalarListenerWithMetadata] = []
|
||||
|
||||
|
||||
def record_event(event: str, **kwargs: str | int) -> None:
|
||||
@ -81,6 +89,14 @@ def record_event_time_span(
|
||||
callback(event, start_time, end_time, **kwargs)
|
||||
|
||||
|
||||
def record_scalar(
|
||||
event: str, value: float | int, **kwargs: str | int
|
||||
) -> None:
|
||||
"""Record a scalar summary value."""
|
||||
for callback in _scalar_listeners:
|
||||
callback(event, value, **kwargs)
|
||||
|
||||
|
||||
def register_event_listener(
|
||||
callback: EventListenerWithMetadata,
|
||||
) -> None:
|
||||
@ -100,6 +116,14 @@ def register_event_duration_secs_listener(
|
||||
"""Register a callback to be invoked during record_event_duration_secs()."""
|
||||
_event_duration_secs_listeners.append(callback)
|
||||
|
||||
|
||||
def register_scalar_listener(
|
||||
callback : ScalarListenerWithMetadata,
|
||||
) -> None:
|
||||
"""Register a callback to be invoked during record_scalar()."""
|
||||
_scalar_listeners.append(callback)
|
||||
|
||||
|
||||
def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]:
|
||||
"""Get event duration listeners."""
|
||||
return list(_event_duration_secs_listeners)
|
||||
@ -114,12 +138,20 @@ def get_event_listeners() -> list[EventListenerWithMetadata]:
|
||||
"""Get event listeners."""
|
||||
return list(_event_listeners)
|
||||
|
||||
|
||||
def get_scalar_listeners() -> list[ScalarListenerWithMetadata]:
|
||||
"""Get scalar event listeners."""
|
||||
return list(_scalar_listeners)
|
||||
|
||||
|
||||
def clear_event_listeners():
|
||||
"""Clear event listeners."""
|
||||
global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners
|
||||
_event_listeners = []
|
||||
_event_duration_secs_listeners = []
|
||||
_event_time_span_listeners = []
|
||||
_scalar_listeners = []
|
||||
|
||||
|
||||
def _unregister_event_duration_listener_by_callback(
|
||||
callback: EventDurationListenerWithMetadata) -> None:
|
||||
@ -159,3 +191,14 @@ def _unregister_event_listener_by_callback(
|
||||
"""
|
||||
assert callback in _event_listeners
|
||||
_event_listeners.remove(callback)
|
||||
|
||||
|
||||
def _unregister_scalar_listener_by_callback(
|
||||
callback: ScalarListenerWithMetadata,
|
||||
) -> None:
|
||||
"""Unregister a scalar event listener by callback.
|
||||
|
||||
This function is supposed to be called for testing only.
|
||||
"""
|
||||
assert callback in _scalar_listeners
|
||||
_scalar_listeners.remove(callback)
|
||||
|
@ -26,7 +26,9 @@ from jax._src.monitoring import (
|
||||
record_event_duration_secs as record_event_duration_secs,
|
||||
record_event_time_span as record_event_time_span,
|
||||
record_event as record_event,
|
||||
record_scalar as record_scalar,
|
||||
register_event_duration_secs_listener as register_event_duration_secs_listener,
|
||||
register_event_listener as register_event_listener,
|
||||
register_event_time_span_listener as register_event_time_span_listener,
|
||||
register_scalar_listener as register_scalar_listener,
|
||||
)
|
||||
|
@ -29,7 +29,7 @@ class MonitoringTest(absltest.TestCase):
|
||||
|
||||
def test_record_event(self):
|
||||
events = []
|
||||
counters = {} # Map event names to frequency.
|
||||
counters = {} # Map event names to frequency.
|
||||
def increment_event_counter(event):
|
||||
if event not in counters:
|
||||
counters[event] = 0
|
||||
@ -48,7 +48,7 @@ class MonitoringTest(absltest.TestCase):
|
||||
"test_common_event": 2})
|
||||
|
||||
def test_record_event_durations(self):
|
||||
durations = {} # Map event names to frequency.
|
||||
durations = {} # Map event names to frequency.
|
||||
def increment_event_duration(event, duration):
|
||||
if event not in durations:
|
||||
durations[event] = 0.
|
||||
@ -62,6 +62,30 @@ class MonitoringTest(absltest.TestCase):
|
||||
self.assertDictEqual(durations, {"test_short_event": 3,
|
||||
"test_long_event": 10})
|
||||
|
||||
def test_record_scalar(self):
|
||||
observed_keys = []
|
||||
observed_values = []
|
||||
|
||||
monitoring.register_scalar_listener(
|
||||
lambda key, _: observed_keys.append(key),
|
||||
)
|
||||
monitoring.register_scalar_listener(
|
||||
lambda _, value: observed_values.append(value),
|
||||
)
|
||||
|
||||
monitoring.record_scalar("test_unique_event", 1)
|
||||
monitoring.record_scalar("test_common_event", 2.5)
|
||||
monitoring.record_scalar("test_common_event", 5e5)
|
||||
|
||||
self.assertListEqual(
|
||||
observed_keys,
|
||||
["test_unique_event", "test_common_event", "test_common_event"],
|
||||
)
|
||||
self.assertListEqual(
|
||||
observed_values,
|
||||
[1, 2.5, 5e5],
|
||||
)
|
||||
|
||||
def test_unregister_exist_callback_success(self):
|
||||
original_duration_listeners = jax_src_monitoring.get_event_duration_listeners()
|
||||
callback = lambda event, durations: None
|
||||
|
Loading…
x
Reference in New Issue
Block a user