mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add JAX monitoring library that instruments code via events.
PiperOrigin-RevId: 488731805
This commit is contained in:
parent
a419e1917a
commit
726b2bc2ee
@ -36,6 +36,7 @@ import jax
|
||||
from jax import core
|
||||
from jax import linear_util as lu
|
||||
from jax.errors import UnexpectedTracerError
|
||||
from jax.monitoring import record_event_duration_secs
|
||||
import jax.interpreters.ad as ad
|
||||
import jax.interpreters.batching as batching
|
||||
import jax.interpreters.mlir as mlir
|
||||
@ -56,6 +57,8 @@ import jax._src.util as util
|
||||
from jax._src.util import flatten, unflatten
|
||||
from jax._src import path
|
||||
|
||||
JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
|
||||
BACKEND_COMPILE_EVENT = "/jax/core/compile/backend_compile_duration"
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
@ -370,7 +373,7 @@ def is_single_device_sharding(sharding) -> bool:
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def log_elapsed_time(fmt: str):
|
||||
def log_elapsed_time(fmt: str, event: Optional[str] = None):
|
||||
if _on_exit:
|
||||
yield
|
||||
else:
|
||||
@ -379,6 +382,8 @@ def log_elapsed_time(fmt: str):
|
||||
yield
|
||||
elapsed_time = time.time() - start_time
|
||||
logger.log(log_priority, fmt.format(elapsed_time=elapsed_time))
|
||||
if event is not None:
|
||||
record_event_duration_secs(event, elapsed_time)
|
||||
|
||||
|
||||
def should_tuple_args(num_args: int, platform: str):
|
||||
@ -441,7 +446,8 @@ def lower_xla_callable(
|
||||
abstract_args = tuple(aval for aval, _ in fun.in_type)
|
||||
|
||||
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for jit in {elapsed_time} sec"):
|
||||
"for jit in {elapsed_time} sec",
|
||||
event=JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
|
||||
fun, pe.debug_info_final(fun, "jit"))
|
||||
out_avals, kept_outputs = util.unzip2(out_type)
|
||||
@ -1190,7 +1196,8 @@ class XlaCompiledComputation(stages.XlaExecutable):
|
||||
device_assignment=(sticky_device,) if sticky_device else None)
|
||||
options.parameter_is_tupled_arguments = tuple_args
|
||||
with log_elapsed_time(f"Finished XLA compilation of {name} "
|
||||
"in {elapsed_time} sec"):
|
||||
"in {elapsed_time} sec",
|
||||
event=BACKEND_COMPILE_EVENT):
|
||||
compiled = compile_or_get_cached(backend, xla_computation, options,
|
||||
host_callbacks)
|
||||
buffer_counts = get_buffer_counts(out_avals, ordered_effects,
|
||||
|
@ -677,7 +677,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
|
||||
for aval, in_axes in zip(in_avals, in_axes)]
|
||||
with core.extend_axis_env_nd(global_axis_sizes.items()):
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for xmap in {elapsed_time} sec"):
|
||||
"for xmap in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
|
||||
out_axes = out_axes_thunk()
|
||||
_check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes)
|
||||
|
@ -658,7 +658,8 @@ def _pjit_jaxpr(fun, out_shardings_thunk, global_in_avals, out_tree):
|
||||
try:
|
||||
maps._positional_semantics.val = maps._PositionalSemantics.GLOBAL
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for pjit in {elapsed_time} sec"):
|
||||
"for pjit in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, global_in_avals)
|
||||
finally:
|
||||
maps._positional_semantics.val = prev_positional_val
|
||||
|
@ -1296,7 +1296,8 @@ def stage_parallel_callable(
|
||||
|
||||
with core.extend_axis_env(pci.axis_name, pci.global_axis_size, None): # type: ignore
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
|
||||
"for pmap in {elapsed_time} sec"):
|
||||
"for pmap in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, global_sharded_avals, pe.debug_info_final(fun, "pmap"))
|
||||
jaxpr = dispatch.apply_outfeed_rewriter(jaxpr)
|
||||
@ -1634,7 +1635,8 @@ class PmapExecutable(stages.XlaExecutable):
|
||||
ordered_effects)
|
||||
|
||||
with dispatch.log_elapsed_time(
|
||||
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec"):
|
||||
f"Finished XLA compilation of {pci.name} in {{elapsed_time}} sec",
|
||||
event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
compiled = dispatch.compile_or_get_cached(
|
||||
pci.backend, xla_computation, compile_options, host_callbacks)
|
||||
handle_args = InputsHandler(
|
||||
@ -2759,7 +2761,8 @@ def lower_sharding_computation(
|
||||
name_stack = new_name_stack(wrap_name(fun_name, api_name))
|
||||
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
||||
"in {elapsed_time} sec"):
|
||||
"in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
|
||||
fun, global_in_avals, debug_info=pe.debug_info_final(fun, api_name))
|
||||
kept_outputs = [True] * len(global_out_avals)
|
||||
@ -3015,7 +3018,8 @@ def lower_mesh_computation(
|
||||
in_jaxpr_avals = in_tiled_avals
|
||||
with core.extend_axis_env_nd(mesh.shape.items()):
|
||||
with dispatch.log_elapsed_time(f"Finished tracing + transforming {name_stack} "
|
||||
"in {elapsed_time} sec"):
|
||||
"in {elapsed_time} sec",
|
||||
event=dispatch.JAXPR_TRACE_EVENT):
|
||||
jaxpr, out_jaxpr_avals, consts = pe.trace_to_jaxpr_final(fun, in_jaxpr_avals)
|
||||
assert len(out_shardings) == len(out_jaxpr_avals)
|
||||
if spmd_lowering:
|
||||
@ -3404,7 +3408,8 @@ class UnloadedMeshExecutable:
|
||||
kept_var_idx, backend, device_assignment, committed, pmap_nreps)
|
||||
else:
|
||||
with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} "
|
||||
"in {elapsed_time} sec"):
|
||||
"in {elapsed_time} sec",
|
||||
event=dispatch.BACKEND_COMPILE_EVENT):
|
||||
xla_executable = dispatch.compile_or_get_cached(
|
||||
backend, computation, compile_options, host_callbacks)
|
||||
|
||||
|
50
jax/monitoring.py
Normal file
50
jax/monitoring.py
Normal file
@ -0,0 +1,50 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utilities for instrumenting code.
|
||||
|
||||
Code points can be marked as a named event. Every time an event is reached
|
||||
during program execution, the registered listeners will be invoked.
|
||||
|
||||
A typical listener callback is to send an event to a metrics collector for
|
||||
aggregation/exporting.
|
||||
"""
|
||||
from typing import Callable, List
|
||||
|
||||
_event_listeners: List[Callable[[str], None]] = []
|
||||
_event_duration_secs_listeners: List[Callable[[str, float], None]] = []
|
||||
|
||||
def record_event(event: str):
|
||||
"""Record an event."""
|
||||
for callback in _event_listeners:
|
||||
callback(event)
|
||||
|
||||
def record_event_duration_secs(event: str, duration: float):
|
||||
"""Record an event duration in seconds (float)."""
|
||||
for callback in _event_duration_secs_listeners:
|
||||
callback(event, duration)
|
||||
|
||||
def register_event_listener(callback: Callable[[str], None]):
|
||||
"""Register a callback to be invoked during record_event()."""
|
||||
_event_listeners.append(callback)
|
||||
|
||||
def register_event_duration_secs_listener(callback : Callable[[str, float], None]):
|
||||
"""Register a callback to be invoked during record_event_duration_secs()."""
|
||||
_event_duration_secs_listeners.append(callback)
|
||||
|
||||
def _clear_event_listeners():
|
||||
"""Clear event listeners."""
|
||||
global _event_listeners, _event_duration_secs_listeners
|
||||
_event_listeners = []
|
||||
_event_duration_secs_listeners = []
|
@ -504,6 +504,15 @@ jax_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "monitoring_test",
|
||||
srcs = ["monitoring_test.py"],
|
||||
deps = [
|
||||
"//jax",
|
||||
"//jax:test_util",
|
||||
],
|
||||
)
|
||||
|
||||
jax_test(
|
||||
name = "multibackend_test",
|
||||
srcs = ["multibackend_test.py"],
|
||||
|
60
tests/monitoring_test.py
Normal file
60
tests/monitoring_test.py
Normal file
@ -0,0 +1,60 @@
|
||||
# Copyright 2022 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for jax.monitoring.
|
||||
|
||||
Verify that callbacks are registered and invoked correctly to record events.
|
||||
"""
|
||||
from absl.testing import absltest
|
||||
from jax import monitoring
|
||||
|
||||
class MonitoringTest(absltest.TestCase):
|
||||
|
||||
def test_record_event(self):
|
||||
events = []
|
||||
counters = {} # Map event names to frequency.
|
||||
def increment_event_counter(event):
|
||||
if event not in counters:
|
||||
counters[event] = 0
|
||||
counters[event] += 1
|
||||
# Test that we can register multiple callbacks.
|
||||
monitoring.register_event_listener(events.append)
|
||||
monitoring.register_event_listener(increment_event_counter)
|
||||
|
||||
monitoring.record_event("test_unique_event")
|
||||
monitoring.record_event("test_common_event")
|
||||
monitoring.record_event("test_common_event")
|
||||
|
||||
self.assertListEqual(events, ["test_unique_event",
|
||||
"test_common_event", "test_common_event"])
|
||||
self.assertDictEqual(counters, {"test_unique_event": 1,
|
||||
"test_common_event": 2})
|
||||
|
||||
def test_record_event_durations(self):
|
||||
durations = {} # Map event names to frequency.
|
||||
def increment_event_duration(event, duration):
|
||||
if event not in durations:
|
||||
durations[event] = 0.
|
||||
durations[event] += duration
|
||||
monitoring.register_event_duration_secs_listener(increment_event_duration)
|
||||
|
||||
monitoring.record_event_duration_secs("test_short_event", 1)
|
||||
monitoring.record_event_duration_secs("test_short_event", 2)
|
||||
monitoring.record_event_duration_secs("test_long_event", 10)
|
||||
|
||||
self.assertDictEqual(durations, {"test_short_event": 3,
|
||||
"test_long_event": 10})
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
Loading…
x
Reference in New Issue
Block a user