mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[explain-cache-miss] Improve tracing-cache-miss explanations
The previous approach was to report, for several elements of the cache key, the closest mismatch. Some parts of the cache key were ignored, which led to "explanation unavailable". The same happened when we had two keys close to the current one, each differring in a different part of the key. No explanation was produced because for each part of the key, there was a matching key already in the cache, even though the key taken as a whole did not match. Now, we scan *all* parts of they key and compute the differences. We keep track of the "size" of the differences, and we explain the differences to those keys that are closest (possibly more than one key if equidistant). For example, for shape differences we'll report the closest matching shape. If a type differs in both the dtype and some parts of the shape, or sharding, it is considered farther away. We add new tests and explanations for different static argnums and argnames. There are still cases when we do not produce an explanation, but now the "explanation unavailable" includes a description of which component of the key is different, and what the difference is. This may still be hard to understand by the user but at least they can file a clearer bug. Refactored the tests, and added a few new ones.
This commit is contained in:
parent
19d3d954bf
commit
f070cdecb3
333
jax/_src/pjit.py
333
jax/_src/pjit.py
@ -21,7 +21,6 @@ import dataclasses
|
||||
from functools import partial
|
||||
import inspect
|
||||
import logging
|
||||
import operator as op
|
||||
import weakref
|
||||
from typing import NamedTuple, Any, Union, cast
|
||||
import warnings
|
||||
@ -1158,17 +1157,209 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
|
||||
callsites_with_tracing_cache_miss: set[str] = set()
|
||||
|
||||
def diff_tracing_cache_keys(
|
||||
k: tuple, oldk: tuple, debug_info: lu.DebugInfo) -> tuple[Sequence[str], int]:
|
||||
"""Explanations of differences between the cache keys, along with diff sizes.
|
||||
|
||||
Result: a pair of a list of explanations for differences, and the total size
|
||||
of the differences. The sizes are used to pick the old key with the smallest
|
||||
different size for the explanation that is shown to the user.
|
||||
"""
|
||||
(fun_transforms_k, fun_params_k, fun_in_type_k,
|
||||
(arg_in_type_k, arg_attr_data_k, arg_inline_k), ctx_k) = k
|
||||
(fun_transforms_ok, fun_params_ok, fun_in_type_ok,
|
||||
(arg_in_type_ok, arg_attr_data_ok, arg_inline_ok), ctx_ok) = oldk
|
||||
|
||||
diffs: list[tuple[str, int]] = [] # each difference with its size
|
||||
def unavailable(key_field: str, what_k, what_ok):
|
||||
diffs.append(
|
||||
(f"different {key_field}:\n now: {what_k}\n != before: {what_ok}.\n"
|
||||
"explanation unavailable! "
|
||||
"please open an issue at https://github.com/jax-ml/jax.",
|
||||
10))
|
||||
|
||||
def list_diff_size(s1: Sequence, s2: Sequence) -> int:
|
||||
min_len = min(len(s1), len(s2))
|
||||
diff_size = max(len(s1), len(s2)) - min_len
|
||||
diff_size += sum(e1 != e2 for e1, e2 in zip(s1[:min_len],
|
||||
s2[:min_len]))
|
||||
return diff_size
|
||||
|
||||
different_leaf_count = False
|
||||
|
||||
def explain_transform_argnums_partial(param_k: tuple, param_ok: tuple):
|
||||
dyn_argnums_k, static_args_k = param_k
|
||||
dyn_argnums_ok, static_args_ok = param_ok
|
||||
if dyn_argnums_k != dyn_argnums_ok:
|
||||
diffs.append(
|
||||
("different static_argnums:\n"
|
||||
f" dynamic argnums now {dyn_argnums_k} and before {dyn_argnums_ok}",
|
||||
1))
|
||||
if static_args_k != static_args_ok:
|
||||
diffs.append(
|
||||
("different value of static args:\n"
|
||||
f" now {', '.join(repr(a.val) for a in static_args_k)}"
|
||||
f" and before {', '.join(repr(a.val) for a in static_args_ok)}",
|
||||
list_diff_size(static_args_k, static_args_ok)))
|
||||
|
||||
def explain_transform_argnames_partial(param_k: tuple, param_ok: tuple):
|
||||
static_kwargs_k, = param_k
|
||||
static_kwargs_ok, = param_ok
|
||||
static_kwargs_k = [(k, v.val) for k, v in
|
||||
sorted(static_kwargs_k.val.items())]
|
||||
static_kwargs_ok = [(k, v.val) for k, v in
|
||||
sorted(static_kwargs_ok.val.items())]
|
||||
if static_kwargs_k != static_kwargs_ok:
|
||||
diffs.append(
|
||||
("different value of static kwargs:\n"
|
||||
f" now {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_k)}}}"
|
||||
f" and before {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_ok)}}}",
|
||||
list_diff_size(static_kwargs_k, static_kwargs_ok)))
|
||||
|
||||
def explain_in_tree_diff(in_tree_k: PyTreeDef, in_tree_ok: PyTreeDef):
|
||||
nonlocal different_leaf_count
|
||||
different_leaf_count = (in_tree_k.num_leaves != in_tree_ok.num_leaves)
|
||||
if not different_leaf_count:
|
||||
# Look for the special case of passing positional args as kwargs or
|
||||
# vice-versa; the common prefix of positional args match.
|
||||
args_tree_k, kwargs_tree_k = treedef_children(in_tree_k)
|
||||
nr_args_k = len(treedef_children(args_tree_k))
|
||||
args_tree_ok, kwargs_tree_ok = treedef_children(in_tree_ok)
|
||||
nr_args_ok = len(treedef_children(args_tree_k))
|
||||
if (treedef_children(args_tree_k)[:min(nr_args_k, nr_args_ok)] ==
|
||||
treedef_children(args_tree_ok)[:min(nr_args_k, nr_args_ok)]):
|
||||
keys_k = kwargs_tree_k.node_data()[1] # type: ignore[index]
|
||||
keys_ok = kwargs_tree_ok.node_data()[1] # type: ignore[index]
|
||||
diffs.append(
|
||||
(("different number of args and kwargs, but same total number.\n"
|
||||
f" now {nr_args_k} args and kwargs "
|
||||
f"with keys {keys_k}\n"
|
||||
f" before {nr_args_ok} args and kwargs "
|
||||
f"with keys {keys_ok}"),
|
||||
abs(nr_args_ok - nr_args_k)))
|
||||
return
|
||||
|
||||
in_tree_k_str = str(in_tree_k)
|
||||
in_tree_k_str = (in_tree_k_str if len(in_tree_k_str) < 73
|
||||
else in_tree_k_str[:73] + "...")
|
||||
in_tree_ok_str = str(in_tree_ok)
|
||||
in_tree_ok_str = (in_tree_ok_str if len(in_tree_ok_str) < 73
|
||||
else in_tree_ok_str[:73] + "...")
|
||||
diff = [f"different input pytree:\n now: {in_tree_k_str}\n"
|
||||
f" before: {in_tree_ok_str}"]
|
||||
|
||||
errs = list(tree_util.equality_errors_pytreedef(in_tree_k, in_tree_ok))
|
||||
for path, thing1, thing2, explanation in errs:
|
||||
fst, *path = path # type: ignore
|
||||
base = ["args", "kwargs"][fst.idx]
|
||||
diff.append(
|
||||
f" * at {base}{keystr(tuple(path))}, now {thing1} and before {thing2},"
|
||||
f" so {explanation}")
|
||||
diffs.append(("\n".join(diff), len(errs)))
|
||||
|
||||
def explain_args_type_diff(args_k: tuple[core.AbstractValue],
|
||||
args_ok: tuple[core.AbstractValue]):
|
||||
diff_size = 0
|
||||
arg_names = debug_info.safe_arg_names(len(args_k))
|
||||
def arg_type_to_str(at):
|
||||
if hasattr(at, "str_short"):
|
||||
return at.str_short(short_dtypes=True)
|
||||
else:
|
||||
return str(at)
|
||||
args_k_str = ", ".join(f"{an}: {arg_type_to_str(at)}"
|
||||
for an, at in zip(arg_names, args_k))
|
||||
args_k_str = args_k_str if len(args_k_str) < 73 else args_k_str[:73] + "..."
|
||||
diff = [f"different input types:\n types now: {args_k_str}"]
|
||||
add_weak_type_hint = False
|
||||
|
||||
for name, arg_t_k, arg_t_ok in zip(arg_names, args_k, args_ok):
|
||||
if arg_t_k == arg_t_ok: continue
|
||||
this_arg_diff_size = 0
|
||||
if type(arg_t_k) == type(arg_t_ok) == core.ShapedArray:
|
||||
s1, s2 = arg_type_to_str(arg_t_k), arg_type_to_str(arg_t_ok)
|
||||
this_arg_diff_size += list_diff_size(arg_t_k.shape, arg_t_ok.shape) # type: ignore
|
||||
|
||||
if arg_t_k.weak_type != arg_t_ok.weak_type: # type: ignore
|
||||
s1 += f"{{weak_type={arg_t_k.weak_type}}}" # type: ignore
|
||||
s2 += f"{{weak_type={arg_t_ok.weak_type}}}" # type: ignore
|
||||
add_weak_type_hint = True
|
||||
this_arg_diff_size += 1
|
||||
elif arg_t_k.sharding != arg_t_ok.sharding: # type: ignore
|
||||
s1 = arg_t_k.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore
|
||||
s2 = arg_t_ok.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore
|
||||
this_arg_diff_size += 1
|
||||
else:
|
||||
s1, s2 = str(arg_t_k), str(arg_t_ok)
|
||||
diff_size += max(1, this_arg_diff_size)
|
||||
diff.append(f" * at {name}, now {s1} and before {s2}")
|
||||
|
||||
if add_weak_type_hint:
|
||||
diff.append(
|
||||
"where weak_type=True often means a Python builtin numeric value, and \n"
|
||||
"weak_type=False means a jax.Array.\n"
|
||||
"See https://docs.jax.dev/en/latest/type_promotion.html#weak-types.")
|
||||
diffs.append(("\n".join(diff), diff_size))
|
||||
|
||||
if fun_transforms_k != fun_transforms_ok:
|
||||
if len(fun_transforms_k) != len(fun_transforms_ok):
|
||||
different_leaf_count = True # Skip other more precise checks
|
||||
unavailable("fun_transforms length",
|
||||
fun_transforms_k, fun_transforms_ok)
|
||||
else:
|
||||
for i, (t, ot) in enumerate(zip(fun_transforms_k, fun_transforms_ok)):
|
||||
t_name = t[0].__name__
|
||||
if t == ot: continue
|
||||
if t[0] != ot[0]:
|
||||
unavailable(f"fun_transforms[{i}] transform", t, ot)
|
||||
continue
|
||||
|
||||
if t_name == "flatten_fun":
|
||||
explain_in_tree_diff(t[1][0], ot[1][0])
|
||||
continue
|
||||
if t_name == "_argnums_partial":
|
||||
explain_transform_argnums_partial(t[1], ot[1])
|
||||
continue
|
||||
if t_name == "_argnames_partial":
|
||||
explain_transform_argnames_partial(t[1], ot[1])
|
||||
continue
|
||||
unavailable(f"fun_transforms.{t_name} params", t[1:], ot[1:])
|
||||
continue
|
||||
|
||||
# If we had different leaf counts, we can discard the _argnums_partial
|
||||
# difference. That transform sometimes occurs before the flatten_fun
|
||||
if different_leaf_count:
|
||||
diffs = [d for d in diffs if "fun_transforms._argnums_partial" not in d[0]]
|
||||
if fun_params_k != fun_params_ok:
|
||||
unavailable("fun_params", fun_params_k, fun_params_ok)
|
||||
if fun_in_type_k != fun_in_type_ok:
|
||||
unavailable("fun_in_type", fun_params_k, fun_params_ok)
|
||||
if arg_in_type_k != arg_in_type_ok and not different_leaf_count:
|
||||
explain_args_type_diff(arg_in_type_k, arg_in_type_ok)
|
||||
if arg_attr_data_k != arg_attr_data_ok:
|
||||
unavailable("arg_attr_data", arg_attr_data_k, arg_attr_data_ok)
|
||||
if arg_inline_k != arg_inline_ok:
|
||||
unavailable("arg_inline", arg_inline_k, arg_inline_ok)
|
||||
if ctx_k != ctx_ok:
|
||||
assert len(ctx_k) == len(ctx_ok)
|
||||
idxs = [f" [{i}]: now {c_k} and before {c_ok}"
|
||||
for i, (c_k, c_ok) in enumerate(zip(ctx_k, ctx_ok)) if c_k != c_ok]
|
||||
diffs.append(
|
||||
("different tracing context, e.g. due to config or context manager.\n"
|
||||
"found differences at positions\n" +
|
||||
", and\n".join(idxs) +
|
||||
"\ncompare to tuple returned by "
|
||||
"config.trace_context() in jax/_src/config.py.",
|
||||
len(idxs)))
|
||||
if not diffs: # Should never happen, but let's not crash
|
||||
unavailable("something (unexpected empty diffs)", k, oldk)
|
||||
diffs_and_sizes = util.unzip2(sorted(diffs, key=lambda d: d[1]))
|
||||
return (diffs_and_sizes[0], sum(diffs_and_sizes[1]))
|
||||
|
||||
|
||||
def explain_tracing_cache_miss(
|
||||
fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple):
|
||||
if config.check_tracer_leaks.value: return
|
||||
|
||||
def unpack(key):
|
||||
transforms, (), _, (in_type, _, inline), *_, ctx = key
|
||||
# TODO(dougalm,mattjj): enable cache miss explanation with attrs
|
||||
_, (_, (in_tree,)), *_ = transforms
|
||||
return in_tree, in_type, inline.val, ctx
|
||||
in_tree, in_type, inline, ctx = unpack(key)
|
||||
if inline: return
|
||||
if key[3][2].val: return # No explanations for "inline" functions
|
||||
|
||||
debug_info = fun.debug_info
|
||||
func_filename = debug_info.func_filename
|
||||
@ -1177,7 +1368,7 @@ def explain_tracing_cache_miss(
|
||||
|
||||
msg: list[str] = []
|
||||
p = msg.append
|
||||
done = lambda: logger.log(logging.WARNING, '\n'.join(msg))
|
||||
done = lambda: logger.log(logging.WARNING, "\n".join(msg))
|
||||
|
||||
callsite = source_info_util.summarize(source_info_util.current())
|
||||
p(f"TRACING CACHE MISS at {callsite} because:")
|
||||
@ -1188,110 +1379,42 @@ def explain_tracing_cache_miss(
|
||||
src_info += f" defined at {func_filename}"
|
||||
if func_lineno := debug_info.func_lineno:
|
||||
src_info += f":{func_lineno}"
|
||||
if unseen_f:
|
||||
p(f" never seen function:\n {debug_info.func_name} id={id(fun.f)}{src_info}")
|
||||
func_name = debug_info.func_name
|
||||
if unseen_f or not cache:
|
||||
p(f" never seen function:\n {func_name} id={id(fun.f)}{src_info}")
|
||||
if callsite in callsites_with_tracing_cache_miss:
|
||||
p(" but seen another function defined on the same line; maybe the function is\n"
|
||||
" being re-defined repeatedly, preventing caching?")
|
||||
else:
|
||||
callsites_with_tracing_cache_miss.add(callsite)
|
||||
return done()
|
||||
|
||||
p(f" for {func_name}{src_info}")
|
||||
|
||||
diffs = [diff_tracing_cache_keys(key, ok, debug_info)
|
||||
for ok in cache.keys() if key != ok]
|
||||
assert diffs, "we must find some diffs if key differs from all cache keys"
|
||||
min_diff = min(diffs, key=lambda v: v[1])
|
||||
smallest_diffs: Sequence[Sequence[str]] # the diffs for the closest keys
|
||||
smallest_diffs = [d[0] for d in diffs if d[1] == min_diff[1]]
|
||||
def indent_subsequent_lines(indent: int, msg: str) -> str:
|
||||
return msg.replace("\n", "\n" + " " * indent)
|
||||
def p_one_diff(diff: Sequence[str]):
|
||||
for d in diff:
|
||||
p(" * key with " + indent_subsequent_lines(4, d))
|
||||
|
||||
if len(smallest_diffs) == 1:
|
||||
p(" all previously seen cache keys are different. Closest previous key:")
|
||||
p_one_diff(smallest_diffs[0])
|
||||
else:
|
||||
p(f" for {debug_info.func_name}{src_info}")
|
||||
p(" all previously seen cache keys are different. "
|
||||
"Several previous keys are closest:")
|
||||
for d in smallest_diffs:
|
||||
p_one_diff(d)
|
||||
|
||||
seen_keys = map(unpack, cache.keys())
|
||||
done()
|
||||
return
|
||||
|
||||
# have we maybe switched some args to be kwargs or visa-versa?
|
||||
args_tree, kwargs_tree = treedef_children(in_tree)
|
||||
args_kwargs_trees = [treedef_children(k) for k, *_ in seen_keys]
|
||||
args_kwargs_match = [t for t in args_kwargs_trees
|
||||
if t == [args_tree, kwargs_tree]]
|
||||
if not args_kwargs_match:
|
||||
num_args = len(treedef_children(args_tree))
|
||||
_, kwarg_keys = kwargs_tree.node_data() # type: ignore
|
||||
p(f" never seen passing {num_args} positional args and {len(kwarg_keys)} "
|
||||
"keyword args with keys:\n"
|
||||
f" {', '.join(map(repr, kwarg_keys))}")
|
||||
dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore
|
||||
if t != [args_tree, kwargs_tree]]
|
||||
close_kwargs = min(
|
||||
dont_match, key=set(kwarg_keys).symmetric_difference, default=None
|
||||
)
|
||||
if not close_kwargs:
|
||||
p(" closest seen is passing no keyword args")
|
||||
else:
|
||||
p(f" closest seen passes {len(close_kwargs)} keyword args with keys:\n"
|
||||
f" {', '.join(map(repr, close_kwargs))}")
|
||||
return done()
|
||||
|
||||
# have we never seen this tracing context before?
|
||||
ctxs_match = [c for *_, c in seen_keys if c == ctx]
|
||||
if not ctxs_match:
|
||||
p(" tracing context doesn't match, e.g. due to config or context manager")
|
||||
dont_match = [c for *_, c in seen_keys if c != ctx]
|
||||
closest_ctx = min(dont_match, key=lambda c: sum(map(op.ne, c, ctx)))
|
||||
idxs = [i for i, (c1, c2) in enumerate(zip(ctx, closest_ctx)) if c1 != c2]
|
||||
p(" closest seen context tuple differs at positions:\n"
|
||||
f" {', '.join(map(str, idxs))}\n"
|
||||
" compare to tuple returned by config._trace_context() in jax/_src/config.py.")
|
||||
return done()
|
||||
|
||||
# have we never seen this input pytree before?
|
||||
trees_match = [k for k in seen_keys if k[0] == in_tree]
|
||||
if not trees_match:
|
||||
in_tree_str = f':\n {in_tree}' if len(str(in_tree)) < 76 else ''
|
||||
p(f" never seen input pytree{in_tree_str}")
|
||||
dont_match = [t for t, *_ in seen_keys if t != in_tree]
|
||||
closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves))
|
||||
errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type]
|
||||
p(f" closest seen input pytree has {len(errs)} mismatches, including:")
|
||||
for path, thing1, thing2, explanation in errs:
|
||||
fst, *path = path # type: ignore
|
||||
base = ['args', 'kwargs'][fst.idx]
|
||||
p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1},"
|
||||
f" so {explanation}")
|
||||
return done()
|
||||
|
||||
# have we never seen these input types (eg shapes, dtypes) before?
|
||||
types_match = [k for k in trees_match if k[1] == in_type]
|
||||
if not types_match:
|
||||
if len(in_type) < 5:
|
||||
in_type_str = ":\n {}".format(", ".join(
|
||||
f"{n}: {ty.str_short(short_dtypes=True)}"
|
||||
for n, ty in zip(debug_info.arg_names, in_type)))
|
||||
else:
|
||||
in_type_str = ''
|
||||
p(f" never seen input type signature{in_type_str}")
|
||||
dont_match = [t for _, t, *_ in trees_match if t != in_type]
|
||||
closest_ty = min(dont_match, key=lambda t: sum(map(op.ne, t, in_type)))
|
||||
num_mismatch = sum(map(op.ne, closest_ty, in_type))
|
||||
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
|
||||
add_weak_type_hint = False
|
||||
arg_names = debug_info.safe_arg_names(len(in_type))
|
||||
|
||||
for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
|
||||
if ty1 != ty2:
|
||||
if type(ty1) == type(ty2) == core.ShapedArray:
|
||||
s1, s2 = ty1.str_short(True), ty2.str_short(True)
|
||||
if ty1.weak_type != ty2.weak_type:
|
||||
s1 += f"{{weak_type={ty1.weak_type}}}"
|
||||
s2 += f"{{weak_type={ty2.weak_type}}}"
|
||||
add_weak_type_hint = True
|
||||
elif ty1.sharding != ty2.sharding:
|
||||
s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True)
|
||||
s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True)
|
||||
else:
|
||||
s1, s2 = str(ty1), str(ty2)
|
||||
p(f" * at {name}, seen {s1}, but now given {s2}")
|
||||
if add_weak_type_hint:
|
||||
p("where weak_type=True often means a Python builtin numeric value, and ")
|
||||
p("weak_type=False means a jax.Array.")
|
||||
p("See https://docs.jax.dev/en/latest/type_promotion.html#weak-types")
|
||||
return done()
|
||||
|
||||
# we think this is unreachable...
|
||||
p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax")
|
||||
return done()
|
||||
|
||||
@partial(lu.cache, explain=explain_tracing_cache_miss)
|
||||
def _create_pjit_jaxpr(
|
||||
|
@ -4465,66 +4465,6 @@ class APITest(jtu.JaxTestCase):
|
||||
tracing_add_count += 1
|
||||
self.assertEqual(tracing_add_count, 2)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations(self):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return jnp.sin(x) * y['hi']
|
||||
|
||||
x = jnp.float32(1.)
|
||||
y = {'hi': jnp.arange(3., dtype='float32')}
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
|
||||
# print on first miss, not on hit
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(x, y)
|
||||
f(x, y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('TRACING CACHE MISS', msg)
|
||||
self.assertIn('never seen function', msg)
|
||||
|
||||
# shape change
|
||||
y_ = {'hi': jnp.arange(4, dtype='float32')}
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(x, y_)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen input type signature', msg)
|
||||
self.assertIn('closest seen input type signature has 1 mismatches', msg)
|
||||
self.assertIn('seen f32[3], but now given f32[4]', msg)
|
||||
|
||||
# weak type change (assuming no x64)
|
||||
if not config.enable_x64.value:
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(1., y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('weak_type=True', msg)
|
||||
self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg)
|
||||
|
||||
# kwarg change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(1, y=y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen passing 1 positional args and 1 keyword args', msg)
|
||||
|
||||
# tracing config change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
with jax.numpy_rank_promotion('warn'):
|
||||
f(x, y)
|
||||
# depending on the backend, we may or may not get persistent cache warnings
|
||||
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("tracing context doesn't match", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_skip_internals(self):
|
||||
if is_persistent_cache_enabled():
|
||||
@ -4535,6 +4475,211 @@ class APITest(jtu.JaxTestCase):
|
||||
for i in range(2):
|
||||
jnp.sin(jnp.arange(i + 1, dtype=np.float32))
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_first_miss(self):
|
||||
@jax.jit
|
||||
def f(x): return x
|
||||
x = jnp.float32(1.)
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
# print on first miss, not on hit
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(x)
|
||||
f(x)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("TRACING CACHE MISS", msg)
|
||||
self.assertIn("never seen function", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_in_tree(self):
|
||||
@jax.jit
|
||||
def f(*args, **kwargs): return args[0]
|
||||
|
||||
f(0., 1., y=(2., 2.1))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
# Same number of leaves but different trees
|
||||
f(0., (1., 1.1), y=2.)
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different input pytree", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_arg_passed_as_kwarg(self):
|
||||
@jax.jit
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
|
||||
f(0., 1.)
|
||||
|
||||
# kwarg change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(0., y=1.)
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different number of args and kwargs, but same total number", msg)
|
||||
self.assertIn("now 1 args and kwargs with keys ['y']", msg)
|
||||
self.assertIn("before 1 args and kwargs with keys []", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_static_argnums(self):
|
||||
@partial(jax.jit, static_argnums=(0, 2))
|
||||
def f(x, y, z):
|
||||
return y
|
||||
|
||||
f(1., 2., "foo")
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(1., 2., "bar")
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different value of static args", msg)
|
||||
self.assertIn("now 1.0, 'bar' and before 1.0, 'foo'", msg)
|
||||
self.assertNotIn('explanation unavailable!', msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_static_argnames(self):
|
||||
@partial(jax.jit, static_argnames='foo')
|
||||
def f(*, foo):
|
||||
return 1
|
||||
|
||||
f(foo="foo")
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(foo="bar")
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different value of static kwargs", msg)
|
||||
self.assertIn("now {foo: 'bar'} and before {foo: 'foo'}", msg)
|
||||
self.assertNotIn('explanation unavailable!', msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_dtype(self):
|
||||
@jax.jit
|
||||
def f(x, y): return x
|
||||
f(np.float32(0), np.float32(1))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(np.float32(0), np.int32(1))
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different input types", msg)
|
||||
self.assertIn("at y, now i32[] and before f32[]", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_weak_type(self):
|
||||
@jax.jit
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
|
||||
y = jnp.arange(4, dtype="float32")
|
||||
f(jnp.float32(0.), y)
|
||||
# weak type change (assuming no x64)
|
||||
if config.enable_x64.value:
|
||||
self.skipTest("Work only for 32 bit mode")
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(0., y)
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different input types", msg)
|
||||
self.assertIn("at x, now f32[]{weak_type=True} and before f32[]{weak_type=False}", msg)
|
||||
self.assertIn("https://docs.jax.dev/en/latest/type_promotion.html#weak-types", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_shape(self):
|
||||
@jax.jit
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
f(np.float32(0), np.arange(1, dtype=np.float32))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(np.float32(0), np.arange(2, dtype=np.float32))
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different input types", msg)
|
||||
self.assertIn("at y, now f32[2] and before f32[1]", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_shape_explain_closest(self):
|
||||
@jax.jit
|
||||
def f(x): return x
|
||||
f(np.ones((1, 2), dtype=np.float32))
|
||||
f(np.ones((10, 20, 30), dtype=np.float32))
|
||||
f(np.ones((1, 2, 3), dtype=np.float32))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(np.ones((10, 2, 30), dtype=np.float32))
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("key with different input types", msg)
|
||||
self.assertIn("at x, now f32[10,2,30] and before f32[10,20,30]", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_tracing_config(self):
|
||||
@jax.jit
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
|
||||
f(0., 1.)
|
||||
# tracing config change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
with jax.numpy_rank_promotion("warn"):
|
||||
with jax.default_matmul_precision("high"):
|
||||
f(0., 1.)
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("key with different tracing context", msg)
|
||||
self.assertIn("now warn and before", msg)
|
||||
self.assertIn("now high and before", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_multiple_changes(self):
|
||||
@jax.jit
|
||||
def f(x): return jnp.sin(x)
|
||||
|
||||
call_1 = f(np.arange(4, dtype=np.float32))
|
||||
with jax.numpy_rank_promotion("warn"):
|
||||
call_2 = f(np.arange(8, dtype=np.float32))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
# Matches call_2 in shape but not context, and call_1 in context but
|
||||
# not in shape.
|
||||
f(np.arange(8, dtype=np.float32))
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("key with different input types", msg)
|
||||
self.assertIn("at x, now f32[8] and before f32[4]", msg)
|
||||
self.assertIn("key with different tracing context", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_new_function_in_loop(self):
|
||||
@jax.jit
|
||||
|
@ -392,66 +392,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, err_str):
|
||||
jax.jit(f)(jnp.int32)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_arg_names_cache_miss_explanations(self):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return jnp.sin(x) * y['hi']
|
||||
|
||||
x = jnp.float32(1.)
|
||||
y = {'hi': jnp.arange(3., dtype='float32')}
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
|
||||
# print on first miss, not on hit
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(x, y)
|
||||
f(x, y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('TRACING CACHE MISS', msg)
|
||||
self.assertIn('never seen function', msg)
|
||||
|
||||
# shape change
|
||||
y_ = {'hi': jnp.arange(4, dtype='float32')}
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(x, y_)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen input type signature', msg)
|
||||
self.assertIn('closest seen input type signature has 1 mismatches', msg)
|
||||
self.assertIn('seen f32[3], but now given f32[4]', msg)
|
||||
|
||||
# weak type change (assuming no x64)
|
||||
if not config.enable_x64.value:
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(1., y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('weak_type=True', msg)
|
||||
self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg)
|
||||
|
||||
# kwarg change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(1, y=y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen passing 1 positional args and 1 keyword args', msg)
|
||||
|
||||
# tracing config change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
with jax.numpy_rank_promotion('warn'):
|
||||
f(x, y)
|
||||
# depending on the backend, we may or may not get persistent cache warnings
|
||||
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("tracing context doesn't match", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_arg_names_cache_miss_explanations_new_function_in_loop(self):
|
||||
@jax.jit
|
||||
|
@ -3467,9 +3467,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
f(x_, y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen input type signature', msg)
|
||||
self.assertIn('closest seen input type signature has 1 mismatches', msg)
|
||||
self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg)
|
||||
self.assertIn("different input types", msg)
|
||||
self.assertIn("at x, now f32[8]({Auto: ('x',)}) and before f32[8]({})", msg)
|
||||
|
||||
def test_pjit_function_cache_cpp(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user