Add JAX monitoring library that instruments code via events.

PiperOrigin-RevId: 488731805
This commit is contained in:
jax authors 2022-11-15 12:41:08 -08:00
parent a419e1917a
commit 726b2bc2ee
7 changed files with 143 additions and 10 deletions

View File

@ -36,6 +36,7 @@ import jax
from jax import core
from jax import linear_util as lu
from jax.errors import UnexpectedTracerError
from jax.monitoring import record_event_duration_secs
import jax.interpreters.ad as ad
import jax.interpreters.batching as batching
import jax.interpreters.mlir as mlir
@ -56,6 +57,8 @@ import jax._src.util as util
from jax._src.util import flatten, unflatten
from jax._src import path
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
FLAGS = flags.FLAGS
@ -370,7 +373,7 @@ def is_single_device_sharding(sharding) -> bool:
@contextlib.contextmanager
def log_elapsed_time(fmt: str):
def log_elapsed_time(fmt: str, event: Optional[str] = None):
if _on_exit:
yield
else:
@ -379,6 +382,8 @@ def log_elapsed_time(fmt: str):
yield
elapsed_time = time.time() - start_time
logger.log(log_priority, fmt.format(elapsed_time=elapsed_time))
if event is not None:
record_event_duration_secs(event, elapsed_time)
def should_tuple_args(num_args: int, platform: str):
@ -441,7 +446,8 @@ def lower_xla_callable(
abstract_args = tuple(aval for aval, _ in fun.in_type)
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
"for jit in {elapsed_time} sec",
event=JAXPR_TRACE_EVENT):
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
fun, pe.debug_info_final(fun, "jit"))
out_avals, kept_outputs = util.unzip2(out_type)
@ -1190,7 +1196,8 @@ class XlaCompiledComputation(stages.XlaExecutable):
device_assignment=(sticky_device,) if sticky_device else None)
options.parameter_is_tupled_arguments = tuple_args
with log_elapsed_time(f"Finished XLA compilation of {name} "
"in {elapsed_time} sec"):
"in {elapsed_time} sec",
event=BACKEND_COMPILE_EVENT):
compiled = compile_or_get_cached(backend, xla_computation, options,
host_callbacks)
buffer_counts = get_buffer_counts(out_avals, ordered_effects,

View File

@ -677,7 +677,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
for aval, in_axes in zip(in_avals, in_axes)]
with core.extend_axis_env_nd(global_axis_sizes.items()):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for xmap in {elapsed_time} sec"):
"for xmap in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
out_axes = out_axes_thunk()
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)

View File

@ -658,7 +658,8 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree):
try:
maps._positional_semantics.val = maps._PositionalSemantics.GLOBAL
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for pjit in {elapsed_time} sec"):
"for pjit in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
finally:
maps._positional_semantics.val = prev_positional_val

View File

@ -1296,7 +1296,8 @@ def stage_parallel_callable(
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for pmap in {elapsed_time} sec"):
"for pmap in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
@ -1634,7 +1635,8 @@ class PmapExecutable(stages.XlaExecutable):
ordered_effects)
with dispatch.log_elapsed_time(
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec"):
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec",
event=dispatch.BACKEND_COMPILE_EVENT):
compiled = dispatch.compile_or_get_cached(
pci.backend, xla_computation, compile_options, host_callbacks)
handle_args = InputsHandler(
@ -2759,7 +2761,8 @@ def lower_sharding_computation(
name_stack = new_name_stack(wrap_name(fun_name, api_name))
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec"):
"in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))
kept_outputs = [True] * len(global_out_avals)
@ -3015,7 +3018,8 @@ def lower_mesh_computation(
in_jaxpr_avals = in_tiled_avals
with core.extend_axis_env_nd(mesh.shape.items()):
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
"in {elapsed_time} sec"):
"in {elapsed_time} sec",
event=dispatch.JAXPR_TRACE_EVENT):
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
assert len(out_shardings) == len(out_jaxpr_avals)
if spmd_lowering:
@ -3404,7 +3408,8 @@ class UnloadedMeshExecutable:
kept_var_idx, backend, device_assignment, committed, pmap_nreps)
else:
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
"in {elapsed_time} sec"):
"in {elapsed_time} sec",
event=dispatch.BACKEND_COMPILE_EVENT):
xla_executable = dispatch.compile_or_get_cached(
backend, computation, compile_options, host_callbacks)

50
jax/monitoring.py Normal file
View File

@ -0,0 +1,50 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for instrumenting code.
Code points can be marked as a named event. Every time an event is reached
during program execution, the registered listeners will be invoked.
A typical listener callback is to send an event to a metrics collector for
aggregation/exporting.
"""
from typing import Callable, List
_event_listeners: List[Callable[[str], None]] = []
_event_duration_secs_listeners: List[Callable[[str, float], None]] = []
def record_event(event: str):
"""Record an event."""
for callback in _event_listeners:
callback(event)
def record_event_duration_secs(event: str, duration: float):
"""Record an event duration in seconds (float)."""
for callback in _event_duration_secs_listeners:
callback(event, duration)
def register_event_listener(callback: Callable[[str], None]):
"""Register a callback to be invoked during record_event()."""
_event_listeners.append(callback)
def register_event_duration_secs_listener(callback : Callable[[str, float], None]):
"""Register a callback to be invoked during record_event_duration_secs()."""
_event_duration_secs_listeners.append(callback)
def _clear_event_listeners():
"""Clear event listeners."""
global _event_listeners, _event_duration_secs_listeners
_event_listeners = []
_event_duration_secs_listeners = []

View File

@ -504,6 +504,15 @@ jax_test(
],
)
py_test(
name = "monitoring_test",
srcs = ["monitoring_test.py"],
deps = [
"//jax",
"//jax:test_util",
],
)
jax_test(
name = "multibackend_test",
srcs = ["multibackend_test.py"],

60
tests/monitoring_test.py Normal file
View File

@ -0,0 +1,60 @@
# Copyright 2022 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for jax.monitoring.
Verify that callbacks are registered and invoked correctly to record events.
"""
from absl.testing import absltest
from jax import monitoring
class MonitoringTest(absltest.TestCase):
def test_record_event(self):
events = []
counters = {} # Map event names to frequency.
def increment_event_counter(event):
if event not in counters:
counters[event] = 0
counters[event] += 1
# Test that we can register multiple callbacks.
monitoring.register_event_listener(events.append)
monitoring.register_event_listener(increment_event_counter)
monitoring.record_event("test_unique_event")
monitoring.record_event("test_common_event")
monitoring.record_event("test_common_event")
self.assertListEqual(events, ["test_unique_event",
"test_common_event", "test_common_event"])
self.assertDictEqual(counters, {"test_unique_event": 1,
"test_common_event": 2})
def test_record_event_durations(self):
durations = {} # Map event names to frequency.
def increment_event_duration(event, duration):
if event not in durations:
durations[event] = 0.
durations[event] += duration
monitoring.register_event_duration_secs_listener(increment_event_duration)
monitoring.record_event_duration_secs("test_short_event", 1)
monitoring.record_event_duration_secs("test_short_event", 2)
monitoring.record_event_duration_secs("test_long_event", 10)
self.assertDictEqual(durations, {"test_short_event": 3,
"test_long_event": 10})
if __name__ == "__main__":
absltest.main()