[better_errors] Cleanup use of DebugInfo.arg_names and result_paths

Previously, we represented a missing arg name with `None`,
and a missing result path with the empty string. We now
adopt the same convention for arg names and use empty strings.
This simplifies the typing, and prevents the string "None" from
appearing in error messages.

I changed how we encode the result paths. Previously for a
function that returns a single array the path was the empty
string (the same as for an unknown path). And for a function
that returns a pair of arrays it was `([0], [1])`. Now we
add the "result" prefix: `("result",)` for a function returning a
single array and `(result[0], result[1])` for a function returning
a pair of arrays.

Finally, in debug_info_test, I removed the `check_tracer_arg_name`
so that all spied tracers are printed with the argument name they
depend on.
This commit is contained in:
George Necula 2025-02-18 10:09:47 +01:00
parent d695aa4c63
commit 1be801bac8
16 changed files with 266 additions and 249 deletions

View File

@ -56,7 +56,7 @@ some other features along the way. An example:
>>> # Print lowered HLO
>>> print(lowered.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
%c = stablehlo.constant dense<2> : tensor<i32>
%0 = stablehlo.multiply %c, %arg0 : tensor<i32>
%1 = stablehlo.add %0, %arg1 : tensor<i32>
@ -140,7 +140,7 @@ to invoke the resulting compiled function. Continuing with our example above:
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = ""}) {
func.func public @main(%arg0: tensor<i32>) -> (tensor<i32> {jax.result_info = "result"}) {
%c = stablehlo.constant dense<14> : tensor<i32>
%0 = stablehlo.add %c, %arg0 : tensor<i32>
return %0 : tensor<i32>

View File

@ -44,7 +44,7 @@ Here is an example:
(ShapedArray(float32[]),)
>>> print(re.search(r".*@main.*", exported.mlir_module()).group(0))
func.func public @main(%arg0: tensor<f32> loc("x")) -> (tensor<f32> {jax.result_info = ""}) {
func.func public @main(%arg0: tensor<f32> loc("x")) -> (tensor<f32> {jax.result_info = "result"}) {
>>> # And you can serialize the Exported to a bytearray.
>>> serialized: bytearray = exported.serialize()
@ -206,7 +206,7 @@ as in the following example:
>>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results)
>>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir())
module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = ""}) {
func.func public @main(%arg0: tensor<f32>) -> (tensor<f32> {jax.result_info = "result"}) {
%0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor<f32>) -> tensor<f32>
return %0 : tensor<f32>
}

View File

@ -458,7 +458,7 @@ def saved_residuals(f: Callable,
return _saved_residuals(jaxpr, debug_info.arg_names)
def _saved_residuals(jaxpr: core.Jaxpr,
arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]:
arg_names: Sequence[str]) -> list[tuple[core.AbstractValue, str]]:
res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)]
res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)}
@ -473,7 +473,7 @@ def _saved_residuals(jaxpr: core.Jaxpr,
for i, v in enumerate(jaxpr.invars):
if v in res_vars:
if arg_names[i] is not None:
if arg_names[i]:
src = f'from the argument {arg_names[i]}'
else:
src = 'from the argument at flattened index {i}'

View File

@ -2274,7 +2274,7 @@ def _check_sharding(aval, s):
aval = core.get_token_aval()
if not isinstance(s, PmapSharding):
pjit.pjit_check_aval_sharding(
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
(s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False)
s.shard_shape(aval.shape) # should raise an Error if incompatible

View File

@ -662,7 +662,7 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None,
args: Sequence[Any], kwargs: dict[str, Any],
static_argnums: Sequence[int],
static_argnames: Sequence[str],
) -> tuple[str | None, ...]:
) -> tuple[str, ...]:
"""Returns the names of the non-static arguments.
If the `fn_signature` is given then we get from it the names of the

View File

@ -149,8 +149,8 @@ class Jaxpr:
debug_info = debug_info or lu._missing_debug_info("core.Jaxpr")
self._debug_info = debug_info.resolve_result_paths()
# TODO(necula): re-enable these safety checks
# assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
# assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars)
# assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars)
def __str__(self):
return str(self.pretty_print())

View File

@ -135,7 +135,7 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents):
def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr:
dbg = jaxpr.debug_info._replace(
arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars))
arg_names=jaxpr.debug_info.arg_names + ("",) * len(jaxpr.constvars))
return core.Jaxpr(constvars=(),
invars=jaxpr.invars + jaxpr.constvars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,

View File

@ -1113,7 +1113,7 @@ def lower_jaxpr_to_module(
result_shardings: Sequence[JSharding | AUTO | None] | None = None,
in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None,
arg_names: Sequence[str | None] | None = None,
arg_names: Sequence[str] | None = None,
result_names: Sequence[str] | None = None,
num_replicas: int = 1,
num_partitions: int = 1,

View File

@ -842,7 +842,7 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:
"""Moves the constvars to the start of invars."""
config.enable_checks.value and core.check_jaxpr(jaxpr)
dbg = jaxpr.debug_info._replace(
arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
arg_names=("",) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names)
lifted_jaxpr = Jaxpr(constvars=(),
invars=jaxpr.constvars + jaxpr.invars,
outvars=jaxpr.outvars, eqns=jaxpr.eqns,
@ -1574,7 +1574,7 @@ class DynamicJaxprTracer(core.Tracer):
origin = ("The error occurred while tracing the function "
f"{dbg.func_src_info} for {dbg.traced_for}. ")
if invar_pos and dbg.arg_names:
if invar_pos:
try:
arg_names = [dbg.arg_names[i] for i in invar_pos]
except IndexError:

View File

@ -3260,7 +3260,7 @@ def check_array_xla_sharding_layout_match(
from jax._src.array import ArrayImpl
# jaxpr_debug_info.arg_names are before DCE, so need to DCE them.
arg_names = (
[a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore
[a for i, a in enumerate(jaxpr_debug_info.arg_names)
if i in kept_var_idx]
)
errors = []

View File

@ -1554,7 +1554,7 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts,
cond_debug = cond_jaxpr.jaxpr.debug_info
augmented_debug = cond_debug and (
cond_debug._replace(
arg_names=cond_debug.arg_names + (None,) * len(init_dot)
arg_names=cond_debug.arg_names + ("",) * len(init_dot)
)
)
cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars,

View File

@ -276,30 +276,43 @@ class DebugInfo(NamedTuple):
"""Debugging info about a func, its arguments, and results."""
traced_for: str # e.g. 'jit', 'scan', etc
# e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
# no source location information. The first word is always the function name,
# which may be '<unknown>'.
func_src_info: str
"""e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have
no source location information. The first word is always the function name,
which may be '<unknown>'.
"""
# The paths of the flattened non-static argnames,
# e.g. ('x', 'dict_arg["a"]', ... ).
# Uses `None` for the args that do not correspond to user-named arguments,
# e.g., tangent args in jax.jvp. At the moment, `arg_names` accuracy is
# best-effort. Use `safe_arg_names` to detect and handle an unexpected
# number of elements in `arg_names`.
arg_names: tuple[str | None, ...]
arg_names: tuple[str, ...]
"""The paths of the flattened non-static argnames,
e.g. `('x', 'dict_arg["a"]', ... )`.
Uses the empty string for the args that do not correspond to
user-named arguments, e.g., tangent args in `jax.jvp`, or for arguments that
we are not yet tracking properly.
At the moment, `arg_names` accuracy is best-effort.
Use `safe_arg_names` to detect and handle an unexpected
number of elements in `arg_names`.
"""
# The result paths are not available while we are tracing the function,
# instead we keep a thunk. Once we are done tracing, we use
# `self.resolve_result_paths()` to execute the thunk and replace the
# actual result paths. At the moment, `result_paths` accuracy is
# best-effort. Use `safe_result_paths` to detect and handle an unexpected
# number of elements in `result_paths`.
# e.g. ('[0]', '[1]', ...)
result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None
"""The paths to the flattened results, e.g., `('result[0]', result[1])` for a
function that returns a tuple of arrays, or `(result,)` for a function that
returns a single array.
The result paths are not available while we are tracing the function,
instead we keep a thunk. It is possible for the result paths to be `None`
only when we first create a `DebugInfo`, before we put it in `lu.WrappedFun`
and before we start tracing.
Inside a `lu.WrappedFun` it can be only a thunk or a tuple of strings.
Once we are done tracing, we use
`self.resolve_result_paths()` to execute the thunk and replace the
actual result paths.
At the moment, `result_paths` accuracy is best-effort.
Use `safe_result_paths` to detect and handle an unexpected
number of elements in `result_paths`.
"""
def resolve_result_paths(self) -> DebugInfo:
"""Return a debug info with resolved result paths."""
assert self.result_paths is not None
if callable(self.result_paths):
return self._replace(result_paths=tuple(self.result_paths()))
return self
@ -308,21 +321,21 @@ class DebugInfo(NamedTuple):
def func_name(self) -> str:
return self.func_src_info.split(" ")[0]
def safe_arg_names(self, expected: int) -> tuple[str | None, ...]:
def safe_arg_names(self, expected: int) -> tuple[str, ...]:
"""Get the arg_names with a safety check."""
if len(self.arg_names) == expected:
return self.arg_names
else:
# TODO(necula): this should not happen
return (None,) * expected
return ("",) * expected
def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str | None, ...]:
def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str, ...]:
"""Keep only the arg_names for which `keep` is True."""
return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b)
def safe_result_paths(self, expected: int) -> tuple[str, ...]:
"""Get the result paths with a safety check."""
assert not callable(self.result_paths), self
assert self.result_paths is not None and not callable(self.result_paths), self
if self.result_paths is not None and len(self.result_paths) == expected:
return self.result_paths
else:
@ -331,7 +344,7 @@ class DebugInfo(NamedTuple):
def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]:
"""Keep only the result_paths for which `keep` is True."""
assert not callable(self.result_paths), self
assert self.result_paths is not None and not callable(self.result_paths), self
return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b)
@ -368,7 +381,7 @@ def _clean_keystr_arg_names(k: KeyPath) -> str:
@transformation_with_aux2
def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs):
ans = _fun(*args, **kwargs)
result_paths = [_clean_keystr_arg_names(path) for path, _ in generate_key_paths(ans)]
result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans))
if _store:
# In some instances a lu.WrappedFun is called multiple times, e.g.,
# the bwd function in a custom_vjp

View File

@ -538,7 +538,7 @@ class PjitParams(NamedTuple):
in_tree: PyTreeDef
out_tree: PyTreeDef
donated_invars: tuple[bool, ...]
arg_names: tuple[str | None, ...]
arg_names: tuple[str, ...]
num_consts: int
attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]]
@ -663,7 +663,7 @@ def _infer_params_impl(
compiler_options_kvs=ji.compiler_options_kvs,
)
return PjitParams(consts, params, in_avals, in_tree, out_tree(),
donated_invars, dbg.arg_names if dbg else None, len(consts),
donated_invars, dbg.arg_names, len(consts),
attrs_tracked), args_flat
@ -741,13 +741,13 @@ def _infer_input_type(fun: Callable, dbg: core.DebugInfo,
for i, x in enumerate(explicit_args):
avals.append(core.shaped_abstractify(x))
except OverflowError:
arg_path = f"argument path is {dbg.arg_names[i]}" # type: ignore
arg_path = f"argument path is {dbg.arg_names[i]}"
raise OverflowError(
"An overflow was encountered while parsing an argument to a jitted "
f"computation, whose {arg_path}."
) from None
except TypeError:
arg_description = f"path {dbg.arg_names[i]}" # type: ignore
arg_description = f"path {dbg.arg_names[i]}"
raise TypeError(
f"Error interpreting argument to {fun} as an abstract array."
f" The problematic value is of type {type(x)} and was passed to"
@ -1134,7 +1134,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves,
"pjit in_layouts", in_tree, in_layouts, tupled_args=True)
# TODO(dougalm,mattjj): enable debug info with attrs_tracked
attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals)
attrs_tracked = len(debug_info.arg_names) != len(in_avals)
if not config.dynamic_shapes.value and not attrs_tracked:
pjit_check_aval_sharding(in_shardings_flat, in_avals,
debug_info.safe_arg_names(len(in_avals)),
@ -1338,11 +1338,11 @@ def _check_and_canonicalize_out_shardings(
if not config.dynamic_shapes.value:
pjit_check_aval_sharding(
out_shardings_flat, out_avals,
debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
debug_info.safe_result_paths(len(out_avals)),
"pjit outputs", allow_uneven_sharding=False)
check_aval_layout_compatibility(
out_layouts_flat, out_avals,
debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type]
debug_info.safe_result_paths(len(out_avals)),
"jit outputs")
return out_shardings_flat, out_layouts_flat
@ -1396,10 +1396,9 @@ class IgnoreKey:
def pjit_check_aval_sharding(
shardings, flat_avals, names: tuple[str | None, ...] | None,
shardings, flat_avals, names: Sequence[str],
what_aval: str, allow_uneven_sharding: bool):
new_names = [None] * len(shardings) if names is None else names
for aval, s, name in zip(flat_avals, shardings, new_names):
for aval, s, name in zip(flat_avals, shardings, names):
if isinstance(s, (UnspecifiedValue, AUTO)):
continue
name_str = f' with pytree key path {name}' if name else ''
@ -1431,9 +1430,8 @@ def pjit_check_aval_sharding(
def check_aval_layout_compatibility(
layouts, flat_avals, names: tuple[str, ...] | None, what_aval: str):
new_names = [''] * len(layouts) if names is None else names
for aval, l, name in zip(flat_avals, layouts, new_names):
layouts, flat_avals, names: Sequence[str], what_aval: str):
for aval, l, name in zip(flat_avals, layouts, names):
if l is None or isinstance(l, AutoLayout):
continue
name_str = f' with pytree key path {name}' if name else ''
@ -2557,12 +2555,14 @@ def with_sharding_constraint(x, shardings):
for s in shardings_flat]
pjit_check_aval_sharding(
shardings_flat, x_flat, None, "with_sharding_constraint arguments",
shardings_flat, x_flat, ("",) * len(shardings_flat),
"with_sharding_constraint arguments",
allow_uneven_sharding=True)
check_shardings_are_auto(shardings_flat)
check_aval_layout_compatibility(user_layouts_flat, x_flat, None,
check_aval_layout_compatibility(user_layouts_flat, x_flat,
("",) * len(user_layouts_flat),
"with_sharding_constraint arguments")
outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l,

View File

@ -100,11 +100,12 @@ class TracerSpy:
try:
# We plan to do boolean conversion and catch the exception, but this works
# only for scalars
if t.shape:
t = jnp.sum(t)
if t:
t_scalar = t
while t_scalar.shape:
t_scalar = t_scalar[0]
if t_scalar:
pass
assert False, t
assert False, t_scalar
except Exception as e:
self.tracers.append((t, e))
@ -118,7 +119,6 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy: TracerSpy | None = None,
expected_tracer_debug_infos: list[str | re.Pattern] = [],
check_lowering: bool = True,
check_tracer_arg_name: bool = False,
expected_lowering_lines: list[str | re.Pattern] = [],
**kwargs) -> None:
"""Checks the expected debug info in all jaxprs, in spied tracers, and StableHLO.
@ -172,11 +172,14 @@ class DebugInfoTest(jtu.JaxTestCase):
for t, exc in tracer_spy.tracers:
if hasattr(t, "_debug_info"):
t_debug_info = _debug_info_to_string(t._debug_info)
if check_tracer_arg_name:
msg = str(exc)
m = re.match(r".* while tracing the function (.+) for ([^.]+)\.",
msg,
re.DOTALL)
msg = str(exc)
m = re.match(r".* while tracing the function (.+) for ([^.]+)\.",
msg,
re.DOTALL)
if m is None:
found_tracer_debug_infos.append(
f"{t_debug_info}, from None")
else:
self.assertIsNotNone(m, msg)
self.assertEqual(t._debug_info.func_src_info, m.group(1))
self.assertEqual(t._debug_info.traced_for, m.group(2))
@ -185,8 +188,6 @@ class DebugInfoTest(jtu.JaxTestCase):
re.DOTALL)
found_tracer_debug_infos.append(
f"{t_debug_info}, from {m.group(1) if m else None}")
else:
found_tracer_debug_infos.append(t_debug_info)
else:
found_tracer_debug_infos.append("None")
@ -646,9 +647,8 @@ class DebugInfoTest(jtu.JaxTestCase):
dict(a=1, b=2), 3,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=['c'],['d']"
"traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=result['c'],result['d']"
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, from x_dict['a']",
])
@ -669,10 +669,10 @@ class DebugInfoTest(jtu.JaxTestCase):
3,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
'traced_for=jit, fun=my_f, arg_names=a, result_paths=',
'traced_for=jit, fun=my_g, arg_names=b, result_paths=',
# TODO(necula): result_paths?
"traced_for=jit, fun=my_f, arg_names=a, result_paths=",
"traced_for=jit, fun=my_g, arg_names=b, result_paths=result",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_g, arg_names=b, from b",
"traced_for=jit, fun=my_f, arg_names=a, from a",
@ -690,10 +690,9 @@ class DebugInfoTest(jtu.JaxTestCase):
jax.jit(f),
{"ho": 1.}, {"hi": 2.}, 3., 4., z=5., w=6.,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=",
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=result",
],
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from kwargs['w']",
"traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from args[0]",
@ -703,7 +702,7 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"args\[1\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"kwargs\['w'\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg3: tensor<f..> loc\(\"kwargs\['z'\]\"\)"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"),
]
)
@ -721,10 +720,9 @@ class DebugInfoTest(jtu.JaxTestCase):
(1.,), {"hi": 2.}, 3., 4., 5., 6., # x, y, z, args[0], args[1], args[2]
t=11., w=12., # kwargs
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=",
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=result",
],
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']",
],
@ -732,7 +730,7 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"y\['hi'\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"args\[1\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"kwargs\['t'\]\"\)"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"),
])
def test_jit_arg_names_static_argnames(self):
@ -748,10 +746,9 @@ class DebugInfoTest(jtu.JaxTestCase):
(1.,), {'hi': 2.}, 3., 4., # x, y, args[0], args[1]
z=5., w=6., a=7., b=8., # kwargs
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=",
"traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=result",
],
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], from x[0]",
],
@ -760,7 +757,7 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"args\[1\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"kwargs\['b'\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg3: tensor<f..> loc\(\"kwargs\['w'\]\"\)"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"),
])
def test_jit_result_info(self):
@ -770,13 +767,13 @@ class DebugInfoTest(jtu.JaxTestCase):
jax.jit(f),
1., (2.,), [3.],
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=['a'],['b'][0][0]",
"traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=result['a'],result['b'][0][0]",
],
expected_lowering_lines=[
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"x\"\)"),
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"y\[0\]\"\)"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\['a'\]\"\}"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\['b'\]\[0\]\[0\]\"\}"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['a'\]\"\}"),
re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['b'\]\[0\]\[0\]\"\}"),
])
def test_nested_jit(self):
@ -795,14 +792,35 @@ class DebugInfoTest(jtu.JaxTestCase):
2, 3,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=",
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=[\'c\']"
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']"
],
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x,y",
"traced_for=jit, fun=my_g, arg_names=u,v"
"traced_for=jit, fun=my_f, arg_names=x,y, from x",
"traced_for=jit, fun=my_g, arg_names=u,v, from u"
])
def test_nested_jit_with_const_and_unused_args(self):
def my_f(x, y): # y is unused
def my_g(u, v): # v is unused
return v + np.ones(v.shape, v.dtype)
return x + jax.jit(my_g)(y, x)
x = y = np.ones((8,), dtype=np.float32)
self._check_tracers_and_jaxprs(
jax.jit(my_f),
x, y,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result"
],
expected_lowering_lines=[
re.compile(r".*func.func public @main\(%arg0: tensor<8xf..> loc\(\"x\"\)\)"),
re.compile(r".*call @my_g\(%arg.\) : \(tensor<8xf..>\)"),
]
)
def test_jvp_of_jit(self):
tracer_spy = TracerSpy()
def f(x, y, z):
@ -813,11 +831,11 @@ class DebugInfoTest(jtu.JaxTestCase):
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
expected_jaxpr_debug_infos=[
# TODO(necula): arg_names, result_paths
"traced_for=jit, fun=f, arg_names=None,None,None,None, result_paths=,,,",
"traced_for=jit, fun=f, arg_names=,,,, result_paths=,,,",
],
tracer_spy=tracer_spy,
expected_tracer_debug_infos=[
"traced_for=jit, fun=f, arg_names=x,y[0],z[0]",
"traced_for=jit, fun=f, arg_names=x,y[0],z[0], from x",
],
expected_lowering_lines=[
# TODO(necula): missing arg_names
@ -840,13 +858,12 @@ class DebugInfoTest(jtu.JaxTestCase):
lambda x, y, z: jax.vjp(jax.jit(my_f), x, y, z)[1](dict(a=x, b=[y])),
jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)],
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=",
re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths="),
# TODO(necula): arg_names?
"traced_for=jit, fun=my_f, arg_names=None,None,None,None, result_paths=['a'],['b'][0][0]",
"traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=result",
re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths=result"),
# TODO(necula): arg_names? result_paths?
"traced_for=jit, fun=my_f, arg_names=,,,, result_paths=['a'],['b'][0][0]",
],
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x,y[0],z[0], from y[0]",
],
@ -854,6 +871,7 @@ class DebugInfoTest(jtu.JaxTestCase):
# TODO(necula): missing arg_names
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(unknown\)"),
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(unknown\)"),
# TODO(necula): result_paths?
re.compile(r".*func.func public @main\(.*-> \(tensor<f..> {jax.result_info = \"\"}"),
])
@ -873,13 +891,12 @@ class DebugInfoTest(jtu.JaxTestCase):
2., 3., 0.3,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x,y,res_ct, result_paths=[0],[1]",
"traced_for=jit, fun=<lambda>, arg_names=x,y,res_ct, result_paths=result[0],result[1]",
# TODO(necula): result_paths
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=",
# TODO(necula): arg_names
"traced_for=jit, fun=my_g, arg_names=None,None,u,v, result_paths=['c'],['d']",
"traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
# TODO(necula): missing debug info
"None",
@ -889,8 +906,8 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"x\"\)"),
re.compile(r".*func.func public @main\(.*%arg1: tensor<f..> loc\(\"y\"\)"),
re.compile(r".*func.func public @main\(.*%arg2: tensor<f..> loc\(\"res_ct\"\)"),
re.compile(r".*func.func public @main\(.*jax.result_info = \"\[0\]\"}"),
re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"),
re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[0\]\"}"),
re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[1\]\"}"),
])
def test_vjp_remat(self):
@ -909,11 +926,10 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
# TODO(necula): what are these flat_index components?
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][0][0][0][0][0]",
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="),
"traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=result[0],result[1][0][0][0][0][0]",
re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"),
re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"),
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=checkpoint / remat, fun=to_remat, arg_names=x, from x",
"traced_for=jit, fun=apply_fn, arg_names=inp, from inp",
@ -942,11 +958,11 @@ class DebugInfoTest(jtu.JaxTestCase):
42.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=a, result_paths=[0],[1]",
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=",
"traced_for=jit, fun=<lambda>, arg_names=a, result_paths=result[0],result[1]",
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=result",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
# TODO(necula): from None?
"traced_for=jit, fun=<lambda>, arg_names=a, from None",
"traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, from y",
])
@ -976,10 +992,9 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
# TODO(necula): arg_names
"traced_for=jit, fun=<lambda>, arg_names=None,a,b, result_paths=[0],[1]",
"traced_for=custom_jvp fun, fun=my_g, arg_names=None,xy[0],xy[1], result_paths=",
"traced_for=jit, fun=<lambda>, arg_names=,a,b, result_paths=result[0],result[1]",
"traced_for=custom_jvp fun, fun=my_g, arg_names=,xy[0],xy[1], result_paths=result",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=custom_jvp fun, fun=my_g, arg_names=xy[0],xy[1], from xy[0]",
# TODO(necula): from None
@ -1010,10 +1025,9 @@ class DebugInfoTest(jtu.JaxTestCase):
{"a" : 3.},
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=['a']",
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=['b']",
"traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=result['a']",
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=result['b']",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']",
# TODO(necula): from None?
@ -1041,10 +1055,9 @@ class DebugInfoTest(jtu.JaxTestCase):
(3., 3.),
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], result_paths=[0],[1]",
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=",
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], result_paths=result[0],result[1]",
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=result",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=xy[0],xy[1], from xy[0]",
"traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]",
@ -1097,11 +1110,10 @@ class DebugInfoTest(jtu.JaxTestCase):
x,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=",
"traced_for=cond, fun=<lambda>, arg_names=x['c'], result_paths=",
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0][0][0],[0][0][1]",
"traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=result",
"traced_for=cond, fun=<lambda>, arg_names=x['c'], result_paths=result",
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result[0][0][0],result[0][0][1]",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from r",
"traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from x['c']",
@ -1129,11 +1141,10 @@ class DebugInfoTest(jtu.JaxTestCase):
x,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0]['c']",
"traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=['b']",
"traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=['c']",
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result[0]['c']",
"traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=result['b']",
"traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=result['c']",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
# TODO(necula): from None?
"traced_for=jit, fun=<lambda>, arg_names=x, from None",
@ -1167,11 +1178,11 @@ class DebugInfoTest(jtu.JaxTestCase):
xy,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=['a']",
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=result['a']",
],
expected_tracer_debug_infos=[
"traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y']",
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y']"
"traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']",
"traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']"
])
def test_cond(self):
@ -1192,13 +1203,13 @@ class DebugInfoTest(jtu.JaxTestCase):
0,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=",
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result",
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result",
],
expected_tracer_debug_infos=[
"traced_for=cond, fun=my_true_branch, arg_names=a,b",
"traced_for=cond, fun=my_false_branch, arg_names=c,d"
"traced_for=cond, fun=my_true_branch, arg_names=a,b, from a",
"traced_for=cond, fun=my_false_branch, arg_names=c,d, from c"
])
def test_switch(self):
@ -1220,15 +1231,15 @@ class DebugInfoTest(jtu.JaxTestCase):
2,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
"traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=",
"traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=",
"traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
"traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=result",
"traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=result",
"traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=result",
],
expected_tracer_debug_infos=[
"traced_for=switch, fun=my_branch0, arg_names=x0",
"traced_for=switch, fun=my_branch1, arg_names=x1",
"traced_for=switch, fun=my_branch2, arg_names=x2"
"traced_for=switch, fun=my_branch0, arg_names=x0, from x0",
"traced_for=switch, fun=my_branch1, arg_names=x1, from x1",
"traced_for=switch, fun=my_branch2, arg_names=x2, from x2"
])
def test_grad_cond_with_remat(self):
@ -1258,18 +1269,19 @@ class DebugInfoTest(jtu.JaxTestCase):
1., 2.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
'traced_for=jit, fun=my_f, arg_names=x,y, result_paths=',
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
# TODO(necula): arg_names? result_paths?
"traced_for=cond, fun=my_true_branch, arg_names=None, result_paths=,",
"traced_for=cond, fun=my_false_branch, arg_names=None, result_paths=,",
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=[0],[1]",
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=[0],[1]",
"traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=,",
"traced_for=cond, fun=my_true_branch, arg_names=, result_paths=,",
"traced_for=cond, fun=my_false_branch, arg_names=, result_paths=,",
"traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result[0],result[1]",
"traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result[0],result[1]",
"traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=,",
],
expected_tracer_debug_infos=[
'traced_for=cond, fun=my_true_branch, arg_names=a,b',
'traced_for=cond, fun=my_false_branch, arg_names=c,d',
'traced_for=checkpoint / remat, fun=my_g, arg_names=x,y',
"traced_for=cond, fun=my_true_branch, arg_names=a,b, from a",
"traced_for=cond, fun=my_false_branch, arg_names=c,d, from c",
# TODO(necula): from None
"traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None",
])
def test_grad_scan(self):
@ -1302,30 +1314,29 @@ class DebugInfoTest(jtu.JaxTestCase):
c, as_,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]",
# TODO(necula): bad result paths
"traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]",
# TODO(necula): arg names, bad result paths
"traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,",
# TODO(necula): arg_names?
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=",
"traced_for=for_loop, fun=f, arg_names=None,None,None, result_paths=,",
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None, result_paths=,",
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None, result_paths=",
"traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None,None,None,None,None, result_paths=,",
"traced_for=checkpoint / remat, fun=to_remat, arg_names=None,None,None, result_paths=,",
"traced_for=jit, fun=my_f, arg_names=None,None,x,as_, result_paths=",
"traced_for=for_loop, fun=f, arg_names=,,, result_paths=,",
"traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,",
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=",
"traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,",
"traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,",
"traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=",
],
expected_tracer_debug_infos=[
"traced_for=jit, fun=the_grad, arg_names=c,as_",
"traced_for=scan, fun=f, arg_names=c,a",
"traced_for=jit, fun=my_f, arg_names=x,as_",
# TODO(necula): arg_names
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2]",
"traced_for=jit, fun=the_grad, arg_names=c,as_, from c",
"traced_for=scan, fun=f, arg_names=c,a, from c",
"traced_for=jit, fun=my_f, arg_names=x,as_, from x",
# TODO(necula): arg_names, and "from x"
"traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], from refs[0]",
],
expected_lowering_lines=[
re.compile(r".*func.func public @main\(%arg0: tensor<f..> loc\(\"c\"\)"),
re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"),
re.compile(r".*func.func public @main\(.* -> .*tensor<f..> {jax.result_info = \"\[0\]\""),
re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"\[1\]\""),
re.compile(r".*func.func public @main\(.* -> .*tensor<f..> {jax.result_info = \"result\[0\]\""),
re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"result\[1\]\""),
# TODO(necula): unnamed function?
re.compile(r".*func.func private @None"),
])
@ -1348,11 +1359,10 @@ class DebugInfoTest(jtu.JaxTestCase):
0,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
'traced_for=while_body, fun=my_body, arg_names=b, result_paths=',
'traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=',
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
"traced_for=while_body, fun=my_body, arg_names=b, result_paths=result",
"traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=result",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=while_cond, fun=my_cond, arg_names=a, from a",
"traced_for=while_body, fun=my_body, arg_names=b, from b",
@ -1370,14 +1380,14 @@ class DebugInfoTest(jtu.JaxTestCase):
3.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
# TODO(necula): bad arg_names, result_paths
'traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=[0][0],[0][1]',
"traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=result[0][0],result[0][1]",
],
expected_tracer_debug_infos=[
# TODO(necula): the arg_names are not right
"traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1]",
"traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], from loop_carry[1]",
]
)
@ -1389,14 +1399,14 @@ class DebugInfoTest(jtu.JaxTestCase):
5, 3.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=ub,x, result_paths=",
re.compile(r'traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths='),
"traced_for=jit, fun=<lambda>, arg_names=ub,x, result_paths=result",
re.compile(r"traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths="),
# TODO(necula): arg_names and result_paths are not right
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=[0],[1],[2]",
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=result[0],result[1],result[2]",
],
expected_tracer_debug_infos=[
# TODO(necula): the arg_names are not right
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2]",
"traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], from loop_carry[2]",
])
def test_scan(self):
@ -1413,10 +1423,9 @@ class DebugInfoTest(jtu.JaxTestCase):
np.arange(8, dtype=np.int32),
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=[0],[1]",
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=[0],[1]",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result[0],result[1]",
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=result[0],result[1]",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=scan, fun=my_scan_body, arg_names=carry,inp, from carry"
])
@ -1433,7 +1442,7 @@ class DebugInfoTest(jtu.JaxTestCase):
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[],
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x"],
"traced_for=jit, fun=my_f, arg_names=x, from x"],
)
def test_vmap_of_nested_jit(self):
@ -1452,13 +1461,13 @@ class DebugInfoTest(jtu.JaxTestCase):
np.ones((8,), dtype=np.float32), np.zeros((8,), dtype=np.float32),
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=",
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=['c']",
"traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result",
"traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']",
],
expected_tracer_debug_infos=[
# TODO(necula): missing debug info
"None",
"traced_for=jit, fun=my_g, arg_names=u,v"
"traced_for=jit, fun=my_g, arg_names=u,v, from u"
])
def test_pmap(self):
@ -1471,11 +1480,11 @@ class DebugInfoTest(jtu.JaxTestCase):
jax.pmap(my_f),
np.ones((jax.device_count(),), dtype=np.float32),
expected_jaxpr_debug_infos=[
"traced_for=pmap, fun=my_f, arg_names=x, result_paths="
"traced_for=pmap, fun=my_f, arg_names=x, result_paths=result"
],
tracer_spy=tracer_spy,
expected_tracer_debug_infos=[
"traced_for=pmap, fun=my_f, arg_names=x"
"traced_for=pmap, fun=my_f, arg_names=x, from x"
],
)
@ -1493,10 +1502,9 @@ class DebugInfoTest(jtu.JaxTestCase):
1., x, x, x, # x, y, args[0], args[1]
d=x, a=x, b=x, # kwargs
expected_jaxpr_debug_infos=[
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=['u'],['v']",
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']",
],
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]",
],
@ -1508,8 +1516,8 @@ class DebugInfoTest(jtu.JaxTestCase):
re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"),
re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"),
re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"),
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['u'\]\"\}"),
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['v'\]\"\}"),
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"),
re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"),
]
)
@ -1523,7 +1531,7 @@ class DebugInfoTest(jtu.JaxTestCase):
jax.pmap(jax.grad(my_f)),
np.ones((jax.device_count(),), dtype=np.float32),
expected_jaxpr_debug_infos=[
"traced_for=pmap, fun=my_f, arg_names=x, result_paths=",
"traced_for=pmap, fun=my_f, arg_names=x, result_paths=result",
],
tracer_spy=tracer_spy,
expected_tracer_debug_infos=[
@ -1550,12 +1558,12 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
# TODO(necula): why this?
re.compile(r'traced_for=jit, fun=_multi_slice at .*array_methods.py:.*, arg_names=self, result_paths=.*'),
"traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=['u'],['v']",
"traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=result['u'],result['v']",
],
tracer_spy=tracer_spy,
expected_tracer_debug_infos=[
# TODO(necula): missing debug_info
'None'
"None"
],
)
@ -1574,13 +1582,13 @@ class DebugInfoTest(jtu.JaxTestCase):
jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))),
x, x_tan,
expected_jaxpr_debug_infos=[
'traced_for=jit, fun=<lambda>, arg_names=x,x_tan, result_paths=[0],[1]',
"traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=",
"traced_for=jit, fun=<lambda>, arg_names=x,x_tan, result_paths=result[0],result[1]",
"traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=result",
],
tracer_spy=tracer_spy,
expected_tracer_debug_infos=[
# TODO(necula): missing debug_info
'None'
"None"
],
)
@ -1597,13 +1605,12 @@ class DebugInfoTest(jtu.JaxTestCase):
jax.jit(jax.hessian(jax.jit(my_f))),
x,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
# TODO(necula): arg_names and result_paths?
"traced_for=jit, fun=my_f, arg_names=None,x, result_paths=,",
"traced_for=jit, fun=my_f, arg_names=,x, result_paths=,",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,",
],
tracer_spy=tracer_spy,
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, from x",
],
@ -1626,11 +1633,10 @@ class DebugInfoTest(jtu.JaxTestCase):
0.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
# TODO(necula): missing result_paths
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=",
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=result",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y"
])
@ -1650,12 +1656,12 @@ class DebugInfoTest(jtu.JaxTestCase):
0.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
# TODO(necula): arg_names?
"traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
# TODO(necula): arg_names? result_paths?
"traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=",
],
expected_tracer_debug_infos=[
"traced_for=checkpoint / remat, fun=my_g, arg_names=y"
"traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y",
])
def test_remat_shard_map(self):
@ -1677,13 +1683,12 @@ class DebugInfoTest(jtu.JaxTestCase):
jnp.arange(2, dtype=np.float32),
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
# TODO(necula): arg_names
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
"traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=",
"traced_for=shard_map, fun=my_f, arg_names=x, result_paths=",
"traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=",
# TODO(necula): arg_names, result_paths
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
"traced_for=checkpoint / remat, fun=my_f, arg_names=,, result_paths=",
"traced_for=shard_map, fun=my_f, arg_names=x, result_paths=result",
"traced_for=shard_map, fun=my_f, arg_names=,, result_paths=",
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
"None" # TODO(necula): missing
])
@ -1719,11 +1724,11 @@ class DebugInfoTest(jtu.JaxTestCase):
expected_jaxpr_debug_infos=[
# TODO(necula): this should not be pointing into the JAX internals
re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\]"),
re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths="),
"traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=[0]",
re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths=result"),
"traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=result[0]",
],
expected_tracer_debug_infos=[
"traced_for=pmap, fun=my_f, arg_names=my_x",
"traced_for=pmap, fun=my_f, arg_names=my_x, from my_x",
],
check_lowering=False, # TODO(necula): warning during lowering
)
@ -1751,12 +1756,12 @@ class DebugInfoTest(jtu.JaxTestCase):
0.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
"traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=[0],[1]",
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
"traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=result[0],result[1]",
],
expected_tracer_debug_infos=[
# TODO(necula): no leaked tracer from my_g_dce?
"traced_for=custom_dce, fun=my_g, arg_names=x",
"traced_for=custom_dce, fun=my_g, arg_names=x, from x",
])
def test_custom_dce_consts(self):
@ -1779,11 +1784,10 @@ class DebugInfoTest(jtu.JaxTestCase):
np.array(1.1234),
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=",
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result",
# TODO(necula): bad arg_names (why None), bad result_paths
'traced_for=custom_dce, fun=my_f, arg_names=None,x, result_paths=[0],[1]',
'traced_for=custom_dce, fun=my_f, arg_names=,x, result_paths=result[0],result[1]',
],
check_tracer_arg_name=True,
expected_tracer_debug_infos=[
# TODO(necula): no leaked tracer from my_rule?
"traced_for=custom_dce, fun=my_f, arg_names=x, from x",
@ -1816,23 +1820,22 @@ class DebugInfoTest(jtu.JaxTestCase):
a, b,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=a,b, result_paths=[0],[1]",
re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths="),
re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths="),
re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths="),
"traced_for=jit, fun=<lambda>, arg_names=a,b, result_paths=result[0],result[1]",
re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths=result"),
re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths=result"),
re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths=result"),
# TODO(necula): why pointers to internal functions, arg_names, result_paths?
re.compile(r'traced_for=custom_linear_solve solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,None,x, result_paths='),
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,None,x, result_paths='),
re.compile(r'traced_for=custom_linear_solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,x, result_paths='),
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=None,x, result_paths='),
'traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=None,b, result_paths=',
'traced_for=custom_linear_solve solve, fun=my_solve, arg_names=None,x, result_paths=',
'traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=None,x, result_paths=',
re.compile(r'traced_for=custom_linear_solve solve, fun=<lambda> at .*linalg.py:.*, arg_names=,,x, result_paths='),
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=,,x, result_paths='),
re.compile(r'traced_for=custom_linear_solve, fun=<lambda> at .*linalg.py:.*, arg_names=,x, result_paths='),
re.compile(r'traced_for=custom_linear_solve transpose_solve, fun=<lambda> at .*linalg.py:.*, arg_names=,x, result_paths='),
"traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=,b, result_paths=result",
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=,x, result_paths=result",
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=,x, result_paths=result",
],
expected_tracer_debug_infos=[
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x",
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x",
"traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=b",
"traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x, from x",
"traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x, from x",
"None", # TODO(necula): there are missing debug info
])
@ -1856,14 +1859,15 @@ class DebugInfoTest(jtu.JaxTestCase):
0.,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=[0],[1]",
"traced_for=jit, fun=<lambda>, arg_names=x, result_paths=result[0],result[1]",
# TODO(necula): internal function?
re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=\[0\]"),
re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=result\[0\]"),
],
expected_tracer_debug_infos=[
"traced_for=custom_root, fun=my_f, arg_names=x",
"traced_for=custom_root solve, fun=my_solve, arg_names=x",
"traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x",
"traced_for=custom_root, fun=my_f, arg_names=x, from x",
"traced_for=custom_root solve, fun=my_solve, arg_names=x, from x",
# TODO(necula): from None
"traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x, from None",
"None", # TODO(necula): there are missing debug info
])
@ -1891,13 +1895,13 @@ class DebugInfoTest(jtu.JaxTestCase):
jax.jit(my_f), x,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=x, result_paths=",
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=[0],[1]",
"traced_for=jit, fun=my_f, arg_names=x, result_paths=result",
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=result[0],result[1]",
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, result_paths=",
],
expected_tracer_debug_infos=[
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j",
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref",
"traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, from i",
"traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, from x_ref",
],
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
)
@ -1921,14 +1925,14 @@ class DebugInfoTest(jtu.JaxTestCase):
jnp.arange(4, dtype=jnp.float32) - 2,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_f, arg_names=input, result_paths=",
"traced_for=jit, fun=my_f, arg_names=input, result_paths=result",
# TODO(necula): function source location points in JAX internals
# TODO(necula): arg_names and result_paths are wrong
re.compile(r"traced_for=checkify_pallas, fun=checked_kernel_fn at .*pallas_call.py:.*, arg_names=args\[0\],.*, result_paths="),
re.compile(r"traced_for=pallas_call index_map, fun=<lambda> at .*pallas.core.py:.*, arg_names=, result_paths="),
],
expected_tracer_debug_infos=[
"traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref",
"traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref, from x_ref",
],
check_lowering=False, # We need interpret mode on CPU. TODO(necula)
)
@ -1948,11 +1952,11 @@ class DebugInfoTest(jtu.JaxTestCase):
jax.jit(my_consts), x,
tracer_spy=tracer_spy,
expected_jaxpr_debug_infos=[
"traced_for=jit, fun=my_consts, arg_names=x, result_paths=",
"traced_for=composite, fun=my_consts, arg_names=x, result_paths=",
"traced_for=jit, fun=my_consts, arg_names=x, result_paths=result",
"traced_for=composite, fun=my_consts, arg_names=x, result_paths=result",
],
expected_tracer_debug_infos=[
"traced_for=composite, fun=my_consts, arg_names=x"])
"traced_for=composite, fun=my_consts, arg_names=x, from x"])
if __name__ == '__main__':

View File

@ -274,13 +274,13 @@ class MutableArrayErrorsTest(jtu.JaxTestCase):
def test_return_from_jit_pytree(self):
with self.assertRaisesRegex(
ValueError,
r"tree path \['hi'\]"):
r"tree path result\['hi'\]"):
jax.jit(lambda x_ref: {'hi': x_ref})(core.mutable_array(jnp.arange(3)))
def test_return_from_jit_closure(self):
with self.assertRaisesRegex(
ValueError,
r"tree path \['hi'\]"):
r"tree path result\['hi'\]"):
x_ref = core.mutable_array(jnp.arange(3))
jax.jit(lambda: {'hi': x_ref})()

View File

@ -6846,7 +6846,7 @@ class PJitErrorTest(jtu.JaxTestCase):
spec = P(resources, None)
mesh_size = str(math.prod([dim[1] for dim in mesh]))
error = re.compile(
r"One of pjit outputs with pytree key path \['rrr'\].*" + spec_regex(spec) + r".*"
r"One of pjit outputs with pytree key path result\['rrr'\].*" + spec_regex(spec) + r".*"
r"implies that the global size of its dimension 0 should be "
r"divisible by " + mesh_size + r", but it is equal to 3", re.M | re.S)
with self.assertRaisesRegex(ValueError, error):