diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index f011e756d..471b0c4af 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -22,39 +22,38 @@ import dataclasses import enum from functools import partial import itertools -import time -from typing import Any, NamedTuple import logging import threading - -import numpy as np +import time +from typing import Any, NamedTuple import jax +from jax._src import api +from jax._src import array from jax._src import basearray from jax._src import config from jax._src import core -from jax._src import api -from jax._src import array from jax._src import dtypes +from jax._src import lib from jax._src import source_info_util from jax._src import traceback_util from jax._src import util +from jax._src.abstract_arrays import array_types from jax._src.interpreters import ad from jax._src.interpreters import batching -from jax._src.abstract_arrays import array_types from jax._src.interpreters import mlir -from jax._src.interpreters import xla from jax._src.interpreters import pxla -from jax._src import lib -from jax._src.mesh import AbstractMesh, Mesh +from jax._src.interpreters import xla +from jax._src.layout import DeviceLocalLayout, Layout from jax._src.lib import xla_client as xc -from jax._src.monitoring import record_event_duration_secs +from jax._src.mesh import AbstractMesh, Mesh +from jax._src.monitoring import record_event_duration_secs, record_event_time_span from jax._src.partition_spec import PartitionSpec from jax._src.sharding import Sharding -from jax._src.sharding_impls import ( - SingleDeviceSharding, NamedSharding, TransferToMemoryKind, +from jax._src.sharding_impls import ( NamedSharding, + SingleDeviceSharding, TransferToMemoryKind, is_single_device_sharding) -from jax._src.layout import Layout, DeviceLocalLayout +import numpy as np JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" @@ -177,12 +176,14 @@ def log_elapsed_time(fmt: str, fun_name: str, event: str | None = None): log_priority = logging.WARNING if config.log_compiles.value else logging.DEBUG start_time = time.time() yield - elapsed_time = time.time() - start_time + end_time = time.time() + elapsed_time = end_time - start_time if logger.isEnabledFor(log_priority): logger.log(log_priority, fmt.format( fun_name=fun_name, elapsed_time=elapsed_time)) if event is not None: record_event_duration_secs(event, elapsed_time) + record_event_time_span(event, start_time, end_time) def should_tuple_args(num_args: int, platform: str) -> bool: diff --git a/jax/_src/monitoring.py b/jax/_src/monitoring.py index 3b291de00..99e957733 100644 --- a/jax/_src/monitoring.py +++ b/jax/_src/monitoring.py @@ -39,8 +39,17 @@ class EventDurationListenerWithMetadata(Protocol): ... +class EventTimeSpanListenerWithMetadata(Protocol): + + def __call__( + self, event: str, start_time: float, end_time: float, **kwargs: str | int + ) -> None: + ... + + _event_listeners: list[EventListenerWithMetadata] = [] _event_duration_secs_listeners: list[EventDurationListenerWithMetadata] = [] +_event_time_span_listeners: list[EventTimeSpanListenerWithMetadata] = [] def record_event(event: str, **kwargs: str | int) -> None: @@ -64,6 +73,14 @@ def record_event_duration_secs(event: str, duration: float, callback(event, duration, **kwargs) +def record_event_time_span( + event: str, start_time: float, end_time: float, **kwargs: str | int +) -> None: + """Record an event start and end time in seconds (float).""" + for callback in _event_time_span_listeners: + callback(event, start_time, end_time, **kwargs) + + def register_event_listener( callback: EventListenerWithMetadata, ) -> None: @@ -71,6 +88,13 @@ def register_event_listener( _event_listeners.append(callback) +def register_event_time_span_listener( + callback: EventTimeSpanListenerWithMetadata, +) -> None: + """Register a callback to be invoked during record_event_time_span().""" + _event_time_span_listeners.append(callback) + + def register_event_duration_secs_listener( callback : EventDurationListenerWithMetadata) -> None: """Register a callback to be invoked during record_event_duration_secs().""" @@ -80,15 +104,22 @@ def get_event_duration_listeners() -> list[EventDurationListenerWithMetadata]: """Get event duration listeners.""" return list(_event_duration_secs_listeners) + +def get_event_time_span_listeners() -> list[EventTimeSpanListenerWithMetadata]: + """Get event time span listeners.""" + return list(_event_time_span_listeners) + + def get_event_listeners() -> list[EventListenerWithMetadata]: """Get event listeners.""" return list(_event_listeners) def clear_event_listeners(): """Clear event listeners.""" - global _event_listeners, _event_duration_secs_listeners + global _event_listeners, _event_duration_secs_listeners, _event_time_span_listeners _event_listeners = [] _event_duration_secs_listeners = [] + _event_time_span_listeners = [] def _unregister_event_duration_listener_by_callback( callback: EventDurationListenerWithMetadata) -> None: @@ -108,6 +139,18 @@ def _unregister_event_duration_listener_by_index(index: int) -> None: assert -size <= index < size del _event_duration_secs_listeners[index] + +def _unregister_event_time_span_listener_by_callback( + callback: EventTimeSpanListenerWithMetadata, +) -> None: + """Unregister an event time span listener by callback. + + This function is supposed to be called for testing only. + """ + assert callback in _event_time_span_listeners + _event_time_span_listeners.remove(callback) + + def _unregister_event_listener_by_callback( callback: EventListenerWithMetadata) -> None: """Unregister an event listener by callback. diff --git a/jax/monitoring.py b/jax/monitoring.py index 374e301b9..4c9996da5 100644 --- a/jax/monitoring.py +++ b/jax/monitoring.py @@ -22,9 +22,11 @@ aggregation/exporting. """ from jax._src.monitoring import ( + clear_event_listeners as clear_event_listeners, record_event_duration_secs as record_event_duration_secs, + record_event_time_span as record_event_time_span, record_event as record_event, register_event_duration_secs_listener as register_event_duration_secs_listener, register_event_listener as register_event_listener, - clear_event_listeners as clear_event_listeners, + register_event_time_span_listener as register_event_time_span_listener, )