Merge pull request #27993 from gnecula:explain_timing

PiperOrigin-RevId: 747480248
This commit is contained in:
jax authors 2025-04-14 10:41:05 -07:00
commit 30669dc219
2 changed files with 11 additions and 6 deletions

View File

@ -67,6 +67,7 @@ from __future__ import annotations
from collections.abc import Callable, Sequence
from functools import partial
import re
import time
from typing import Any, Hashable, NamedTuple
import warnings
import weakref
@ -446,7 +447,7 @@ def _check_input_type(in_type: core.InputType) -> None:
def cache(call: Callable, *,
explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None):
explain: Callable[[WrappedFun, bool, dict, tuple, float], None] | None = None):
"""Memoization decorator for functions taking a WrappedFun as first argument.
Args:
@ -455,7 +456,8 @@ def cache(call: Callable, *,
memoization cache key.
explain: a function that is invoked upon cache misses to log an explanation
of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`.
of the miss.
Invoked with `(fun, is_cache_first_use, cache, key, elapsed_sec)`.
Returns:
A memoized version of ``call``.
@ -470,9 +472,11 @@ def cache(call: Callable, *,
ans, stores = result
fun.populate_stores(stores)
else:
if do_explain := explain and config.explain_cache_misses.value:
start = time.time()
ans = call(fun, *args)
if explain and config.explain_cache_misses.value:
explain(fun, cache is new_cache, cache, key)
if do_explain:
explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore
cache[key] = (ans, fun.stores)
return ans

View File

@ -1357,7 +1357,8 @@ def diff_tracing_cache_keys(
def explain_tracing_cache_miss(
fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple):
fun: lu.WrappedFun, unseen_f: bool, cache: dict,
key: tuple, elapsed_sec: float):
if config.check_tracer_leaks.value: return
if key[3][2].val: return # No explanations for "inline" functions
@ -1371,7 +1372,7 @@ def explain_tracing_cache_miss(
done = lambda: logger.log(logging.WARNING, "\n".join(msg))
callsite = source_info_util.summarize(source_info_util.current())
p(f"TRACING CACHE MISS at {callsite} because:")
p(f"TRACING CACHE MISS at {callsite} costing {elapsed_sec * 1e3:.3f} ms because:")
# have we seen this function before at all?
src_info = ""