From 726b2bc2ee1bbc5690ad2b6ab0e2b77aa88eb868 Mon Sep 17 00:00:00 2001 From: jax authors Date: Tue, 15 Nov 2022 12:41:08 -0800 Subject: [PATCH] Add JAX monitoring library that instruments code via events. PiperOrigin-RevId: 488731805 --- jax/_src/dispatch.py | 13 +++++++-- jax/experimental/maps.py | 3 +- jax/experimental/pjit.py | 3 +- jax/interpreters/pxla.py | 15 ++++++---- jax/monitoring.py | 50 +++++++++++++++++++++++++++++++++ tests/BUILD | 9 ++++++ tests/monitoring_test.py | 60 ++++++++++++++++++++++++++++++++++++++++ 7 files changed, 143 insertions(+), 10 deletions(-) create mode 100644 jax/monitoring.py create mode 100644 tests/monitoring_test.py diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 900b25253..23aaa5105 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -36,6 +36,7 @@ import jax from jax import core from jax import linear_util as lu from jax.errors import UnexpectedTracerError +from jax.monitoring import record_event_duration_secs import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir @@ -56,6 +57,8 @@ import jax._src.util as util from jax._src.util import flatten, unflatten from jax._src import path +JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration" +BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration" FLAGS = flags.FLAGS @@ -370,7 +373,7 @@ def is_single_device_sharding(sharding) -> bool: @contextlib.contextmanager -def log_elapsed_time(fmt: str): +def log_elapsed_time(fmt: str, event: Optional[str] = None): if _on_exit: yield else: @@ -379,6 +382,8 @@ def log_elapsed_time(fmt: str): yield elapsed_time = time.time() - start_time logger.log(log_priority, fmt.format(elapsed_time=elapsed_time)) + if event is not None: + record_event_duration_secs(event, elapsed_time) def should_tuple_args(num_args: int, platform: str): @@ -441,7 +446,8 @@ def lower_xla_callable( abstract_args = tuple(aval for aval, _ in fun.in_type) with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " - "for jit in {elapsed_time} sec"): + "for jit in {elapsed_time} sec", + event=JAXPR_TRACE_EVENT): jaxpr, out_type, consts = pe.trace_to_jaxpr_final2( fun, pe.debug_info_final(fun, "jit")) out_avals, kept_outputs = util.unzip2(out_type) @@ -1190,7 +1196,8 @@ class XlaCompiledComputation(stages.XlaExecutable): device_assignment=(sticky_device,) if sticky_device else None) options.parameter_is_tupled_arguments = tuple_args with log_elapsed_time(f"Finished XLA compilation of {name} " - "in {elapsed_time} sec"): + "in {elapsed_time} sec", + event=BACKEND_COMPILE_EVENT): compiled = compile_or_get_cached(backend, xla_computation, options, host_callbacks) buffer_counts = get_buffer_counts(out_avals, ordered_effects, diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 5204462ff..9bead238e 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -677,7 +677,8 @@ def make_xmap_callable(fun: lu.WrappedFun, for aval, in_axes in zip(in_avals, in_axes)] with core.extend_axis_env_nd(global_axis_sizes.items()): with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " - "for xmap in {elapsed_time} sec"): + "for xmap in {elapsed_time} sec", + event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals) out_axes = out_axes_thunk() _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 0c001120c..6dac4fd5d 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -658,7 +658,8 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree): try: maps._positional_semantics.val = maps._PositionalSemantics.GLOBAL with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " - "for pjit in {elapsed_time} sec"): + "for pjit in {elapsed_time} sec", + event=dispatch.JAXPR_TRACE_EVENT): jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals) finally: maps._positional_semantics.val = prev_positional_val diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 5a8c93b30..81bd26f33 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -1296,7 +1296,8 @@ def stage_parallel_callable( with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} " - "for pmap in {elapsed_time} sec"): + "for pmap in {elapsed_time} sec", + event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final( fun, global_sharded_avals, pe.debug_info_final(fun, "pmap")) jaxpr = dispatch.apply_outfeed_rewriter(jaxpr) @@ -1634,7 +1635,8 @@ class PmapExecutable(stages.XlaExecutable): ordered_effects) with dispatch.log_elapsed_time( - f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec"): + f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec", + event=dispatch.BACKEND_COMPILE_EVENT): compiled = dispatch.compile_or_get_cached( pci.backend, xla_computation, compile_options, host_callbacks) handle_args = InputsHandler( @@ -2759,7 +2761,8 @@ def lower_sharding_computation( name_stack = new_name_stack(wrap_name(fun_name, api_name)) with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " - "in {elapsed_time} sec"): + "in {elapsed_time} sec", + event=dispatch.JAXPR_TRACE_EVENT): jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name)) kept_outputs = [True] * len(global_out_avals) @@ -3015,7 +3018,8 @@ def lower_mesh_computation( in_jaxpr_avals = in_tiled_avals with core.extend_axis_env_nd(mesh.shape.items()): with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} " - "in {elapsed_time} sec"): + "in {elapsed_time} sec", + event=dispatch.JAXPR_TRACE_EVENT): jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals) assert len(out_shardings) == len(out_jaxpr_avals) if spmd_lowering: @@ -3404,7 +3408,8 @@ class UnloadedMeshExecutable: kept_var_idx, backend, device_assignment, committed, pmap_nreps) else: with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} " - "in {elapsed_time} sec"): + "in {elapsed_time} sec", + event=dispatch.BACKEND_COMPILE_EVENT): xla_executable = dispatch.compile_or_get_cached( backend, computation, compile_options, host_callbacks) diff --git a/jax/monitoring.py b/jax/monitoring.py new file mode 100644 index 000000000..ebcf0c076 --- /dev/null +++ b/jax/monitoring.py @@ -0,0 +1,50 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for instrumenting code. + +Code points can be marked as a named event. Every time an event is reached +during program execution, the registered listeners will be invoked. + +A typical listener callback is to send an event to a metrics collector for +aggregation/exporting. +""" +from typing import Callable, List + +_event_listeners: List[Callable[[str], None]] = [] +_event_duration_secs_listeners: List[Callable[[str, float], None]] = [] + +def record_event(event: str): + """Record an event.""" + for callback in _event_listeners: + callback(event) + +def record_event_duration_secs(event: str, duration: float): + """Record an event duration in seconds (float).""" + for callback in _event_duration_secs_listeners: + callback(event, duration) + +def register_event_listener(callback: Callable[[str], None]): + """Register a callback to be invoked during record_event().""" + _event_listeners.append(callback) + +def register_event_duration_secs_listener(callback : Callable[[str, float], None]): + """Register a callback to be invoked during record_event_duration_secs().""" + _event_duration_secs_listeners.append(callback) + +def _clear_event_listeners(): + """Clear event listeners.""" + global _event_listeners, _event_duration_secs_listeners + _event_listeners = [] + _event_duration_secs_listeners = [] diff --git a/tests/BUILD b/tests/BUILD index 90ae35558..b5c1cf530 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -504,6 +504,15 @@ jax_test( ], ) +py_test( + name = "monitoring_test", + srcs = ["monitoring_test.py"], + deps = [ + "//jax", + "//jax:test_util", + ], +) + jax_test( name = "multibackend_test", srcs = ["multibackend_test.py"], diff --git a/tests/monitoring_test.py b/tests/monitoring_test.py new file mode 100644 index 000000000..68c38f34e --- /dev/null +++ b/tests/monitoring_test.py @@ -0,0 +1,60 @@ +# Copyright 2022 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tests for jax.monitoring. + +Verify that callbacks are registered and invoked correctly to record events. +""" +from absl.testing import absltest +from jax import monitoring + +class MonitoringTest(absltest.TestCase): + + def test_record_event(self): + events = [] + counters = {} # Map event names to frequency. + def increment_event_counter(event): + if event not in counters: + counters[event] = 0 + counters[event] += 1 + # Test that we can register multiple callbacks. + monitoring.register_event_listener(events.append) + monitoring.register_event_listener(increment_event_counter) + + monitoring.record_event("test_unique_event") + monitoring.record_event("test_common_event") + monitoring.record_event("test_common_event") + + self.assertListEqual(events, ["test_unique_event", + "test_common_event", "test_common_event"]) + self.assertDictEqual(counters, {"test_unique_event": 1, + "test_common_event": 2}) + + def test_record_event_durations(self): + durations = {} # Map event names to frequency. + def increment_event_duration(event, duration): + if event not in durations: + durations[event] = 0. + durations[event] += duration + monitoring.register_event_duration_secs_listener(increment_event_duration) + + monitoring.record_event_duration_secs("test_short_event", 1) + monitoring.record_event_duration_secs("test_short_event", 2) + monitoring.record_event_duration_secs("test_long_event", 10) + + self.assertDictEqual(durations, {"test_short_event": 3, + "test_long_event": 10}) + + +if __name__ == "__main__": + absltest.main()