Add JAX events that have time spans, not only durations.

Log such events for log_elapsed_time.

The rationale for not replacing durations with it is that it appears that
record_event_duration_secs() is widely used outside of the code of JAX itself.

PiperOrigin-RevId: 713167192
This commit is contained in:
jax authors 2025-01-07 23:07:38 -08:00
parent 6d08f36f5b
commit 1bd781d992
3 changed files with 63 additions and 17 deletions

View File

@ -22,39 +22,38 @@ import dataclasses
import enum
from functools import partial
import itertools
import time
from typing import Any, NamedTuple
import logging
import threading
import numpy as np
import time
from typing import Any, NamedTuple
import jax
from jax._src import api
from jax._src import array
from jax._src import basearray
from jax._src import config
from jax._src import core
from jax._src import api
from jax._src import array
from jax._src import dtypes
from jax._src import lib
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.abstract_arrays import array_types
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.abstract_arrays import array_types
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src import lib
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.interpreters import xla
from jax._src.layout import DeviceLocalLayout, Layout
from jax._src.lib import xla_client as xc
from jax._src.monitoring import record_event_duration_secs
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.monitoring import record_event_duration_secs, record_event_time_span
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
SingleDeviceSharding, NamedSharding, TransferToMemoryKind,
from jax._src.sharding_impls import ( NamedSharding,
SingleDeviceSharding, TransferToMemoryKind,
is_single_device_sharding)
from jax._src.layout import Layout, DeviceLocalLayout
import numpy as np
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
@ -177,12 +176,14 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None):
log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG
start_time = time.time()
yield
elapsed_time = time.time() - start_time
end_time = time.time()
elapsed_time = end_time - start_time
if logger.isEnabledFor(log_priority):
logger.log(log_priority, fmt.format(
fun_name=fun_name, elapsed_time=elapsed_time))
if event is not None:
record_event_duration_secs(event, elapsed_time)
record_event_time_span(event, start_time, end_time)
def should_tuple_args(num_args: int, platform: str) -> bool:

View File

@ -39,8 +39,17 @@ class EventDurationListenerWithMetadata(Protocol):
...
class EventTimeSpanListenerWithMetadata(Protocol):
def __call__(
self, event: str, start_time: float, end_time: float, **kwargs: str | int
) -> None:
...
_event_listeners: list[EventListenerWithMetadata] = []
_event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = []
_event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = []
def record_event(event: str, **kwargs: str | int) -> None:
@ -64,6 +73,14 @@ def record_event_duration_secs(event: str, duration: float,
callback(event, duration, **kwargs)
def record_event_time_span(
event: str, start_time: float, end_time: float, **kwargs: str | int
) -> None:
"""Record an event start and end time in seconds (float)."""
for callback in _event_time_span_listeners:
callback(event, start_time, end_time, **kwargs)
def register_event_listener(
callback: EventListenerWithMetadata,
) -> None:
@ -71,6 +88,13 @@ def register_event_listener(
_event_listeners.append(callback)
def register_event_time_span_listener(
callback: EventTimeSpanListenerWithMetadata,
) -> None:
"""Register a callback to be invoked during record_event_time_span()."""
_event_time_span_listeners.append(callback)
def register_event_duration_secs_listener(
callback : EventDurationListenerWithMetadata) -> None:
"""Register a callback to be invoked during record_event_duration_secs()."""
@ -80,15 +104,22 @@ def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]:
"""Get event duration listeners."""
return list(_event_duration_secs_listeners)
def get_event_time_span_listeners() -> list[EventTimeSpanListenerWithMetadata]:
"""Get event time span listeners."""
return list(_event_time_span_listeners)
def get_event_listeners() -> list[EventListenerWithMetadata]:
"""Get event listeners."""
return list(_event_listeners)
def clear_event_listeners():
"""Clear event listeners."""
global _event_listeners, _event_duration_secs_listeners
global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners
_event_listeners = []
_event_duration_secs_listeners = []
_event_time_span_listeners = []
def _unregister_event_duration_listener_by_callback(
callback: EventDurationListenerWithMetadata) -> None:
@ -108,6 +139,18 @@ def _unregister_event_duration_listener_by_index(index: int) -> None:
assert -size <= index < size
del _event_duration_secs_listeners[index]
def _unregister_event_time_span_listener_by_callback(
callback: EventTimeSpanListenerWithMetadata,
) -> None:
"""Unregister an event time span listener by callback.
This function is supposed to be called for testing only.
"""
assert callback in _event_time_span_listeners
_event_time_span_listeners.remove(callback)
def _unregister_event_listener_by_callback(
callback: EventListenerWithMetadata) -> None:
"""Unregister an event listener by callback.

View File

@ -22,9 +22,11 @@ aggregation/exporting.
"""
from jax._src.monitoring import (
clear_event_listeners as clear_event_listeners,
record_event_duration_secs as record_event_duration_secs,
record_event_time_span as record_event_time_span,
record_event as record_event,
register_event_duration_secs_listener as register_event_duration_secs_listener,
register_event_listener as register_event_listener,
clear_event_listeners as clear_event_listeners,
register_event_time_span_listener as register_event_time_span_listener,
)