[explain-cache-miss] Improve tracing-cache-miss explanations

The previous approach was to report, for several elements
of the cache key, the closest mismatch. Some parts of
the cache key were ignored, which led to "explanation unavailable".
The same happened when we had two keys close to the current
one, each differring in a different part of the key.
No explanation was produced because for each part of the key,
there was a matching key already in the cache, even though
the key taken as a whole did not match.

Now, we scan *all* parts of they key and compute the differences.
We keep track of the "size" of the differences, and we explain
the differences to those keys that are closest (possibly more
than one key if equidistant).
For example, for shape differences we'll report the
closest matching shape. If a type differs in both the dtype
and some parts of the shape, or sharding, it is considered
farther away.

We add new tests and explanations for  different
static argnums and argnames.

There are still cases when we do not produce an explanation, but
now the "explanation unavailable" includes a description
of which component of the key is different, and what the
difference is. This may still be hard to understand by the
user but at least they can file a clearer bug.

Refactored the tests, and added a few new ones.
This commit is contained in:
George Necula 2025-04-10 08:46:11 +02:00
parent 19d3d954bf
commit f070cdecb3
4 changed files with 435 additions and 228 deletions

View File

@ -21,7 +21,6 @@ import dataclasses
from functools import partial
import inspect
import logging
import operator as op
import weakref
from typing import NamedTuple, Any, Union, cast
import warnings
@ -1158,17 +1157,209 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
callsites_with_tracing_cache_miss: set[str] = set()
def diff_tracing_cache_keys(
k: tuple, oldk: tuple, debug_info: lu.DebugInfo) -> tuple[Sequence[str], int]:
"""Explanations of differences between the cache keys, along with diff sizes.
Result: a pair of a list of explanations for differences, and the total size
of the differences. The sizes are used to pick the old key with the smallest
different size for the explanation that is shown to the user.
"""
(fun_transforms_k, fun_params_k, fun_in_type_k,
(arg_in_type_k, arg_attr_data_k, arg_inline_k), ctx_k) = k
(fun_transforms_ok, fun_params_ok, fun_in_type_ok,
(arg_in_type_ok, arg_attr_data_ok, arg_inline_ok), ctx_ok) = oldk
diffs: list[tuple[str, int]] = [] # each difference with its size
def unavailable(key_field: str, what_k, what_ok):
diffs.append(
(f"different {key_field}:\n now: {what_k}\n != before: {what_ok}.\n"
"explanation unavailable! "
"please open an issue at https://github.com/jax-ml/jax.",
10))
def list_diff_size(s1: Sequence, s2: Sequence) -> int:
min_len = min(len(s1), len(s2))
diff_size = max(len(s1), len(s2)) - min_len
diff_size += sum(e1 != e2 for e1, e2 in zip(s1[:min_len],
s2[:min_len]))
return diff_size
different_leaf_count = False
def explain_transform_argnums_partial(param_k: tuple, param_ok: tuple):
dyn_argnums_k, static_args_k = param_k
dyn_argnums_ok, static_args_ok = param_ok
if dyn_argnums_k != dyn_argnums_ok:
diffs.append(
("different static_argnums:\n"
f" dynamic argnums now {dyn_argnums_k} and before {dyn_argnums_ok}",
1))
if static_args_k != static_args_ok:
diffs.append(
("different value of static args:\n"
f" now {', '.join(repr(a.val) for a in static_args_k)}"
f" and before {', '.join(repr(a.val) for a in static_args_ok)}",
list_diff_size(static_args_k, static_args_ok)))
def explain_transform_argnames_partial(param_k: tuple, param_ok: tuple):
static_kwargs_k, = param_k
static_kwargs_ok, = param_ok
static_kwargs_k = [(k, v.val) for k, v in
sorted(static_kwargs_k.val.items())]
static_kwargs_ok = [(k, v.val) for k, v in
sorted(static_kwargs_ok.val.items())]
if static_kwargs_k != static_kwargs_ok:
diffs.append(
("different value of static kwargs:\n"
f" now {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_k)}}}"
f" and before {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_ok)}}}",
list_diff_size(static_kwargs_k, static_kwargs_ok)))
def explain_in_tree_diff(in_tree_k: PyTreeDef, in_tree_ok: PyTreeDef):
nonlocal different_leaf_count
different_leaf_count = (in_tree_k.num_leaves != in_tree_ok.num_leaves)
if not different_leaf_count:
# Look for the special case of passing positional args as kwargs or
# vice-versa; the common prefix of positional args match.
args_tree_k, kwargs_tree_k = treedef_children(in_tree_k)
nr_args_k = len(treedef_children(args_tree_k))
args_tree_ok, kwargs_tree_ok = treedef_children(in_tree_ok)
nr_args_ok = len(treedef_children(args_tree_k))
if (treedef_children(args_tree_k)[:min(nr_args_k, nr_args_ok)] ==
treedef_children(args_tree_ok)[:min(nr_args_k, nr_args_ok)]):
keys_k = kwargs_tree_k.node_data()[1] # type: ignore[index]
keys_ok = kwargs_tree_ok.node_data()[1] # type: ignore[index]
diffs.append(
(("different number of args and kwargs, but same total number.\n"
f" now {nr_args_k} args and kwargs "
f"with keys {keys_k}\n"
f" before {nr_args_ok} args and kwargs "
f"with keys {keys_ok}"),
abs(nr_args_ok - nr_args_k)))
return
in_tree_k_str = str(in_tree_k)
in_tree_k_str = (in_tree_k_str if len(in_tree_k_str) < 73
else in_tree_k_str[:73] + "...")
in_tree_ok_str = str(in_tree_ok)
in_tree_ok_str = (in_tree_ok_str if len(in_tree_ok_str) < 73
else in_tree_ok_str[:73] + "...")
diff = [f"different input pytree:\n now: {in_tree_k_str}\n"
f" before: {in_tree_ok_str}"]
errs = list(tree_util.equality_errors_pytreedef(in_tree_k, in_tree_ok))
for path, thing1, thing2, explanation in errs:
fst, *path = path # type: ignore
base = ["args", "kwargs"][fst.idx]
diff.append(
f" * at {base}{keystr(tuple(path))}, now {thing1} and before {thing2},"
f" so {explanation}")
diffs.append(("\n".join(diff), len(errs)))
def explain_args_type_diff(args_k: tuple[core.AbstractValue],
args_ok: tuple[core.AbstractValue]):
diff_size = 0
arg_names = debug_info.safe_arg_names(len(args_k))
def arg_type_to_str(at):
if hasattr(at, "str_short"):
return at.str_short(short_dtypes=True)
else:
return str(at)
args_k_str = ", ".join(f"{an}: {arg_type_to_str(at)}"
for an, at in zip(arg_names, args_k))
args_k_str = args_k_str if len(args_k_str) < 73 else args_k_str[:73] + "..."
diff = [f"different input types:\n types now: {args_k_str}"]
add_weak_type_hint = False
for name, arg_t_k, arg_t_ok in zip(arg_names, args_k, args_ok):
if arg_t_k == arg_t_ok: continue
this_arg_diff_size = 0
if type(arg_t_k) == type(arg_t_ok) == core.ShapedArray:
s1, s2 = arg_type_to_str(arg_t_k), arg_type_to_str(arg_t_ok)
this_arg_diff_size += list_diff_size(arg_t_k.shape, arg_t_ok.shape) # type: ignore
if arg_t_k.weak_type != arg_t_ok.weak_type: # type: ignore
s1 += f"{{weak_type={arg_t_k.weak_type}}}" # type: ignore
s2 += f"{{weak_type={arg_t_ok.weak_type}}}" # type: ignore
add_weak_type_hint = True
this_arg_diff_size += 1
elif arg_t_k.sharding != arg_t_ok.sharding: # type: ignore
s1 = arg_t_k.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore
s2 = arg_t_ok.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore
this_arg_diff_size += 1
else:
s1, s2 = str(arg_t_k), str(arg_t_ok)
diff_size += max(1, this_arg_diff_size)
diff.append(f" * at {name}, now {s1} and before {s2}")
if add_weak_type_hint:
diff.append(
"where weak_type=True often means a Python builtin numeric value, and \n"
"weak_type=False means a jax.Array.\n"
"See https://docs.jax.dev/en/latest/type_promotion.html#weak-types.")
diffs.append(("\n".join(diff), diff_size))
if fun_transforms_k != fun_transforms_ok:
if len(fun_transforms_k) != len(fun_transforms_ok):
different_leaf_count = True # Skip other more precise checks
unavailable("fun_transforms length",
fun_transforms_k, fun_transforms_ok)
else:
for i, (t, ot) in enumerate(zip(fun_transforms_k, fun_transforms_ok)):
t_name = t[0].__name__
if t == ot: continue
if t[0] != ot[0]:
unavailable(f"fun_transforms[{i}] transform", t, ot)
continue
if t_name == "flatten_fun":
explain_in_tree_diff(t[1][0], ot[1][0])
continue
if t_name == "_argnums_partial":
explain_transform_argnums_partial(t[1], ot[1])
continue
if t_name == "_argnames_partial":
explain_transform_argnames_partial(t[1], ot[1])
continue
unavailable(f"fun_transforms.{t_name} params", t[1:], ot[1:])
continue
# If we had different leaf counts, we can discard the _argnums_partial
# difference. That transform sometimes occurs before the flatten_fun
if different_leaf_count:
diffs = [d for d in diffs if "fun_transforms._argnums_partial" not in d[0]]
if fun_params_k != fun_params_ok:
unavailable("fun_params", fun_params_k, fun_params_ok)
if fun_in_type_k != fun_in_type_ok:
unavailable("fun_in_type", fun_params_k, fun_params_ok)
if arg_in_type_k != arg_in_type_ok and not different_leaf_count:
explain_args_type_diff(arg_in_type_k, arg_in_type_ok)
if arg_attr_data_k != arg_attr_data_ok:
unavailable("arg_attr_data", arg_attr_data_k, arg_attr_data_ok)
if arg_inline_k != arg_inline_ok:
unavailable("arg_inline", arg_inline_k, arg_inline_ok)
if ctx_k != ctx_ok:
assert len(ctx_k) == len(ctx_ok)
idxs = [f" [{i}]: now {c_k} and before {c_ok}"
for i, (c_k, c_ok) in enumerate(zip(ctx_k, ctx_ok)) if c_k != c_ok]
diffs.append(
("different tracing context, e.g. due to config or context manager.\n"
"found differences at positions\n" +
", and\n".join(idxs) +
"\ncompare to tuple returned by "
"config.trace_context() in jax/_src/config.py.",
len(idxs)))
if not diffs: # Should never happen, but let's not crash
unavailable("something (unexpected empty diffs)", k, oldk)
diffs_and_sizes = util.unzip2(sorted(diffs, key=lambda d: d[1]))
return (diffs_and_sizes[0], sum(diffs_and_sizes[1]))
def explain_tracing_cache_miss(
fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple):
if config.check_tracer_leaks.value: return
def unpack(key):
transforms, (), _, (in_type, _, inline), *_, ctx = key
# TODO(dougalm,mattjj): enable cache miss explanation with attrs
_, (_, (in_tree,)), *_ = transforms
return in_tree, in_type, inline.val, ctx
in_tree, in_type, inline, ctx = unpack(key)
if inline: return
if key[3][2].val: return # No explanations for "inline" functions
debug_info = fun.debug_info
func_filename = debug_info.func_filename
@ -1177,7 +1368,7 @@ def explain_tracing_cache_miss(
msg: list[str] = []
p = msg.append
done = lambda: logger.log(logging.WARNING, '\n'.join(msg))
done = lambda: logger.log(logging.WARNING, "\n".join(msg))
callsite = source_info_util.summarize(source_info_util.current())
p(f"TRACING CACHE MISS at {callsite} because:")
@ -1188,110 +1379,42 @@ def explain_tracing_cache_miss(
src_info += f" defined at {func_filename}"
if func_lineno := debug_info.func_lineno:
src_info += f":{func_lineno}"
if unseen_f:
p(f" never seen function:\n {debug_info.func_name} id={id(fun.f)}{src_info}")
func_name = debug_info.func_name
if unseen_f or not cache:
p(f" never seen function:\n {func_name} id={id(fun.f)}{src_info}")
if callsite in callsites_with_tracing_cache_miss:
p(" but seen another function defined on the same line; maybe the function is\n"
" being re-defined repeatedly, preventing caching?")
else:
callsites_with_tracing_cache_miss.add(callsite)
return done()
p(f" for {func_name}{src_info}")
diffs = [diff_tracing_cache_keys(key, ok, debug_info)
for ok in cache.keys() if key != ok]
assert diffs, "we must find some diffs if key differs from all cache keys"
min_diff = min(diffs, key=lambda v: v[1])
smallest_diffs: Sequence[Sequence[str]] # the diffs for the closest keys
smallest_diffs = [d[0] for d in diffs if d[1] == min_diff[1]]
def indent_subsequent_lines(indent: int, msg: str) -> str:
return msg.replace("\n", "\n" + " " * indent)
def p_one_diff(diff: Sequence[str]):
for d in diff:
p(" * key with " + indent_subsequent_lines(4, d))
if len(smallest_diffs) == 1:
p(" all previously seen cache keys are different. Closest previous key:")
p_one_diff(smallest_diffs[0])
else:
p(f" for {debug_info.func_name}{src_info}")
p(" all previously seen cache keys are different. "
"Several previous keys are closest:")
for d in smallest_diffs:
p_one_diff(d)
seen_keys = map(unpack, cache.keys())
done()
return
# have we maybe switched some args to be kwargs or visa-versa?
args_tree, kwargs_tree = treedef_children(in_tree)
args_kwargs_trees = [treedef_children(k) for k, *_ in seen_keys]
args_kwargs_match = [t for t in args_kwargs_trees
if t == [args_tree, kwargs_tree]]
if not args_kwargs_match:
num_args = len(treedef_children(args_tree))
_, kwarg_keys = kwargs_tree.node_data() # type: ignore
p(f" never seen passing {num_args} positional args and {len(kwarg_keys)} "
"keyword args with keys:\n"
f" {', '.join(map(repr, kwarg_keys))}")
dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore
if t != [args_tree, kwargs_tree]]
close_kwargs = min(
dont_match, key=set(kwarg_keys).symmetric_difference, default=None
)
if not close_kwargs:
p(" closest seen is passing no keyword args")
else:
p(f" closest seen passes {len(close_kwargs)} keyword args with keys:\n"
f" {', '.join(map(repr, close_kwargs))}")
return done()
# have we never seen this tracing context before?
ctxs_match = [c for *_, c in seen_keys if c == ctx]
if not ctxs_match:
p(" tracing context doesn't match, e.g. due to config or context manager")
dont_match = [c for *_, c in seen_keys if c != ctx]
closest_ctx = min(dont_match, key=lambda c: sum(map(op.ne, c, ctx)))
idxs = [i for i, (c1, c2) in enumerate(zip(ctx, closest_ctx)) if c1 != c2]
p(" closest seen context tuple differs at positions:\n"
f" {', '.join(map(str, idxs))}\n"
" compare to tuple returned by config._trace_context() in jax/_src/config.py.")
return done()
# have we never seen this input pytree before?
trees_match = [k for k in seen_keys if k[0] == in_tree]
if not trees_match:
in_tree_str = f':\n {in_tree}' if len(str(in_tree)) < 76 else ''
p(f" never seen input pytree{in_tree_str}")
dont_match = [t for t, *_ in seen_keys if t != in_tree]
closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves))
errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type]
p(f" closest seen input pytree has {len(errs)} mismatches, including:")
for path, thing1, thing2, explanation in errs:
fst, *path = path # type: ignore
base = ['args', 'kwargs'][fst.idx]
p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1},"
f" so {explanation}")
return done()
# have we never seen these input types (eg shapes, dtypes) before?
types_match = [k for k in trees_match if k[1] == in_type]
if not types_match:
if len(in_type) < 5:
in_type_str = ":\n {}".format(", ".join(
f"{n}: {ty.str_short(short_dtypes=True)}"
for n, ty in zip(debug_info.arg_names, in_type)))
else:
in_type_str = ''
p(f" never seen input type signature{in_type_str}")
dont_match = [t for _, t, *_ in trees_match if t != in_type]
closest_ty = min(dont_match, key=lambda t: sum(map(op.ne, t, in_type)))
num_mismatch = sum(map(op.ne, closest_ty, in_type))
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
add_weak_type_hint = False
arg_names = debug_info.safe_arg_names(len(in_type))
for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
if ty1 != ty2:
if type(ty1) == type(ty2) == core.ShapedArray:
s1, s2 = ty1.str_short(True), ty2.str_short(True)
if ty1.weak_type != ty2.weak_type:
s1 += f"{{weak_type={ty1.weak_type}}}"
s2 += f"{{weak_type={ty2.weak_type}}}"
add_weak_type_hint = True
elif ty1.sharding != ty2.sharding:
s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True)
s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True)
else:
s1, s2 = str(ty1), str(ty2)
p(f" * at {name}, seen {s1}, but now given {s2}")
if add_weak_type_hint:
p("where weak_type=True often means a Python builtin numeric value, and ")
p("weak_type=False means a jax.Array.")
p("See https://docs.jax.dev/en/latest/type_promotion.html#weak-types")
return done()
# we think this is unreachable...
p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax")
return done()
@partial(lu.cache, explain=explain_tracing_cache_miss)
def _create_pjit_jaxpr(

View File

@ -4465,66 +4465,6 @@ class APITest(jtu.JaxTestCase):
tracing_add_count += 1
self.assertEqual(tracing_add_count, 2)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations(self):
@jax.jit
def f(x, y):
return jnp.sin(x) * y['hi']
x = jnp.float32(1.)
y = {'hi': jnp.arange(3., dtype='float32')}
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
# print on first miss, not on hit
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(x, y)
f(x, y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('TRACING CACHE MISS', msg)
self.assertIn('never seen function', msg)
# shape change
y_ = {'hi': jnp.arange(4, dtype='float32')}
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(x, y_)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('never seen input type signature', msg)
self.assertIn('closest seen input type signature has 1 mismatches', msg)
self.assertIn('seen f32[3], but now given f32[4]', msg)
# weak type change (assuming no x64)
if not config.enable_x64.value:
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(1., y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('weak_type=True', msg)
self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg)
# kwarg change
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(1, y=y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('never seen passing 1 positional args and 1 keyword args', msg)
# tracing config change
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
with jax.numpy_rank_promotion('warn'):
f(x, y)
# depending on the backend, we may or may not get persistent cache warnings
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
msg = cm.output[0]
self.assertIn("tracing context doesn't match", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_skip_internals(self):
if is_persistent_cache_enabled():
@ -4535,6 +4475,211 @@ class APITest(jtu.JaxTestCase):
for i in range(2):
jnp.sin(jnp.arange(i + 1, dtype=np.float32))
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_first_miss(self):
@jax.jit
def f(x): return x
x = jnp.float32(1.)
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
# print on first miss, not on hit
with config.explain_cache_misses(True):
with self.assertLogs(level="WARNING") as cm:
f(x)
f(x)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("TRACING CACHE MISS", msg)
self.assertIn("never seen function", msg)
self.assertNotIn("explanation unavailable!", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_other_in_tree(self):
@jax.jit
def f(*args, **kwargs): return args[0]
f(0., 1., y=(2., 2.1))
with config.explain_cache_misses(True):
with self.assertLogs(level="WARNING") as cm:
# Same number of leaves but different trees
f(0., (1., 1.1), y=2.)
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("different input pytree", msg)
self.assertNotIn("explanation unavailable!", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_other_arg_passed_as_kwarg(self):
@jax.jit
def f(x, y): return jnp.sin(x) + y
f(0., 1.)
# kwarg change
with config.explain_cache_misses(True):
with self.assertLogs(level="WARNING") as cm:
f(0., y=1.)
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("different number of args and kwargs, but same total number", msg)
self.assertIn("now 1 args and kwargs with keys ['y']", msg)
self.assertIn("before 1 args and kwargs with keys []", msg)
self.assertNotIn("explanation unavailable!", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_other_static_argnums(self):
@partial(jax.jit, static_argnums=(0, 2))
def f(x, y, z):
return y
f(1., 2., "foo")
with config.explain_cache_misses(True):
with self.assertLogs(level="WARNING") as cm:
f(1., 2., "bar")
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("different value of static args", msg)
self.assertIn("now 1.0, 'bar' and before 1.0, 'foo'", msg)
self.assertNotIn('explanation unavailable!', msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_other_static_argnames(self):
@partial(jax.jit, static_argnames='foo')
def f(*, foo):
return 1
f(foo="foo")
with config.explain_cache_misses(True):
with self.assertLogs(level="WARNING") as cm:
f(foo="bar")
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("different value of static kwargs", msg)
self.assertIn("now {foo: 'bar'} and before {foo: 'foo'}", msg)
self.assertNotIn('explanation unavailable!', msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_other_dtype(self):
@jax.jit
def f(x, y): return x
f(np.float32(0), np.float32(1))
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(np.float32(0), np.int32(1))
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("different input types", msg)
self.assertIn("at y, now i32[] and before f32[]", msg)
self.assertNotIn("explanation unavailable!", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_other_weak_type(self):
@jax.jit
def f(x, y): return jnp.sin(x) + y
y = jnp.arange(4, dtype="float32")
f(jnp.float32(0.), y)
# weak type change (assuming no x64)
if config.enable_x64.value:
self.skipTest("Work only for 32 bit mode")
with config.explain_cache_misses(True):
with self.assertLogs(level="WARNING") as cm:
f(0., y)
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("different input types", msg)
self.assertIn("at x, now f32[]{weak_type=True} and before f32[]{weak_type=False}", msg)
self.assertIn("https://docs.jax.dev/en/latest/type_promotion.html#weak-types", msg)
self.assertNotIn("explanation unavailable!", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_other_shape(self):
@jax.jit
def f(x, y): return jnp.sin(x) + y
f(np.float32(0), np.arange(1, dtype=np.float32))
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(np.float32(0), np.arange(2, dtype=np.float32))
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("different input types", msg)
self.assertIn("at y, now f32[2] and before f32[1]", msg)
self.assertNotIn("explanation unavailable!", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_other_shape_explain_closest(self):
@jax.jit
def f(x): return x
f(np.ones((1, 2), dtype=np.float32))
f(np.ones((10, 20, 30), dtype=np.float32))
f(np.ones((1, 2, 3), dtype=np.float32))
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(np.ones((10, 2, 30), dtype=np.float32))
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("key with different input types", msg)
self.assertIn("at x, now f32[10,2,30] and before f32[10,20,30]", msg)
self.assertNotIn("explanation unavailable!", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_other_tracing_config(self):
@jax.jit
def f(x, y): return jnp.sin(x) + y
f(0., 1.)
# tracing config change
with config.explain_cache_misses(True):
with self.assertLogs(level="WARNING") as cm:
with jax.numpy_rank_promotion("warn"):
with jax.default_matmul_precision("high"):
f(0., 1.)
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
msg = cm.output[0]
self.assertIn("key with different tracing context", msg)
self.assertIn("now warn and before", msg)
self.assertIn("now high and before", msg)
self.assertNotIn("explanation unavailable!", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_multiple_changes(self):
@jax.jit
def f(x): return jnp.sin(x)
call_1 = f(np.arange(4, dtype=np.float32))
with jax.numpy_rank_promotion("warn"):
call_2 = f(np.arange(8, dtype=np.float32))
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
# Matches call_2 in shape but not context, and call_1 in context but
# not in shape.
f(np.arange(8, dtype=np.float32))
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn("key with different input types", msg)
self.assertIn("at x, now f32[8] and before f32[4]", msg)
self.assertIn("key with different tracing context", msg)
self.assertNotIn("explanation unavailable!", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_cache_miss_explanations_new_function_in_loop(self):
@jax.jit

View File

@ -392,66 +392,6 @@ class DebugInfoTest(jtu.JaxTestCase):
with self.assertRaisesRegex(TypeError, err_str):
jax.jit(f)(jnp.int32)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_arg_names_cache_miss_explanations(self):
@jax.jit
def f(x, y):
return jnp.sin(x) * y['hi']
x = jnp.float32(1.)
y = {'hi': jnp.arange(3., dtype='float32')}
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
# print on first miss, not on hit
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(x, y)
f(x, y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('TRACING CACHE MISS', msg)
self.assertIn('never seen function', msg)
# shape change
y_ = {'hi': jnp.arange(4, dtype='float32')}
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(x, y_)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('never seen input type signature', msg)
self.assertIn('closest seen input type signature has 1 mismatches', msg)
self.assertIn('seen f32[3], but now given f32[4]', msg)
# weak type change (assuming no x64)
if not config.enable_x64.value:
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(1., y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('weak_type=True', msg)
self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg)
# kwarg change
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
f(1, y=y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('never seen passing 1 positional args and 1 keyword args', msg)
# tracing config change
with config.explain_cache_misses(True):
with self.assertLogs(level='WARNING') as cm:
with jax.numpy_rank_promotion('warn'):
f(x, y)
# depending on the backend, we may or may not get persistent cache warnings
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
msg = cm.output[0]
self.assertIn("tracing context doesn't match", msg)
@jtu.thread_unsafe_test() # logging is not thread-safe
def test_arg_names_cache_miss_explanations_new_function_in_loop(self):
@jax.jit

View File

@ -3467,9 +3467,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
f(x_, y)
self.assertLen(cm.output, expected_log_len)
msg = cm.output[0]
self.assertIn('never seen input type signature', msg)
self.assertIn('closest seen input type signature has 1 mismatches', msg)
self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg)
self.assertIn("different input types", msg)
self.assertIn("at x, now f32[8]({Auto: ('x',)}) and before f32[8]({})", msg)
def test_pjit_function_cache_cpp(self):
def f(x):