diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index 1231d3066..bfe874305 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -67,6 +67,7 @@ from __future__ import annotations from collections.abc import Callable, Sequence from functools import partial import re +import time from typing import Any, Hashable, NamedTuple import warnings import weakref @@ -446,7 +447,7 @@ def _check_input_type(in_type: core.InputType) -> None: def cache(call: Callable, *, - explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None): + explain: Callable[[WrappedFun, bool, dict, tuple, float], None] | None = None): """Memoization decorator for functions taking a WrappedFun as first argument. Args: @@ -455,7 +456,8 @@ def cache(call: Callable, *, memoization cache key. explain: a function that is invoked upon cache misses to log an explanation - of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`. + of the miss. + Invoked with `(fun, is_cache_first_use, cache, key, elapsed_sec)`. Returns: A memoized version of ``call``. @@ -470,9 +472,11 @@ def cache(call: Callable, *, ans, stores = result fun.populate_stores(stores) else: + if do_explain := explain and config.explain_cache_misses.value: + start = time.time() ans = call(fun, *args) - if explain and config.explain_cache_misses.value: - explain(fun, cache is new_cache, cache, key) + if do_explain: + explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore cache[key] = (ans, fun.stores) return ans diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index afc7a5bed..a10856cbc 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1357,7 +1357,8 @@ def diff_tracing_cache_keys( def explain_tracing_cache_miss( - fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple): + fun: lu.WrappedFun, unseen_f: bool, cache: dict, + key: tuple, elapsed_sec: float): if config.check_tracer_leaks.value: return if key[3][2].val: return # No explanations for "inline" functions @@ -1371,7 +1372,7 @@ def explain_tracing_cache_miss( done = lambda: logger.log(logging.WARNING, "\n".join(msg)) callsite = source_info_util.summarize(source_info_util.current()) - p(f"TRACING CACHE MISS at {callsite} because:") + p(f"TRACING CACHE MISS at {callsite} costing {elapsed_sec * 1e3:.3f} ms because:") # have we seen this function before at all? src_info = ""