diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d1e5c3bfb..afc7a5bed 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -21,7 +21,6 @@ import dataclasses from functools import partial import inspect import logging -import operator as op import weakref from typing import NamedTuple, Any, Union, cast import warnings @@ -1158,17 +1157,209 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, callsites_with_tracing_cache_miss: set[str] = set() +def diff_tracing_cache_keys( + k: tuple, oldk: tuple, debug_info: lu.DebugInfo) -> tuple[Sequence[str], int]: + """Explanations of differences between the cache keys, along with diff sizes. + + Result: a pair of a list of explanations for differences, and the total size + of the differences. The sizes are used to pick the old key with the smallest + different size for the explanation that is shown to the user. + """ + (fun_transforms_k, fun_params_k, fun_in_type_k, + (arg_in_type_k, arg_attr_data_k, arg_inline_k), ctx_k) = k + (fun_transforms_ok, fun_params_ok, fun_in_type_ok, + (arg_in_type_ok, arg_attr_data_ok, arg_inline_ok), ctx_ok) = oldk + + diffs: list[tuple[str, int]] = [] # each difference with its size + def unavailable(key_field: str, what_k, what_ok): + diffs.append( + (f"different {key_field}:\n now: {what_k}\n != before: {what_ok}.\n" + "explanation unavailable! " + "please open an issue at https://github.com/jax-ml/jax.", + 10)) + + def list_diff_size(s1: Sequence, s2: Sequence) -> int: + min_len = min(len(s1), len(s2)) + diff_size = max(len(s1), len(s2)) - min_len + diff_size += sum(e1 != e2 for e1, e2 in zip(s1[:min_len], + s2[:min_len])) + return diff_size + + different_leaf_count = False + + def explain_transform_argnums_partial(param_k: tuple, param_ok: tuple): + dyn_argnums_k, static_args_k = param_k + dyn_argnums_ok, static_args_ok = param_ok + if dyn_argnums_k != dyn_argnums_ok: + diffs.append( + ("different static_argnums:\n" + f" dynamic argnums now {dyn_argnums_k} and before {dyn_argnums_ok}", + 1)) + if static_args_k != static_args_ok: + diffs.append( + ("different value of static args:\n" + f" now {', '.join(repr(a.val) for a in static_args_k)}" + f" and before {', '.join(repr(a.val) for a in static_args_ok)}", + list_diff_size(static_args_k, static_args_ok))) + + def explain_transform_argnames_partial(param_k: tuple, param_ok: tuple): + static_kwargs_k, = param_k + static_kwargs_ok, = param_ok + static_kwargs_k = [(k, v.val) for k, v in + sorted(static_kwargs_k.val.items())] + static_kwargs_ok = [(k, v.val) for k, v in + sorted(static_kwargs_ok.val.items())] + if static_kwargs_k != static_kwargs_ok: + diffs.append( + ("different value of static kwargs:\n" + f" now {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_k)}}}" + f" and before {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_ok)}}}", + list_diff_size(static_kwargs_k, static_kwargs_ok))) + + def explain_in_tree_diff(in_tree_k: PyTreeDef, in_tree_ok: PyTreeDef): + nonlocal different_leaf_count + different_leaf_count = (in_tree_k.num_leaves != in_tree_ok.num_leaves) + if not different_leaf_count: + # Look for the special case of passing positional args as kwargs or + # vice-versa; the common prefix of positional args match. + args_tree_k, kwargs_tree_k = treedef_children(in_tree_k) + nr_args_k = len(treedef_children(args_tree_k)) + args_tree_ok, kwargs_tree_ok = treedef_children(in_tree_ok) + nr_args_ok = len(treedef_children(args_tree_k)) + if (treedef_children(args_tree_k)[:min(nr_args_k, nr_args_ok)] == + treedef_children(args_tree_ok)[:min(nr_args_k, nr_args_ok)]): + keys_k = kwargs_tree_k.node_data()[1] # type: ignore[index] + keys_ok = kwargs_tree_ok.node_data()[1] # type: ignore[index] + diffs.append( + (("different number of args and kwargs, but same total number.\n" + f" now {nr_args_k} args and kwargs " + f"with keys {keys_k}\n" + f" before {nr_args_ok} args and kwargs " + f"with keys {keys_ok}"), + abs(nr_args_ok - nr_args_k))) + return + + in_tree_k_str = str(in_tree_k) + in_tree_k_str = (in_tree_k_str if len(in_tree_k_str) < 73 + else in_tree_k_str[:73] + "...") + in_tree_ok_str = str(in_tree_ok) + in_tree_ok_str = (in_tree_ok_str if len(in_tree_ok_str) < 73 + else in_tree_ok_str[:73] + "...") + diff = [f"different input pytree:\n now: {in_tree_k_str}\n" + f" before: {in_tree_ok_str}"] + + errs = list(tree_util.equality_errors_pytreedef(in_tree_k, in_tree_ok)) + for path, thing1, thing2, explanation in errs: + fst, *path = path # type: ignore + base = ["args", "kwargs"][fst.idx] + diff.append( + f" * at {base}{keystr(tuple(path))}, now {thing1} and before {thing2}," + f" so {explanation}") + diffs.append(("\n".join(diff), len(errs))) + + def explain_args_type_diff(args_k: tuple[core.AbstractValue], + args_ok: tuple[core.AbstractValue]): + diff_size = 0 + arg_names = debug_info.safe_arg_names(len(args_k)) + def arg_type_to_str(at): + if hasattr(at, "str_short"): + return at.str_short(short_dtypes=True) + else: + return str(at) + args_k_str = ", ".join(f"{an}: {arg_type_to_str(at)}" + for an, at in zip(arg_names, args_k)) + args_k_str = args_k_str if len(args_k_str) < 73 else args_k_str[:73] + "..." + diff = [f"different input types:\n types now: {args_k_str}"] + add_weak_type_hint = False + + for name, arg_t_k, arg_t_ok in zip(arg_names, args_k, args_ok): + if arg_t_k == arg_t_ok: continue + this_arg_diff_size = 0 + if type(arg_t_k) == type(arg_t_ok) == core.ShapedArray: + s1, s2 = arg_type_to_str(arg_t_k), arg_type_to_str(arg_t_ok) + this_arg_diff_size += list_diff_size(arg_t_k.shape, arg_t_ok.shape) # type: ignore + + if arg_t_k.weak_type != arg_t_ok.weak_type: # type: ignore + s1 += f"{{weak_type={arg_t_k.weak_type}}}" # type: ignore + s2 += f"{{weak_type={arg_t_ok.weak_type}}}" # type: ignore + add_weak_type_hint = True + this_arg_diff_size += 1 + elif arg_t_k.sharding != arg_t_ok.sharding: # type: ignore + s1 = arg_t_k.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore + s2 = arg_t_ok.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore + this_arg_diff_size += 1 + else: + s1, s2 = str(arg_t_k), str(arg_t_ok) + diff_size += max(1, this_arg_diff_size) + diff.append(f" * at {name}, now {s1} and before {s2}") + + if add_weak_type_hint: + diff.append( + "where weak_type=True often means a Python builtin numeric value, and \n" + "weak_type=False means a jax.Array.\n" + "See https://docs.jax.dev/en/latest/type_promotion.html#weak-types.") + diffs.append(("\n".join(diff), diff_size)) + + if fun_transforms_k != fun_transforms_ok: + if len(fun_transforms_k) != len(fun_transforms_ok): + different_leaf_count = True # Skip other more precise checks + unavailable("fun_transforms length", + fun_transforms_k, fun_transforms_ok) + else: + for i, (t, ot) in enumerate(zip(fun_transforms_k, fun_transforms_ok)): + t_name = t[0].__name__ + if t == ot: continue + if t[0] != ot[0]: + unavailable(f"fun_transforms[{i}] transform", t, ot) + continue + + if t_name == "flatten_fun": + explain_in_tree_diff(t[1][0], ot[1][0]) + continue + if t_name == "_argnums_partial": + explain_transform_argnums_partial(t[1], ot[1]) + continue + if t_name == "_argnames_partial": + explain_transform_argnames_partial(t[1], ot[1]) + continue + unavailable(f"fun_transforms.{t_name} params", t[1:], ot[1:]) + continue + + # If we had different leaf counts, we can discard the _argnums_partial + # difference. That transform sometimes occurs before the flatten_fun + if different_leaf_count: + diffs = [d for d in diffs if "fun_transforms._argnums_partial" not in d[0]] + if fun_params_k != fun_params_ok: + unavailable("fun_params", fun_params_k, fun_params_ok) + if fun_in_type_k != fun_in_type_ok: + unavailable("fun_in_type", fun_params_k, fun_params_ok) + if arg_in_type_k != arg_in_type_ok and not different_leaf_count: + explain_args_type_diff(arg_in_type_k, arg_in_type_ok) + if arg_attr_data_k != arg_attr_data_ok: + unavailable("arg_attr_data", arg_attr_data_k, arg_attr_data_ok) + if arg_inline_k != arg_inline_ok: + unavailable("arg_inline", arg_inline_k, arg_inline_ok) + if ctx_k != ctx_ok: + assert len(ctx_k) == len(ctx_ok) + idxs = [f" [{i}]: now {c_k} and before {c_ok}" + for i, (c_k, c_ok) in enumerate(zip(ctx_k, ctx_ok)) if c_k != c_ok] + diffs.append( + ("different tracing context, e.g. due to config or context manager.\n" + "found differences at positions\n" + + ", and\n".join(idxs) + + "\ncompare to tuple returned by " + "config.trace_context() in jax/_src/config.py.", + len(idxs))) + if not diffs: # Should never happen, but let's not crash + unavailable("something (unexpected empty diffs)", k, oldk) + diffs_and_sizes = util.unzip2(sorted(diffs, key=lambda d: d[1])) + return (diffs_and_sizes[0], sum(diffs_and_sizes[1])) + + def explain_tracing_cache_miss( fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple): if config.check_tracer_leaks.value: return - - def unpack(key): - transforms, (), _, (in_type, _, inline), *_, ctx = key - # TODO(dougalm,mattjj): enable cache miss explanation with attrs - _, (_, (in_tree,)), *_ = transforms - return in_tree, in_type, inline.val, ctx - in_tree, in_type, inline, ctx = unpack(key) - if inline: return + if key[3][2].val: return # No explanations for "inline" functions debug_info = fun.debug_info func_filename = debug_info.func_filename @@ -1177,7 +1368,7 @@ def explain_tracing_cache_miss( msg: list[str] = [] p = msg.append - done = lambda: logger.log(logging.WARNING, '\n'.join(msg)) + done = lambda: logger.log(logging.WARNING, "\n".join(msg)) callsite = source_info_util.summarize(source_info_util.current()) p(f"TRACING CACHE MISS at {callsite} because:") @@ -1188,110 +1379,42 @@ def explain_tracing_cache_miss( src_info += f" defined at {func_filename}" if func_lineno := debug_info.func_lineno: src_info += f":{func_lineno}" - if unseen_f: - p(f" never seen function:\n {debug_info.func_name} id={id(fun.f)}{src_info}") + func_name = debug_info.func_name + if unseen_f or not cache: + p(f" never seen function:\n {func_name} id={id(fun.f)}{src_info}") if callsite in callsites_with_tracing_cache_miss: p(" but seen another function defined on the same line; maybe the function is\n" " being re-defined repeatedly, preventing caching?") else: callsites_with_tracing_cache_miss.add(callsite) return done() + + p(f" for {func_name}{src_info}") + + diffs = [diff_tracing_cache_keys(key, ok, debug_info) + for ok in cache.keys() if key != ok] + assert diffs, "we must find some diffs if key differs from all cache keys" + min_diff = min(diffs, key=lambda v: v[1]) + smallest_diffs: Sequence[Sequence[str]] # the diffs for the closest keys + smallest_diffs = [d[0] for d in diffs if d[1] == min_diff[1]] + def indent_subsequent_lines(indent: int, msg: str) -> str: + return msg.replace("\n", "\n" + " " * indent) + def p_one_diff(diff: Sequence[str]): + for d in diff: + p(" * key with " + indent_subsequent_lines(4, d)) + + if len(smallest_diffs) == 1: + p(" all previously seen cache keys are different. Closest previous key:") + p_one_diff(smallest_diffs[0]) else: - p(f" for {debug_info.func_name}{src_info}") + p(" all previously seen cache keys are different. " + "Several previous keys are closest:") + for d in smallest_diffs: + p_one_diff(d) - seen_keys = map(unpack, cache.keys()) + done() + return - # have we maybe switched some args to be kwargs or visa-versa? - args_tree, kwargs_tree = treedef_children(in_tree) - args_kwargs_trees = [treedef_children(k) for k, *_ in seen_keys] - args_kwargs_match = [t for t in args_kwargs_trees - if t == [args_tree, kwargs_tree]] - if not args_kwargs_match: - num_args = len(treedef_children(args_tree)) - _, kwarg_keys = kwargs_tree.node_data() # type: ignore - p(f" never seen passing {num_args} positional args and {len(kwarg_keys)} " - "keyword args with keys:\n" - f" {', '.join(map(repr, kwarg_keys))}") - dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore - if t != [args_tree, kwargs_tree]] - close_kwargs = min( - dont_match, key=set(kwarg_keys).symmetric_difference, default=None - ) - if not close_kwargs: - p(" closest seen is passing no keyword args") - else: - p(f" closest seen passes {len(close_kwargs)} keyword args with keys:\n" - f" {', '.join(map(repr, close_kwargs))}") - return done() - - # have we never seen this tracing context before? - ctxs_match = [c for *_, c in seen_keys if c == ctx] - if not ctxs_match: - p(" tracing context doesn't match, e.g. due to config or context manager") - dont_match = [c for *_, c in seen_keys if c != ctx] - closest_ctx = min(dont_match, key=lambda c: sum(map(op.ne, c, ctx))) - idxs = [i for i, (c1, c2) in enumerate(zip(ctx, closest_ctx)) if c1 != c2] - p(" closest seen context tuple differs at positions:\n" - f" {', '.join(map(str, idxs))}\n" - " compare to tuple returned by config._trace_context() in jax/_src/config.py.") - return done() - - # have we never seen this input pytree before? - trees_match = [k for k in seen_keys if k[0] == in_tree] - if not trees_match: - in_tree_str = f':\n {in_tree}' if len(str(in_tree)) < 76 else '' - p(f" never seen input pytree{in_tree_str}") - dont_match = [t for t, *_ in seen_keys if t != in_tree] - closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves)) - errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type] - p(f" closest seen input pytree has {len(errs)} mismatches, including:") - for path, thing1, thing2, explanation in errs: - fst, *path = path # type: ignore - base = ['args', 'kwargs'][fst.idx] - p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1}," - f" so {explanation}") - return done() - - # have we never seen these input types (eg shapes, dtypes) before? - types_match = [k for k in trees_match if k[1] == in_type] - if not types_match: - if len(in_type) < 5: - in_type_str = ":\n {}".format(", ".join( - f"{n}: {ty.str_short(short_dtypes=True)}" - for n, ty in zip(debug_info.arg_names, in_type))) - else: - in_type_str = '' - p(f" never seen input type signature{in_type_str}") - dont_match = [t for _, t, *_ in trees_match if t != in_type] - closest_ty = min(dont_match, key=lambda t: sum(map(op.ne, t, in_type))) - num_mismatch = sum(map(op.ne, closest_ty, in_type)) - p(f" closest seen input type signature has {num_mismatch} mismatches, including:") - add_weak_type_hint = False - arg_names = debug_info.safe_arg_names(len(in_type)) - - for name, ty1, ty2 in zip(arg_names, closest_ty, in_type): - if ty1 != ty2: - if type(ty1) == type(ty2) == core.ShapedArray: - s1, s2 = ty1.str_short(True), ty2.str_short(True) - if ty1.weak_type != ty2.weak_type: - s1 += f"{{weak_type={ty1.weak_type}}}" - s2 += f"{{weak_type={ty2.weak_type}}}" - add_weak_type_hint = True - elif ty1.sharding != ty2.sharding: - s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True) - s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True) - else: - s1, s2 = str(ty1), str(ty2) - p(f" * at {name}, seen {s1}, but now given {s2}") - if add_weak_type_hint: - p("where weak_type=True often means a Python builtin numeric value, and ") - p("weak_type=False means a jax.Array.") - p("See https://docs.jax.dev/en/latest/type_promotion.html#weak-types") - return done() - - # we think this is unreachable... - p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax") - return done() @partial(lu.cache, explain=explain_tracing_cache_miss) def _create_pjit_jaxpr( diff --git a/tests/api_test.py b/tests/api_test.py index a5e192a9f..8705d2021 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -4465,66 +4465,6 @@ class APITest(jtu.JaxTestCase): tracing_add_count += 1 self.assertEqual(tracing_add_count, 2) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] - - x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} - - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - - # print on first miss, not on hit - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) - - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y_) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) - - # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg) - - # kwarg change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1, y=y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) - - # tracing config change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings - self.assertTrue(1 <= len(cm.output) <= expected_log_len) - msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) - @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_skip_internals(self): if is_persistent_cache_enabled(): @@ -4535,6 +4475,211 @@ class APITest(jtu.JaxTestCase): for i in range(2): jnp.sin(jnp.arange(i + 1, dtype=np.float32)) + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_first_miss(self): + @jax.jit + def f(x): return x + x = jnp.float32(1.) + + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + # print on first miss, not on hit + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(x) + f(x) + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("TRACING CACHE MISS", msg) + self.assertIn("never seen function", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_in_tree(self): + @jax.jit + def f(*args, **kwargs): return args[0] + + f(0., 1., y=(2., 2.1)) + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + # Same number of leaves but different trees + f(0., (1., 1.1), y=2.) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different input pytree", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_arg_passed_as_kwarg(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + f(0., 1.) + + # kwarg change + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y=1.) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different number of args and kwargs, but same total number", msg) + self.assertIn("now 1 args and kwargs with keys ['y']", msg) + self.assertIn("before 1 args and kwargs with keys []", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnums(self): + @partial(jax.jit, static_argnums=(0, 2)) + def f(x, y, z): + return y + + f(1., 2., "foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(1., 2., "bar") + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different value of static args", msg) + self.assertIn("now 1.0, 'bar' and before 1.0, 'foo'", msg) + self.assertNotIn('explanation unavailable!', msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_static_argnames(self): + @partial(jax.jit, static_argnames='foo') + def f(*, foo): + return 1 + + f(foo="foo") + + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(foo="bar") + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different value of static kwargs", msg) + self.assertIn("now {foo: 'bar'} and before {foo: 'foo'}", msg) + self.assertNotIn('explanation unavailable!', msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_dtype(self): + @jax.jit + def f(x, y): return x + f(np.float32(0), np.float32(1)) + + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + f(np.float32(0), np.int32(1)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different input types", msg) + self.assertIn("at y, now i32[] and before f32[]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_weak_type(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + y = jnp.arange(4, dtype="float32") + f(jnp.float32(0.), y) + # weak type change (assuming no x64) + if config.enable_x64.value: + self.skipTest("Work only for 32 bit mode") + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + f(0., y) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[]{weak_type=True} and before f32[]{weak_type=False}", msg) + self.assertIn("https://docs.jax.dev/en/latest/type_promotion.html#weak-types", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + f(np.float32(0), np.arange(1, dtype=np.float32)) + + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + f(np.float32(0), np.arange(2, dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("different input types", msg) + self.assertIn("at y, now f32[2] and before f32[1]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_shape_explain_closest(self): + @jax.jit + def f(x): return x + f(np.ones((1, 2), dtype=np.float32)) + f(np.ones((10, 20, 30), dtype=np.float32)) + f(np.ones((1, 2, 3), dtype=np.float32)) + + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + f(np.ones((10, 2, 30), dtype=np.float32)) + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[10,2,30] and before f32[10,20,30]", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_other_tracing_config(self): + @jax.jit + def f(x, y): return jnp.sin(x) + y + + f(0., 1.) + # tracing config change + with config.explain_cache_misses(True): + with self.assertLogs(level="WARNING") as cm: + with jax.numpy_rank_promotion("warn"): + with jax.default_matmul_precision("high"): + f(0., 1.) + + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertTrue(1 <= len(cm.output) <= expected_log_len) + msg = cm.output[0] + self.assertIn("key with different tracing context", msg) + self.assertIn("now warn and before", msg) + self.assertIn("now high and before", msg) + self.assertNotIn("explanation unavailable!", msg) + + @jtu.thread_unsafe_test() # logging is not thread-safe + def test_cache_miss_explanations_multiple_changes(self): + @jax.jit + def f(x): return jnp.sin(x) + + call_1 = f(np.arange(4, dtype=np.float32)) + with jax.numpy_rank_promotion("warn"): + call_2 = f(np.arange(8, dtype=np.float32)) + + with config.explain_cache_misses(True): + with self.assertLogs(level='WARNING') as cm: + # Matches call_2 in shape but not context, and call_1 in context but + # not in shape. + f(np.arange(8, dtype=np.float32)) + + expected_log_len = 1 if not is_persistent_cache_enabled() else 3 + self.assertLen(cm.output, expected_log_len) + msg = cm.output[0] + self.assertIn("key with different input types", msg) + self.assertIn("at x, now f32[8] and before f32[4]", msg) + self.assertIn("key with different tracing context", msg) + self.assertNotIn("explanation unavailable!", msg) + @jtu.thread_unsafe_test() # logging is not thread-safe def test_cache_miss_explanations_new_function_in_loop(self): @jax.jit diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index 1f5ddba89..0fc1aabba 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -392,66 +392,6 @@ class DebugInfoTest(jtu.JaxTestCase): with self.assertRaisesRegex(TypeError, err_str): jax.jit(f)(jnp.int32) - @jtu.thread_unsafe_test() # logging is not thread-safe - def test_arg_names_cache_miss_explanations(self): - @jax.jit - def f(x, y): - return jnp.sin(x) * y['hi'] - - x = jnp.float32(1.) - y = {'hi': jnp.arange(3., dtype='float32')} - - expected_log_len = 1 if not is_persistent_cache_enabled() else 3 - - # print on first miss, not on hit - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y) - f(x, y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('TRACING CACHE MISS', msg) - self.assertIn('never seen function', msg) - - # shape change - y_ = {'hi': jnp.arange(4, dtype='float32')} - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(x, y_) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn('seen f32[3], but now given f32[4]', msg) - - # weak type change (assuming no x64) - if not config.enable_x64.value: - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1., y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('weak_type=True', msg) - self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg) - - # kwarg change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - f(1, y=y) - self.assertLen(cm.output, expected_log_len) - msg = cm.output[0] - self.assertIn('never seen passing 1 positional args and 1 keyword args', msg) - - # tracing config change - with config.explain_cache_misses(True): - with self.assertLogs(level='WARNING') as cm: - with jax.numpy_rank_promotion('warn'): - f(x, y) - # depending on the backend, we may or may not get persistent cache warnings - self.assertTrue(1 <= len(cm.output) <= expected_log_len) - msg = cm.output[0] - self.assertIn("tracing context doesn't match", msg) - @jtu.thread_unsafe_test() # logging is not thread-safe def test_arg_names_cache_miss_explanations_new_function_in_loop(self): @jax.jit diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0e7867cc2..025512121 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -3467,9 +3467,8 @@ class ArrayPjitTest(jtu.JaxTestCase): f(x_, y) self.assertLen(cm.output, expected_log_len) msg = cm.output[0] - self.assertIn('never seen input type signature', msg) - self.assertIn('closest seen input type signature has 1 mismatches', msg) - self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg) + self.assertIn("different input types", msg) + self.assertIn("at x, now f32[8]({Auto: ('x',)}) and before f32[8]({})", msg) def test_pjit_function_cache_cpp(self): def f(x):