mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #27980 from gnecula:tracing_cache
PiperOrigin-RevId: 747274185
This commit is contained in:
commit
6ca623f79b
333
jax/_src/pjit.py
333
jax/_src/pjit.py
@ -21,7 +21,6 @@ import dataclasses
|
||||
from functools import partial
|
||||
import inspect
|
||||
import logging
|
||||
import operator as op
|
||||
import weakref
|
||||
from typing import NamedTuple, Any, Union, cast
|
||||
import warnings
|
||||
@ -1158,17 +1157,209 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
|
||||
|
||||
callsites_with_tracing_cache_miss: set[str] = set()
|
||||
|
||||
def diff_tracing_cache_keys(
|
||||
k: tuple, oldk: tuple, debug_info: lu.DebugInfo) -> tuple[Sequence[str], int]:
|
||||
"""Explanations of differences between the cache keys, along with diff sizes.
|
||||
|
||||
Result: a pair of a list of explanations for differences, and the total size
|
||||
of the differences. The sizes are used to pick the old key with the smallest
|
||||
different size for the explanation that is shown to the user.
|
||||
"""
|
||||
(fun_transforms_k, fun_params_k, fun_in_type_k,
|
||||
(arg_in_type_k, arg_attr_data_k, arg_inline_k), ctx_k) = k
|
||||
(fun_transforms_ok, fun_params_ok, fun_in_type_ok,
|
||||
(arg_in_type_ok, arg_attr_data_ok, arg_inline_ok), ctx_ok) = oldk
|
||||
|
||||
diffs: list[tuple[str, int]] = [] # each difference with its size
|
||||
def unavailable(key_field: str, what_k, what_ok):
|
||||
diffs.append(
|
||||
(f"different {key_field}:\n now: {what_k}\n != before: {what_ok}.\n"
|
||||
"explanation unavailable! "
|
||||
"please open an issue at https://github.com/jax-ml/jax.",
|
||||
10))
|
||||
|
||||
def list_diff_size(s1: Sequence, s2: Sequence) -> int:
|
||||
min_len = min(len(s1), len(s2))
|
||||
diff_size = max(len(s1), len(s2)) - min_len
|
||||
diff_size += sum(e1 != e2 for e1, e2 in zip(s1[:min_len],
|
||||
s2[:min_len]))
|
||||
return diff_size
|
||||
|
||||
different_leaf_count = False
|
||||
|
||||
def explain_transform_argnums_partial(param_k: tuple, param_ok: tuple):
|
||||
dyn_argnums_k, static_args_k = param_k
|
||||
dyn_argnums_ok, static_args_ok = param_ok
|
||||
if dyn_argnums_k != dyn_argnums_ok:
|
||||
diffs.append(
|
||||
("different static_argnums:\n"
|
||||
f" dynamic argnums now {dyn_argnums_k} and before {dyn_argnums_ok}",
|
||||
1))
|
||||
if static_args_k != static_args_ok:
|
||||
diffs.append(
|
||||
("different value of static args:\n"
|
||||
f" now {', '.join(repr(a.val) for a in static_args_k)}"
|
||||
f" and before {', '.join(repr(a.val) for a in static_args_ok)}",
|
||||
list_diff_size(static_args_k, static_args_ok)))
|
||||
|
||||
def explain_transform_argnames_partial(param_k: tuple, param_ok: tuple):
|
||||
static_kwargs_k, = param_k
|
||||
static_kwargs_ok, = param_ok
|
||||
static_kwargs_k = [(k, v.val) for k, v in
|
||||
sorted(static_kwargs_k.val.items())]
|
||||
static_kwargs_ok = [(k, v.val) for k, v in
|
||||
sorted(static_kwargs_ok.val.items())]
|
||||
if static_kwargs_k != static_kwargs_ok:
|
||||
diffs.append(
|
||||
("different value of static kwargs:\n"
|
||||
f" now {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_k)}}}"
|
||||
f" and before {{{', '.join(f'{k}: {repr(v)}' for k, v in static_kwargs_ok)}}}",
|
||||
list_diff_size(static_kwargs_k, static_kwargs_ok)))
|
||||
|
||||
def explain_in_tree_diff(in_tree_k: PyTreeDef, in_tree_ok: PyTreeDef):
|
||||
nonlocal different_leaf_count
|
||||
different_leaf_count = (in_tree_k.num_leaves != in_tree_ok.num_leaves)
|
||||
if not different_leaf_count:
|
||||
# Look for the special case of passing positional args as kwargs or
|
||||
# vice-versa; the common prefix of positional args match.
|
||||
args_tree_k, kwargs_tree_k = treedef_children(in_tree_k)
|
||||
nr_args_k = len(treedef_children(args_tree_k))
|
||||
args_tree_ok, kwargs_tree_ok = treedef_children(in_tree_ok)
|
||||
nr_args_ok = len(treedef_children(args_tree_k))
|
||||
if (treedef_children(args_tree_k)[:min(nr_args_k, nr_args_ok)] ==
|
||||
treedef_children(args_tree_ok)[:min(nr_args_k, nr_args_ok)]):
|
||||
keys_k = kwargs_tree_k.node_data()[1] # type: ignore[index]
|
||||
keys_ok = kwargs_tree_ok.node_data()[1] # type: ignore[index]
|
||||
diffs.append(
|
||||
(("different number of args and kwargs, but same total number.\n"
|
||||
f" now {nr_args_k} args and kwargs "
|
||||
f"with keys {keys_k}\n"
|
||||
f" before {nr_args_ok} args and kwargs "
|
||||
f"with keys {keys_ok}"),
|
||||
abs(nr_args_ok - nr_args_k)))
|
||||
return
|
||||
|
||||
in_tree_k_str = str(in_tree_k)
|
||||
in_tree_k_str = (in_tree_k_str if len(in_tree_k_str) < 73
|
||||
else in_tree_k_str[:73] + "...")
|
||||
in_tree_ok_str = str(in_tree_ok)
|
||||
in_tree_ok_str = (in_tree_ok_str if len(in_tree_ok_str) < 73
|
||||
else in_tree_ok_str[:73] + "...")
|
||||
diff = [f"different input pytree:\n now: {in_tree_k_str}\n"
|
||||
f" before: {in_tree_ok_str}"]
|
||||
|
||||
errs = list(tree_util.equality_errors_pytreedef(in_tree_k, in_tree_ok))
|
||||
for path, thing1, thing2, explanation in errs:
|
||||
fst, *path = path # type: ignore
|
||||
base = ["args", "kwargs"][fst.idx]
|
||||
diff.append(
|
||||
f" * at {base}{keystr(tuple(path))}, now {thing1} and before {thing2},"
|
||||
f" so {explanation}")
|
||||
diffs.append(("\n".join(diff), len(errs)))
|
||||
|
||||
def explain_args_type_diff(args_k: tuple[core.AbstractValue],
|
||||
args_ok: tuple[core.AbstractValue]):
|
||||
diff_size = 0
|
||||
arg_names = debug_info.safe_arg_names(len(args_k))
|
||||
def arg_type_to_str(at):
|
||||
if hasattr(at, "str_short"):
|
||||
return at.str_short(short_dtypes=True)
|
||||
else:
|
||||
return str(at)
|
||||
args_k_str = ", ".join(f"{an}: {arg_type_to_str(at)}"
|
||||
for an, at in zip(arg_names, args_k))
|
||||
args_k_str = args_k_str if len(args_k_str) < 73 else args_k_str[:73] + "..."
|
||||
diff = [f"different input types:\n types now: {args_k_str}"]
|
||||
add_weak_type_hint = False
|
||||
|
||||
for name, arg_t_k, arg_t_ok in zip(arg_names, args_k, args_ok):
|
||||
if arg_t_k == arg_t_ok: continue
|
||||
this_arg_diff_size = 0
|
||||
if type(arg_t_k) == type(arg_t_ok) == core.ShapedArray:
|
||||
s1, s2 = arg_type_to_str(arg_t_k), arg_type_to_str(arg_t_ok)
|
||||
this_arg_diff_size += list_diff_size(arg_t_k.shape, arg_t_ok.shape) # type: ignore
|
||||
|
||||
if arg_t_k.weak_type != arg_t_ok.weak_type: # type: ignore
|
||||
s1 += f"{{weak_type={arg_t_k.weak_type}}}" # type: ignore
|
||||
s2 += f"{{weak_type={arg_t_ok.weak_type}}}" # type: ignore
|
||||
add_weak_type_hint = True
|
||||
this_arg_diff_size += 1
|
||||
elif arg_t_k.sharding != arg_t_ok.sharding: # type: ignore
|
||||
s1 = arg_t_k.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore
|
||||
s2 = arg_t_ok.str_short(short_dtypes=True, mesh_axis_types=True) # type: ignore
|
||||
this_arg_diff_size += 1
|
||||
else:
|
||||
s1, s2 = str(arg_t_k), str(arg_t_ok)
|
||||
diff_size += max(1, this_arg_diff_size)
|
||||
diff.append(f" * at {name}, now {s1} and before {s2}")
|
||||
|
||||
if add_weak_type_hint:
|
||||
diff.append(
|
||||
"where weak_type=True often means a Python builtin numeric value, and \n"
|
||||
"weak_type=False means a jax.Array.\n"
|
||||
"See https://docs.jax.dev/en/latest/type_promotion.html#weak-types.")
|
||||
diffs.append(("\n".join(diff), diff_size))
|
||||
|
||||
if fun_transforms_k != fun_transforms_ok:
|
||||
if len(fun_transforms_k) != len(fun_transforms_ok):
|
||||
different_leaf_count = True # Skip other more precise checks
|
||||
unavailable("fun_transforms length",
|
||||
fun_transforms_k, fun_transforms_ok)
|
||||
else:
|
||||
for i, (t, ot) in enumerate(zip(fun_transforms_k, fun_transforms_ok)):
|
||||
t_name = t[0].__name__
|
||||
if t == ot: continue
|
||||
if t[0] != ot[0]:
|
||||
unavailable(f"fun_transforms[{i}] transform", t, ot)
|
||||
continue
|
||||
|
||||
if t_name == "flatten_fun":
|
||||
explain_in_tree_diff(t[1][0], ot[1][0])
|
||||
continue
|
||||
if t_name == "_argnums_partial":
|
||||
explain_transform_argnums_partial(t[1], ot[1])
|
||||
continue
|
||||
if t_name == "_argnames_partial":
|
||||
explain_transform_argnames_partial(t[1], ot[1])
|
||||
continue
|
||||
unavailable(f"fun_transforms.{t_name} params", t[1:], ot[1:])
|
||||
continue
|
||||
|
||||
# If we had different leaf counts, we can discard the _argnums_partial
|
||||
# difference. That transform sometimes occurs before the flatten_fun
|
||||
if different_leaf_count:
|
||||
diffs = [d for d in diffs if "fun_transforms._argnums_partial" not in d[0]]
|
||||
if fun_params_k != fun_params_ok:
|
||||
unavailable("fun_params", fun_params_k, fun_params_ok)
|
||||
if fun_in_type_k != fun_in_type_ok:
|
||||
unavailable("fun_in_type", fun_params_k, fun_params_ok)
|
||||
if arg_in_type_k != arg_in_type_ok and not different_leaf_count:
|
||||
explain_args_type_diff(arg_in_type_k, arg_in_type_ok)
|
||||
if arg_attr_data_k != arg_attr_data_ok:
|
||||
unavailable("arg_attr_data", arg_attr_data_k, arg_attr_data_ok)
|
||||
if arg_inline_k != arg_inline_ok:
|
||||
unavailable("arg_inline", arg_inline_k, arg_inline_ok)
|
||||
if ctx_k != ctx_ok:
|
||||
assert len(ctx_k) == len(ctx_ok)
|
||||
idxs = [f" [{i}]: now {c_k} and before {c_ok}"
|
||||
for i, (c_k, c_ok) in enumerate(zip(ctx_k, ctx_ok)) if c_k != c_ok]
|
||||
diffs.append(
|
||||
("different tracing context, e.g. due to config or context manager.\n"
|
||||
"found differences at positions\n" +
|
||||
", and\n".join(idxs) +
|
||||
"\ncompare to tuple returned by "
|
||||
"config.trace_context() in jax/_src/config.py.",
|
||||
len(idxs)))
|
||||
if not diffs: # Should never happen, but let's not crash
|
||||
unavailable("something (unexpected empty diffs)", k, oldk)
|
||||
diffs_and_sizes = util.unzip2(sorted(diffs, key=lambda d: d[1]))
|
||||
return (diffs_and_sizes[0], sum(diffs_and_sizes[1]))
|
||||
|
||||
|
||||
def explain_tracing_cache_miss(
|
||||
fun: lu.WrappedFun, unseen_f: bool, cache: dict, key: tuple):
|
||||
if config.check_tracer_leaks.value: return
|
||||
|
||||
def unpack(key):
|
||||
transforms, (), _, (in_type, _, inline), *_, ctx = key
|
||||
# TODO(dougalm,mattjj): enable cache miss explanation with attrs
|
||||
_, (_, (in_tree,)), *_ = transforms
|
||||
return in_tree, in_type, inline.val, ctx
|
||||
in_tree, in_type, inline, ctx = unpack(key)
|
||||
if inline: return
|
||||
if key[3][2].val: return # No explanations for "inline" functions
|
||||
|
||||
debug_info = fun.debug_info
|
||||
func_filename = debug_info.func_filename
|
||||
@ -1177,7 +1368,7 @@ def explain_tracing_cache_miss(
|
||||
|
||||
msg: list[str] = []
|
||||
p = msg.append
|
||||
done = lambda: logger.log(logging.WARNING, '\n'.join(msg))
|
||||
done = lambda: logger.log(logging.WARNING, "\n".join(msg))
|
||||
|
||||
callsite = source_info_util.summarize(source_info_util.current())
|
||||
p(f"TRACING CACHE MISS at {callsite} because:")
|
||||
@ -1188,110 +1379,42 @@ def explain_tracing_cache_miss(
|
||||
src_info += f" defined at {func_filename}"
|
||||
if func_lineno := debug_info.func_lineno:
|
||||
src_info += f":{func_lineno}"
|
||||
if unseen_f:
|
||||
p(f" never seen function:\n {debug_info.func_name} id={id(fun.f)}{src_info}")
|
||||
func_name = debug_info.func_name
|
||||
if unseen_f or not cache:
|
||||
p(f" never seen function:\n {func_name} id={id(fun.f)}{src_info}")
|
||||
if callsite in callsites_with_tracing_cache_miss:
|
||||
p(" but seen another function defined on the same line; maybe the function is\n"
|
||||
" being re-defined repeatedly, preventing caching?")
|
||||
else:
|
||||
callsites_with_tracing_cache_miss.add(callsite)
|
||||
return done()
|
||||
|
||||
p(f" for {func_name}{src_info}")
|
||||
|
||||
diffs = [diff_tracing_cache_keys(key, ok, debug_info)
|
||||
for ok in cache.keys() if key != ok]
|
||||
assert diffs, "we must find some diffs if key differs from all cache keys"
|
||||
min_diff = min(diffs, key=lambda v: v[1])
|
||||
smallest_diffs: Sequence[Sequence[str]] # the diffs for the closest keys
|
||||
smallest_diffs = [d[0] for d in diffs if d[1] == min_diff[1]]
|
||||
def indent_subsequent_lines(indent: int, msg: str) -> str:
|
||||
return msg.replace("\n", "\n" + " " * indent)
|
||||
def p_one_diff(diff: Sequence[str]):
|
||||
for d in diff:
|
||||
p(" * key with " + indent_subsequent_lines(4, d))
|
||||
|
||||
if len(smallest_diffs) == 1:
|
||||
p(" all previously seen cache keys are different. Closest previous key:")
|
||||
p_one_diff(smallest_diffs[0])
|
||||
else:
|
||||
p(f" for {debug_info.func_name}{src_info}")
|
||||
p(" all previously seen cache keys are different. "
|
||||
"Several previous keys are closest:")
|
||||
for d in smallest_diffs:
|
||||
p_one_diff(d)
|
||||
|
||||
seen_keys = map(unpack, cache.keys())
|
||||
done()
|
||||
return
|
||||
|
||||
# have we maybe switched some args to be kwargs or visa-versa?
|
||||
args_tree, kwargs_tree = treedef_children(in_tree)
|
||||
args_kwargs_trees = [treedef_children(k) for k, *_ in seen_keys]
|
||||
args_kwargs_match = [t for t in args_kwargs_trees
|
||||
if t == [args_tree, kwargs_tree]]
|
||||
if not args_kwargs_match:
|
||||
num_args = len(treedef_children(args_tree))
|
||||
_, kwarg_keys = kwargs_tree.node_data() # type: ignore
|
||||
p(f" never seen passing {num_args} positional args and {len(kwarg_keys)} "
|
||||
"keyword args with keys:\n"
|
||||
f" {', '.join(map(repr, kwarg_keys))}")
|
||||
dont_match = [set(t[1].node_data()[1]) for t in args_kwargs_trees # type: ignore
|
||||
if t != [args_tree, kwargs_tree]]
|
||||
close_kwargs = min(
|
||||
dont_match, key=set(kwarg_keys).symmetric_difference, default=None
|
||||
)
|
||||
if not close_kwargs:
|
||||
p(" closest seen is passing no keyword args")
|
||||
else:
|
||||
p(f" closest seen passes {len(close_kwargs)} keyword args with keys:\n"
|
||||
f" {', '.join(map(repr, close_kwargs))}")
|
||||
return done()
|
||||
|
||||
# have we never seen this tracing context before?
|
||||
ctxs_match = [c for *_, c in seen_keys if c == ctx]
|
||||
if not ctxs_match:
|
||||
p(" tracing context doesn't match, e.g. due to config or context manager")
|
||||
dont_match = [c for *_, c in seen_keys if c != ctx]
|
||||
closest_ctx = min(dont_match, key=lambda c: sum(map(op.ne, c, ctx)))
|
||||
idxs = [i for i, (c1, c2) in enumerate(zip(ctx, closest_ctx)) if c1 != c2]
|
||||
p(" closest seen context tuple differs at positions:\n"
|
||||
f" {', '.join(map(str, idxs))}\n"
|
||||
" compare to tuple returned by config._trace_context() in jax/_src/config.py.")
|
||||
return done()
|
||||
|
||||
# have we never seen this input pytree before?
|
||||
trees_match = [k for k in seen_keys if k[0] == in_tree]
|
||||
if not trees_match:
|
||||
in_tree_str = f':\n {in_tree}' if len(str(in_tree)) < 76 else ''
|
||||
p(f" never seen input pytree{in_tree_str}")
|
||||
dont_match = [t for t, *_ in seen_keys if t != in_tree]
|
||||
closest_tree = min(dont_match, key=lambda t: abs(t.num_leaves - in_tree.num_leaves))
|
||||
errs = list(tree_util.equality_errors_pytreedef(in_tree, closest_tree)) # type: ignore[arg-type]
|
||||
p(f" closest seen input pytree has {len(errs)} mismatches, including:")
|
||||
for path, thing1, thing2, explanation in errs:
|
||||
fst, *path = path # type: ignore
|
||||
base = ['args', 'kwargs'][fst.idx]
|
||||
p(f" * at {base}{keystr(tuple(path))}, seen {thing2} but now given {thing1},"
|
||||
f" so {explanation}")
|
||||
return done()
|
||||
|
||||
# have we never seen these input types (eg shapes, dtypes) before?
|
||||
types_match = [k for k in trees_match if k[1] == in_type]
|
||||
if not types_match:
|
||||
if len(in_type) < 5:
|
||||
in_type_str = ":\n {}".format(", ".join(
|
||||
f"{n}: {ty.str_short(short_dtypes=True)}"
|
||||
for n, ty in zip(debug_info.arg_names, in_type)))
|
||||
else:
|
||||
in_type_str = ''
|
||||
p(f" never seen input type signature{in_type_str}")
|
||||
dont_match = [t for _, t, *_ in trees_match if t != in_type]
|
||||
closest_ty = min(dont_match, key=lambda t: sum(map(op.ne, t, in_type)))
|
||||
num_mismatch = sum(map(op.ne, closest_ty, in_type))
|
||||
p(f" closest seen input type signature has {num_mismatch} mismatches, including:")
|
||||
add_weak_type_hint = False
|
||||
arg_names = debug_info.safe_arg_names(len(in_type))
|
||||
|
||||
for name, ty1, ty2 in zip(arg_names, closest_ty, in_type):
|
||||
if ty1 != ty2:
|
||||
if type(ty1) == type(ty2) == core.ShapedArray:
|
||||
s1, s2 = ty1.str_short(True), ty2.str_short(True)
|
||||
if ty1.weak_type != ty2.weak_type:
|
||||
s1 += f"{{weak_type={ty1.weak_type}}}"
|
||||
s2 += f"{{weak_type={ty2.weak_type}}}"
|
||||
add_weak_type_hint = True
|
||||
elif ty1.sharding != ty2.sharding:
|
||||
s1 = ty1.str_short(short_dtypes=True, mesh_axis_types=True)
|
||||
s2 = ty2.str_short(short_dtypes=True, mesh_axis_types=True)
|
||||
else:
|
||||
s1, s2 = str(ty1), str(ty2)
|
||||
p(f" * at {name}, seen {s1}, but now given {s2}")
|
||||
if add_weak_type_hint:
|
||||
p("where weak_type=True often means a Python builtin numeric value, and ")
|
||||
p("weak_type=False means a jax.Array.")
|
||||
p("See https://docs.jax.dev/en/latest/type_promotion.html#weak-types")
|
||||
return done()
|
||||
|
||||
# we think this is unreachable...
|
||||
p("explanation unavailable! please open an issue at https://github.com/jax-ml/jax")
|
||||
return done()
|
||||
|
||||
@partial(lu.cache, explain=explain_tracing_cache_miss)
|
||||
def _create_pjit_jaxpr(
|
||||
|
@ -4465,66 +4465,6 @@ class APITest(jtu.JaxTestCase):
|
||||
tracing_add_count += 1
|
||||
self.assertEqual(tracing_add_count, 2)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations(self):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return jnp.sin(x) * y['hi']
|
||||
|
||||
x = jnp.float32(1.)
|
||||
y = {'hi': jnp.arange(3., dtype='float32')}
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
|
||||
# print on first miss, not on hit
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(x, y)
|
||||
f(x, y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('TRACING CACHE MISS', msg)
|
||||
self.assertIn('never seen function', msg)
|
||||
|
||||
# shape change
|
||||
y_ = {'hi': jnp.arange(4, dtype='float32')}
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(x, y_)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen input type signature', msg)
|
||||
self.assertIn('closest seen input type signature has 1 mismatches', msg)
|
||||
self.assertIn('seen f32[3], but now given f32[4]', msg)
|
||||
|
||||
# weak type change (assuming no x64)
|
||||
if not config.enable_x64.value:
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(1., y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('weak_type=True', msg)
|
||||
self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg)
|
||||
|
||||
# kwarg change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(1, y=y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen passing 1 positional args and 1 keyword args', msg)
|
||||
|
||||
# tracing config change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
with jax.numpy_rank_promotion('warn'):
|
||||
f(x, y)
|
||||
# depending on the backend, we may or may not get persistent cache warnings
|
||||
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("tracing context doesn't match", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_skip_internals(self):
|
||||
if is_persistent_cache_enabled():
|
||||
@ -4535,6 +4475,211 @@ class APITest(jtu.JaxTestCase):
|
||||
for i in range(2):
|
||||
jnp.sin(jnp.arange(i + 1, dtype=np.float32))
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_first_miss(self):
|
||||
@jax.jit
|
||||
def f(x): return x
|
||||
x = jnp.float32(1.)
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
# print on first miss, not on hit
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(x)
|
||||
f(x)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("TRACING CACHE MISS", msg)
|
||||
self.assertIn("never seen function", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_in_tree(self):
|
||||
@jax.jit
|
||||
def f(*args, **kwargs): return args[0]
|
||||
|
||||
f(0., 1., y=(2., 2.1))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
# Same number of leaves but different trees
|
||||
f(0., (1., 1.1), y=2.)
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different input pytree", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_arg_passed_as_kwarg(self):
|
||||
@jax.jit
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
|
||||
f(0., 1.)
|
||||
|
||||
# kwarg change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(0., y=1.)
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different number of args and kwargs, but same total number", msg)
|
||||
self.assertIn("now 1 args and kwargs with keys ['y']", msg)
|
||||
self.assertIn("before 1 args and kwargs with keys []", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_static_argnums(self):
|
||||
@partial(jax.jit, static_argnums=(0, 2))
|
||||
def f(x, y, z):
|
||||
return y
|
||||
|
||||
f(1., 2., "foo")
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(1., 2., "bar")
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different value of static args", msg)
|
||||
self.assertIn("now 1.0, 'bar' and before 1.0, 'foo'", msg)
|
||||
self.assertNotIn('explanation unavailable!', msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_static_argnames(self):
|
||||
@partial(jax.jit, static_argnames='foo')
|
||||
def f(*, foo):
|
||||
return 1
|
||||
|
||||
f(foo="foo")
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(foo="bar")
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different value of static kwargs", msg)
|
||||
self.assertIn("now {foo: 'bar'} and before {foo: 'foo'}", msg)
|
||||
self.assertNotIn('explanation unavailable!', msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_dtype(self):
|
||||
@jax.jit
|
||||
def f(x, y): return x
|
||||
f(np.float32(0), np.float32(1))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(np.float32(0), np.int32(1))
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different input types", msg)
|
||||
self.assertIn("at y, now i32[] and before f32[]", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_weak_type(self):
|
||||
@jax.jit
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
|
||||
y = jnp.arange(4, dtype="float32")
|
||||
f(jnp.float32(0.), y)
|
||||
# weak type change (assuming no x64)
|
||||
if config.enable_x64.value:
|
||||
self.skipTest("Work only for 32 bit mode")
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
f(0., y)
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different input types", msg)
|
||||
self.assertIn("at x, now f32[]{weak_type=True} and before f32[]{weak_type=False}", msg)
|
||||
self.assertIn("https://docs.jax.dev/en/latest/type_promotion.html#weak-types", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_shape(self):
|
||||
@jax.jit
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
f(np.float32(0), np.arange(1, dtype=np.float32))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(np.float32(0), np.arange(2, dtype=np.float32))
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("different input types", msg)
|
||||
self.assertIn("at y, now f32[2] and before f32[1]", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_shape_explain_closest(self):
|
||||
@jax.jit
|
||||
def f(x): return x
|
||||
f(np.ones((1, 2), dtype=np.float32))
|
||||
f(np.ones((10, 20, 30), dtype=np.float32))
|
||||
f(np.ones((1, 2, 3), dtype=np.float32))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(np.ones((10, 2, 30), dtype=np.float32))
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("key with different input types", msg)
|
||||
self.assertIn("at x, now f32[10,2,30] and before f32[10,20,30]", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_other_tracing_config(self):
|
||||
@jax.jit
|
||||
def f(x, y): return jnp.sin(x) + y
|
||||
|
||||
f(0., 1.)
|
||||
# tracing config change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level="WARNING") as cm:
|
||||
with jax.numpy_rank_promotion("warn"):
|
||||
with jax.default_matmul_precision("high"):
|
||||
f(0., 1.)
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("key with different tracing context", msg)
|
||||
self.assertIn("now warn and before", msg)
|
||||
self.assertIn("now high and before", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_multiple_changes(self):
|
||||
@jax.jit
|
||||
def f(x): return jnp.sin(x)
|
||||
|
||||
call_1 = f(np.arange(4, dtype=np.float32))
|
||||
with jax.numpy_rank_promotion("warn"):
|
||||
call_2 = f(np.arange(8, dtype=np.float32))
|
||||
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
# Matches call_2 in shape but not context, and call_1 in context but
|
||||
# not in shape.
|
||||
f(np.arange(8, dtype=np.float32))
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("key with different input types", msg)
|
||||
self.assertIn("at x, now f32[8] and before f32[4]", msg)
|
||||
self.assertIn("key with different tracing context", msg)
|
||||
self.assertNotIn("explanation unavailable!", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_cache_miss_explanations_new_function_in_loop(self):
|
||||
@jax.jit
|
||||
|
@ -392,66 +392,6 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(TypeError, err_str):
|
||||
jax.jit(f)(jnp.int32)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_arg_names_cache_miss_explanations(self):
|
||||
@jax.jit
|
||||
def f(x, y):
|
||||
return jnp.sin(x) * y['hi']
|
||||
|
||||
x = jnp.float32(1.)
|
||||
y = {'hi': jnp.arange(3., dtype='float32')}
|
||||
|
||||
expected_log_len = 1 if not is_persistent_cache_enabled() else 3
|
||||
|
||||
# print on first miss, not on hit
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(x, y)
|
||||
f(x, y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('TRACING CACHE MISS', msg)
|
||||
self.assertIn('never seen function', msg)
|
||||
|
||||
# shape change
|
||||
y_ = {'hi': jnp.arange(4, dtype='float32')}
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(x, y_)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen input type signature', msg)
|
||||
self.assertIn('closest seen input type signature has 1 mismatches', msg)
|
||||
self.assertIn('seen f32[3], but now given f32[4]', msg)
|
||||
|
||||
# weak type change (assuming no x64)
|
||||
if not config.enable_x64.value:
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(1., y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('weak_type=True', msg)
|
||||
self.assertIn('https://docs.jax.dev/en/latest/type_promotion.html#weak-types', msg)
|
||||
|
||||
# kwarg change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
f(1, y=y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen passing 1 positional args and 1 keyword args', msg)
|
||||
|
||||
# tracing config change
|
||||
with config.explain_cache_misses(True):
|
||||
with self.assertLogs(level='WARNING') as cm:
|
||||
with jax.numpy_rank_promotion('warn'):
|
||||
f(x, y)
|
||||
# depending on the backend, we may or may not get persistent cache warnings
|
||||
self.assertTrue(1 <= len(cm.output) <= expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn("tracing context doesn't match", msg)
|
||||
|
||||
@jtu.thread_unsafe_test() # logging is not thread-safe
|
||||
def test_arg_names_cache_miss_explanations_new_function_in_loop(self):
|
||||
@jax.jit
|
||||
|
@ -3467,9 +3467,8 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
f(x_, y)
|
||||
self.assertLen(cm.output, expected_log_len)
|
||||
msg = cm.output[0]
|
||||
self.assertIn('never seen input type signature', msg)
|
||||
self.assertIn('closest seen input type signature has 1 mismatches', msg)
|
||||
self.assertIn("seen f32[8]({}), but now given f32[8]({Auto: ('x',)})", msg)
|
||||
self.assertIn("different input types", msg)
|
||||
self.assertIn("at x, now f32[8]({Auto: ('x',)}) and before f32[8]({})", msg)
|
||||
|
||||
def test_pjit_function_cache_cpp(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user