mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #27993 from gnecula:explain_timing
PiperOrigin-RevId: 747480248
This commit is contained in:
commit
30669dc219
@ -67,6 +67,7 @@ from __future__ import annotations
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Hashable, NamedTuple
|
||||
import warnings
|
||||
import weakref
|
||||
@ -446,7 +447,7 @@ def _check_input_type(in_type: core.InputType) -> None:
|
||||
|
||||
|
||||
def cache(call: Callable, *,
|
||||
explain: Callable[[WrappedFun, bool, dict, tuple], None] | None = None):
|
||||
explain: Callable[[WrappedFun, bool, dict, tuple, float], None] | None = None):
|
||||
"""Memoization decorator for functions taking a WrappedFun as first argument.
|
||||
|
||||
Args:
|
||||
@ -455,7 +456,8 @@ def cache(call: Callable, *,
|
||||
memoization cache key.
|
||||
|
||||
explain: a function that is invoked upon cache misses to log an explanation
|
||||
of the miss. Invoked with `(fun, is_cache_first_use, cache, key)`.
|
||||
of the miss.
|
||||
Invoked with `(fun, is_cache_first_use, cache, key, elapsed_sec)`.
|
||||
|
||||
Returns:
|
||||
A memoized version of ``call``.
|
||||
@ -470,9 +472,11 @@ def cache(call: Callable, *,
|
||||
ans, stores = result
|
||||
fun.populate_stores(stores)
|
||||
else:
|
||||
if do_explain := explain and config.explain_cache_misses.value:
|
||||
start = time.time()
|
||||
ans = call(fun, *args)
|
||||
if explain and config.explain_cache_misses.value:
|
||||
explain(fun, cache is new_cache, cache, key)
|
||||
if do_explain:
|
||||
explain(fun, cache is new_cache, cache, key, time.time() - start) # type: ignore
|
||||
cache[key] = (ans, fun.stores)
|
||||
|
||||
return ans
|
||||
|
@ -1357,7 +1357,8 @@ def diff_tracing_cache_keys(
|
||||
|
||||
|
||||
def explain_tracing_cache_miss(
|
||||
fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple):
|
||||
fun: lu.WrappedFun, unseen_f: bool, cache: dict,
|
||||
key: tuple, elapsed_sec: float):
|
||||
if config.check_tracer_leaks.value: return
|
||||
if key[3][2].val: return # No explanations for "inline" functions
|
||||
|
||||
@ -1371,7 +1372,7 @@ def explain_tracing_cache_miss(
|
||||
done = lambda: logger.log(logging.WARNING, "\n".join(msg))
|
||||
|
||||
callsite = source_info_util.summarize(source_info_util.current())
|
||||
p(f"TRACING CACHE MISS at {callsite} because:")
|
||||
p(f"TRACING CACHE MISS at {callsite} costing {elapsed_sec * 1e3:.3f} ms because:")
|
||||
|
||||
# have we seen this function before at all?
|
||||
src_info = ""
|
||||
|
Loading…
x
Reference in New Issue
Block a user