diff --git a/docs/aot.md b/docs/aot.md index 1f24d64fa..1fcf11ab9 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -56,7 +56,7 @@ some other features along the way. An example: >>> # Print lowered HLO >>> print(lowered.as_text()) module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor, %arg1: tensor) -> (tensor {jax.result_info = ""}) { + func.func public @main(%arg0: tensor, %arg1: tensor) -> (tensor {jax.result_info = "result"}) { %c = stablehlo.constant dense<2> : tensor %0 = stablehlo.multiply %c, %arg0 : tensor %1 = stablehlo.add %0, %arg1 : tensor @@ -140,7 +140,7 @@ to invoke the resulting compiled function. Continuing with our example above: >>> # Lowered HLO, specialized to the *value* of the first argument (7) >>> print(lowered_with_x.as_text()) module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor) -> (tensor {jax.result_info = ""}) { + func.func public @main(%arg0: tensor) -> (tensor {jax.result_info = "result"}) { %c = stablehlo.constant dense<14> : tensor %0 = stablehlo.add %c, %arg0 : tensor return %0 : tensor diff --git a/docs/export/export.md b/docs/export/export.md index bab4723f3..f1542d80d 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -44,7 +44,7 @@ Here is an example: (ShapedArray(float32[]),) >>> print(re.search(r".*@main.*", exported.mlir_module()).group(0)) - func.func public @main(%arg0: tensor loc("x")) -> (tensor {jax.result_info = ""}) { + func.func public @main(%arg0: tensor loc("x")) -> (tensor {jax.result_info = "result"}) { >>> # And you can serialize the Exported to a bytearray. >>> serialized: bytearray = exported.serialize() @@ -206,7 +206,7 @@ as in the following example: >>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results) >>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir()) module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor) -> (tensor {jax.result_info = ""}) { + func.func public @main(%arg0: tensor) -> (tensor {jax.result_info = "result"}) { %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor) -> tensor return %0 : tensor } diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 61576bb50..f1c0078cd 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -458,7 +458,7 @@ def saved_residuals(f: Callable, return _saved_residuals(jaxpr, debug_info.arg_names) def _saved_residuals(jaxpr: core.Jaxpr, - arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]: + arg_names: Sequence[str]) -> list[tuple[core.AbstractValue, str]]: res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)] res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)} @@ -473,7 +473,7 @@ def _saved_residuals(jaxpr: core.Jaxpr, for i, v in enumerate(jaxpr.invars): if v in res_vars: - if arg_names[i] is not None: + if arg_names[i]: src = f'from the argument {arg_names[i]}' else: src = 'from the argument at flattened index {i}' diff --git a/jax/_src/api.py b/jax/_src/api.py index 0055f6466..d24ea7cd7 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2274,7 +2274,7 @@ def _check_sharding(aval, s): aval = core.get_token_aval() if not isinstance(s, PmapSharding): pjit.pjit_check_aval_sharding( - (s,), (aval,), None, "device_put args", allow_uneven_sharding=False) + (s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False) s.shard_shape(aval.shape) # should raise an Error if incompatible diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index a597e8b5b..1fd371034 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -662,7 +662,7 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None, args: Sequence[Any], kwargs: dict[str, Any], static_argnums: Sequence[int], static_argnames: Sequence[str], - ) -> tuple[str | None, ...]: + ) -> tuple[str, ...]: """Returns the names of the non-static arguments. If the `fn_signature` is given then we get from it the names of the diff --git a/jax/_src/core.py b/jax/_src/core.py index d8f91789b..cd08d12ef 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -149,8 +149,8 @@ class Jaxpr: debug_info = debug_info or lu._missing_debug_info("core.Jaxpr") self._debug_info = debug_info.resolve_result_paths() # TODO(necula): re-enable these safety checks - # assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars) - # assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) + # assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars) + # assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) def __str__(self): return str(self.pretty_print()) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 4a0e6ca46..37ad40d22 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -135,7 +135,7 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents): def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr: dbg = jaxpr.debug_info._replace( - arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars)) + arg_names=jaxpr.debug_info.arg_names + ("",) * len(jaxpr.constvars)) return core.Jaxpr(constvars=(), invars=jaxpr.invars + jaxpr.constvars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index b4d7b104a..4c722dedb 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1113,7 +1113,7 @@ def lower_jaxpr_to_module( result_shardings: Sequence[JSharding | AUTO | None] | None = None, in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, - arg_names: Sequence[str | None] | None = None, + arg_names: Sequence[str] | None = None, result_names: Sequence[str] | None = None, num_replicas: int = 1, num_partitions: int = 1, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 8d27f11c0..6fde73705 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -842,7 +842,7 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: """Moves the constvars to the start of invars.""" config.enable_checks.value and core.check_jaxpr(jaxpr) dbg = jaxpr.debug_info._replace( - arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) + arg_names=("",) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) lifted_jaxpr = Jaxpr(constvars=(), invars=jaxpr.constvars + jaxpr.invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, @@ -1574,7 +1574,7 @@ class DynamicJaxprTracer(core.Tracer): origin = ("The error occurred while tracing the function " f"{dbg.func_src_info} for {dbg.traced_for}. ") - if invar_pos and dbg.arg_names: + if invar_pos: try: arg_names = [dbg.arg_names[i] for i in invar_pos] except IndexError: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f38084120..9baa5e977 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3260,7 +3260,7 @@ def check_array_xla_sharding_layout_match( from jax._src.array import ArrayImpl # jaxpr_debug_info.arg_names are before DCE, so need to DCE them. arg_names = ( - [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore + [a for i, a in enumerate(jaxpr_debug_info.arg_names) if i in kept_var_idx] ) errors = [] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index af01dc249..861acdd42 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1554,7 +1554,7 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, cond_debug = cond_jaxpr.jaxpr.debug_info augmented_debug = cond_debug and ( cond_debug._replace( - arg_names=cond_debug.arg_names + (None,) * len(init_dot) + arg_names=cond_debug.arg_names + ("",) * len(init_dot) ) ) cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars, diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index d1439f8a4..b8272fa0f 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -276,30 +276,43 @@ class DebugInfo(NamedTuple): """Debugging info about a func, its arguments, and results.""" traced_for: str # e.g. 'jit', 'scan', etc - # e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have - # no source location information. The first word is always the function name, - # which may be ''. func_src_info: str + """e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have + no source location information. The first word is always the function name, + which may be ''. + """ - # The paths of the flattened non-static argnames, - # e.g. ('x', 'dict_arg["a"]', ... ). - # Uses `None` for the args that do not correspond to user-named arguments, - # e.g., tangent args in jax.jvp. At the moment, `arg_names` accuracy is - # best-effort. Use `safe_arg_names` to detect and handle an unexpected - # number of elements in `arg_names`. - arg_names: tuple[str | None, ...] + arg_names: tuple[str, ...] + """The paths of the flattened non-static argnames, + e.g. `('x', 'dict_arg["a"]', ... )`. + Uses the empty string for the args that do not correspond to + user-named arguments, e.g., tangent args in `jax.jvp`, or for arguments that + we are not yet tracking properly. + At the moment, `arg_names` accuracy is best-effort. + Use `safe_arg_names` to detect and handle an unexpected + number of elements in `arg_names`. + """ - # The result paths are not available while we are tracing the function, - # instead we keep a thunk. Once we are done tracing, we use - # `self.resolve_result_paths()` to execute the thunk and replace the - # actual result paths. At the moment, `result_paths` accuracy is - # best-effort. Use `safe_result_paths` to detect and handle an unexpected - # number of elements in `result_paths`. - # e.g. ('[0]', '[1]', ...) result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None + """The paths to the flattened results, e.g., `('result[0]', result[1])` for a + function that returns a tuple of arrays, or `(result,)` for a function that + returns a single array. + The result paths are not available while we are tracing the function, + instead we keep a thunk. It is possible for the result paths to be `None` + only when we first create a `DebugInfo`, before we put it in `lu.WrappedFun` + and before we start tracing. + Inside a `lu.WrappedFun` it can be only a thunk or a tuple of strings. + Once we are done tracing, we use + `self.resolve_result_paths()` to execute the thunk and replace the + actual result paths. + At the moment, `result_paths` accuracy is best-effort. + Use `safe_result_paths` to detect and handle an unexpected + number of elements in `result_paths`. + """ def resolve_result_paths(self) -> DebugInfo: """Return a debug info with resolved result paths.""" + assert self.result_paths is not None if callable(self.result_paths): return self._replace(result_paths=tuple(self.result_paths())) return self @@ -308,21 +321,21 @@ class DebugInfo(NamedTuple): def func_name(self) -> str: return self.func_src_info.split(" ")[0] - def safe_arg_names(self, expected: int) -> tuple[str | None, ...]: + def safe_arg_names(self, expected: int) -> tuple[str, ...]: """Get the arg_names with a safety check.""" if len(self.arg_names) == expected: return self.arg_names else: # TODO(necula): this should not happen - return (None,) * expected + return ("",) * expected - def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str | None, ...]: + def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str, ...]: """Keep only the arg_names for which `keep` is True.""" return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b) def safe_result_paths(self, expected: int) -> tuple[str, ...]: """Get the result paths with a safety check.""" - assert not callable(self.result_paths), self + assert self.result_paths is not None and not callable(self.result_paths), self if self.result_paths is not None and len(self.result_paths) == expected: return self.result_paths else: @@ -331,7 +344,7 @@ class DebugInfo(NamedTuple): def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]: """Keep only the result_paths for which `keep` is True.""" - assert not callable(self.result_paths), self + assert self.result_paths is not None and not callable(self.result_paths), self return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b) @@ -368,7 +381,7 @@ def _clean_keystr_arg_names(k: KeyPath) -> str: @transformation_with_aux2 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs): ans = _fun(*args, **kwargs) - result_paths = [_clean_keystr_arg_names(path) for path, _ in generate_key_paths(ans)] + result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans)) if _store: # In some instances a lu.WrappedFun is called multiple times, e.g., # the bwd function in a custom_vjp diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index be44c2fe1..ac8dd8f6e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -538,7 +538,7 @@ class PjitParams(NamedTuple): in_tree: PyTreeDef out_tree: PyTreeDef donated_invars: tuple[bool, ...] - arg_names: tuple[str | None, ...] + arg_names: tuple[str, ...] num_consts: int attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] @@ -663,7 +663,7 @@ def _infer_params_impl( compiler_options_kvs=ji.compiler_options_kvs, ) return PjitParams(consts, params, in_avals, in_tree, out_tree(), - donated_invars, dbg.arg_names if dbg else None, len(consts), + donated_invars, dbg.arg_names, len(consts), attrs_tracked), args_flat @@ -741,13 +741,13 @@ def _infer_input_type(fun: Callable, dbg: core.DebugInfo, for i, x in enumerate(explicit_args): avals.append(core.shaped_abstractify(x)) except OverflowError: - arg_path = f"argument path is {dbg.arg_names[i]}" # type: ignore + arg_path = f"argument path is {dbg.arg_names[i]}" raise OverflowError( "An overflow was encountered while parsing an argument to a jitted " f"computation, whose {arg_path}." ) from None except TypeError: - arg_description = f"path {dbg.arg_names[i]}" # type: ignore + arg_description = f"path {dbg.arg_names[i]}" raise TypeError( f"Error interpreting argument to {fun} as an abstract array." f" The problematic value is of type {type(x)} and was passed to" @@ -1134,7 +1134,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, "pjit in_layouts", in_tree, in_layouts, tupled_args=True) # TODO(dougalm,mattjj): enable debug info with attrs_tracked - attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals) + attrs_tracked = len(debug_info.arg_names) != len(in_avals) if not config.dynamic_shapes.value and not attrs_tracked: pjit_check_aval_sharding(in_shardings_flat, in_avals, debug_info.safe_arg_names(len(in_avals)), @@ -1338,11 +1338,11 @@ def _check_and_canonicalize_out_shardings( if not config.dynamic_shapes.value: pjit_check_aval_sharding( out_shardings_flat, out_avals, - debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] + debug_info.safe_result_paths(len(out_avals)), "pjit outputs", allow_uneven_sharding=False) check_aval_layout_compatibility( out_layouts_flat, out_avals, - debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] + debug_info.safe_result_paths(len(out_avals)), "jit outputs") return out_shardings_flat, out_layouts_flat @@ -1396,10 +1396,9 @@ class IgnoreKey: def pjit_check_aval_sharding( - shardings, flat_avals, names: tuple[str | None, ...] | None, + shardings, flat_avals, names: Sequence[str], what_aval: str, allow_uneven_sharding: bool): - new_names = [None] * len(shardings) if names is None else names - for aval, s, name in zip(flat_avals, shardings, new_names): + for aval, s, name in zip(flat_avals, shardings, names): if isinstance(s, (UnspecifiedValue, AUTO)): continue name_str = f' with pytree key path {name}' if name else '' @@ -1431,9 +1430,8 @@ def pjit_check_aval_sharding( def check_aval_layout_compatibility( - layouts, flat_avals, names: tuple[str, ...] | None, what_aval: str): - new_names = [''] * len(layouts) if names is None else names - for aval, l, name in zip(flat_avals, layouts, new_names): + layouts, flat_avals, names: Sequence[str], what_aval: str): + for aval, l, name in zip(flat_avals, layouts, names): if l is None or isinstance(l, AutoLayout): continue name_str = f' with pytree key path {name}' if name else '' @@ -2557,12 +2555,14 @@ def with_sharding_constraint(x, shardings): for s in shardings_flat] pjit_check_aval_sharding( - shardings_flat, x_flat, None, "with_sharding_constraint arguments", + shardings_flat, x_flat, ("",) * len(shardings_flat), + "with_sharding_constraint arguments", allow_uneven_sharding=True) check_shardings_are_auto(shardings_flat) - check_aval_layout_compatibility(user_layouts_flat, x_flat, None, + check_aval_layout_compatibility(user_layouts_flat, x_flat, + ("",) * len(user_layouts_flat), "with_sharding_constraint arguments") outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l, diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index de5bb0bc9..8ad235daf 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -100,11 +100,12 @@ class TracerSpy: try: # We plan to do boolean conversion and catch the exception, but this works # only for scalars - if t.shape: - t = jnp.sum(t) - if t: + t_scalar = t + while t_scalar.shape: + t_scalar = t_scalar[0] + if t_scalar: pass - assert False, t + assert False, t_scalar except Exception as e: self.tracers.append((t, e)) @@ -118,7 +119,6 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy: TracerSpy | None = None, expected_tracer_debug_infos: list[str | re.Pattern] = [], check_lowering: bool = True, - check_tracer_arg_name: bool = False, expected_lowering_lines: list[str | re.Pattern] = [], **kwargs) -> None: """Checks the expected debug info in all jaxprs, in spied tracers, and StableHLO. @@ -172,11 +172,14 @@ class DebugInfoTest(jtu.JaxTestCase): for t, exc in tracer_spy.tracers: if hasattr(t, "_debug_info"): t_debug_info = _debug_info_to_string(t._debug_info) - if check_tracer_arg_name: - msg = str(exc) - m = re.match(r".* while tracing the function (.+) for ([^.]+)\.", - msg, - re.DOTALL) + msg = str(exc) + m = re.match(r".* while tracing the function (.+) for ([^.]+)\.", + msg, + re.DOTALL) + if m is None: + found_tracer_debug_infos.append( + f"{t_debug_info}, from None") + else: self.assertIsNotNone(m, msg) self.assertEqual(t._debug_info.func_src_info, m.group(1)) self.assertEqual(t._debug_info.traced_for, m.group(2)) @@ -185,8 +188,6 @@ class DebugInfoTest(jtu.JaxTestCase): re.DOTALL) found_tracer_debug_infos.append( f"{t_debug_info}, from {m.group(1) if m else None}") - else: - found_tracer_debug_infos.append(t_debug_info) else: found_tracer_debug_infos.append("None") @@ -646,9 +647,8 @@ class DebugInfoTest(jtu.JaxTestCase): dict(a=1, b=2), 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=['c'],['d']" + "traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=result['c'],result['d']" ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, from x_dict['a']", ]) @@ -669,10 +669,10 @@ class DebugInfoTest(jtu.JaxTestCase): 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - 'traced_for=jit, fun=my_f, arg_names=a, result_paths=', - 'traced_for=jit, fun=my_g, arg_names=b, result_paths=', + # TODO(necula): result_paths? + "traced_for=jit, fun=my_f, arg_names=a, result_paths=", + "traced_for=jit, fun=my_g, arg_names=b, result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_g, arg_names=b, from b", "traced_for=jit, fun=my_f, arg_names=a, from a", @@ -690,10 +690,9 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(f), {"ho": 1.}, {"hi": 2.}, 3., 4., z=5., w=6., expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=", + "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=result", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from kwargs['w']", "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from args[0]", @@ -703,7 +702,7 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"kwargs\['w'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg3: tensor loc\(\"kwargs\['z'\]\"\)"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ] ) @@ -721,10 +720,9 @@ class DebugInfoTest(jtu.JaxTestCase): (1.,), {"hi": 2.}, 3., 4., 5., 6., # x, y, z, args[0], args[1], args[2] t=11., w=12., # kwargs expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=", + "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=result", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']", ], @@ -732,7 +730,7 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"y\['hi'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"kwargs\['t'\]\"\)"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) def test_jit_arg_names_static_argnames(self): @@ -748,10 +746,9 @@ class DebugInfoTest(jtu.JaxTestCase): (1.,), {'hi': 2.}, 3., 4., # x, y, args[0], args[1] z=5., w=6., a=7., b=8., # kwargs expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=", + "traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=result", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], from x[0]", ], @@ -760,7 +757,7 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"kwargs\['b'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg3: tensor loc\(\"kwargs\['w'\]\"\)"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) def test_jit_result_info(self): @@ -770,13 +767,13 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(f), 1., (2.,), [3.], expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=['a'],['b'][0][0]", + "traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=result['a'],result['b'][0][0]", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"x\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"y\[0\]\"\)"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\['a'\]\"\}"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\['b'\]\[0\]\[0\]\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['a'\]\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['b'\]\[0\]\[0\]\"\}"), ]) def test_nested_jit(self): @@ -795,14 +792,35 @@ class DebugInfoTest(jtu.JaxTestCase): 2, 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=[\'c\']" + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']" ], expected_tracer_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y", - "traced_for=jit, fun=my_g, arg_names=u,v" + "traced_for=jit, fun=my_f, arg_names=x,y, from x", + "traced_for=jit, fun=my_g, arg_names=u,v, from u" ]) + def test_nested_jit_with_const_and_unused_args(self): + def my_f(x, y): # y is unused + def my_g(u, v): # v is unused + return v + np.ones(v.shape, v.dtype) + + return x + jax.jit(my_g)(y, x) + + x = y = np.ones((8,), dtype=np.float32) + self._check_tracers_and_jaxprs( + jax.jit(my_f), + x, y, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result" + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor<8xf..> loc\(\"x\"\)\)"), + re.compile(r".*call @my_g\(%arg.\) : \(tensor<8xf..>\)"), + ] + ) + def test_jvp_of_jit(self): tracer_spy = TracerSpy() def f(x, y, z): @@ -813,11 +831,11 @@ class DebugInfoTest(jtu.JaxTestCase): jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ # TODO(necula): arg_names, result_paths - "traced_for=jit, fun=f, arg_names=None,None,None,None, result_paths=,,,", + "traced_for=jit, fun=f, arg_names=,,,, result_paths=,,,", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ - "traced_for=jit, fun=f, arg_names=x,y[0],z[0]", + "traced_for=jit, fun=f, arg_names=x,y[0],z[0], from x", ], expected_lowering_lines=[ # TODO(necula): missing arg_names @@ -840,13 +858,12 @@ class DebugInfoTest(jtu.JaxTestCase): lambda x, y, z: jax.vjp(jax.jit(my_f), x, y, z)[1](dict(a=x, b=[y])), jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=", - re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths="), - # TODO(necula): arg_names? - "traced_for=jit, fun=my_f, arg_names=None,None,None,None, result_paths=['a'],['b'][0][0]", + "traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=result", + re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths=result"), + # TODO(necula): arg_names? result_paths? + "traced_for=jit, fun=my_f, arg_names=,,,, result_paths=['a'],['b'][0][0]", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y[0],z[0], from y[0]", ], @@ -854,6 +871,7 @@ class DebugInfoTest(jtu.JaxTestCase): # TODO(necula): missing arg_names re.compile(r".*func.func public @main\(%arg0: tensor loc\(unknown\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(unknown\)"), + # TODO(necula): result_paths? re.compile(r".*func.func public @main\(.*-> \(tensor {jax.result_info = \"\"}"), ]) @@ -873,13 +891,12 @@ class DebugInfoTest(jtu.JaxTestCase): 2., 3., 0.3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=[0],[1]", + "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", # TODO(necula): result_paths "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", # TODO(necula): arg_names - "traced_for=jit, fun=my_g, arg_names=None,None,u,v, result_paths=['c'],['d']", + "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ # TODO(necula): missing debug info "None", @@ -889,8 +906,8 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"x\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"y\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"res_ct\"\)"), - re.compile(r".*func.func public @main\(.*jax.result_info = \"\[0\]\"}"), - re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"), + re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[0\]\"}"), + re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[1\]\"}"), ]) def test_vjp_remat(self): @@ -909,11 +926,10 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): what are these flat_index components? - "traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][0][0][0][0][0]", - re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="), - re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="), + "traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=result[0],result[1][0][0][0][0][0]", + re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"), + re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"), ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=to_remat, arg_names=x, from x", "traced_for=jit, fun=apply_fn, arg_names=inp, from inp", @@ -942,11 +958,11 @@ class DebugInfoTest(jtu.JaxTestCase): 42., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=a, result_paths=[0],[1]", - "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=", + "traced_for=jit, fun=, arg_names=a, result_paths=result[0],result[1]", + "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ + # TODO(necula): from None? "traced_for=jit, fun=, arg_names=a, from None", "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, from y", ]) @@ -976,10 +992,9 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): arg_names - "traced_for=jit, fun=, arg_names=None,a,b, result_paths=[0],[1]", - "traced_for=custom_jvp fun, fun=my_g, arg_names=None,xy[0],xy[1], result_paths=", + "traced_for=jit, fun=, arg_names=,a,b, result_paths=result[0],result[1]", + "traced_for=custom_jvp fun, fun=my_g, arg_names=,xy[0],xy[1], result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=custom_jvp fun, fun=my_g, arg_names=xy[0],xy[1], from xy[0]", # TODO(necula): from None @@ -1010,10 +1025,9 @@ class DebugInfoTest(jtu.JaxTestCase): {"a" : 3.}, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=['a']", - "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=['b']", + "traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=result['a']", + "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=result['b']", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']", # TODO(necula): from None? @@ -1041,10 +1055,9 @@ class DebugInfoTest(jtu.JaxTestCase): (3., 3.), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=xy[0],xy[1], result_paths=[0],[1]", - "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=", + "traced_for=jit, fun=, arg_names=xy[0],xy[1], result_paths=result[0],result[1]", + "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=, arg_names=xy[0],xy[1], from xy[0]", "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]", @@ -1097,11 +1110,10 @@ class DebugInfoTest(jtu.JaxTestCase): x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=", - "traced_for=cond, fun=, arg_names=x['c'], result_paths=", - "traced_for=jit, fun=, arg_names=x, result_paths=[0][0][0],[0][0][1]", + "traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=result", + "traced_for=cond, fun=, arg_names=x['c'], result_paths=result", + "traced_for=jit, fun=, arg_names=x, result_paths=result[0][0][0],result[0][0][1]", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from r", "traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from x['c']", @@ -1129,11 +1141,10 @@ class DebugInfoTest(jtu.JaxTestCase): x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=[0]['c']", - "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=['b']", - "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=['c']", + "traced_for=jit, fun=, arg_names=x, result_paths=result[0]['c']", + "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=result['b']", + "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=result['c']", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ # TODO(necula): from None? "traced_for=jit, fun=, arg_names=x, from None", @@ -1167,11 +1178,11 @@ class DebugInfoTest(jtu.JaxTestCase): xy, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=['a']", + "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=result['a']", ], expected_tracer_debug_infos=[ - "traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y']", - "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y']" + "traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']", + "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']" ]) def test_cond(self): @@ -1192,13 +1203,13 @@ class DebugInfoTest(jtu.JaxTestCase): 0, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=", - "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result", + "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result", ], expected_tracer_debug_infos=[ - "traced_for=cond, fun=my_true_branch, arg_names=a,b", - "traced_for=cond, fun=my_false_branch, arg_names=c,d" + "traced_for=cond, fun=my_true_branch, arg_names=a,b, from a", + "traced_for=cond, fun=my_false_branch, arg_names=c,d, from c" ]) def test_switch(self): @@ -1220,15 +1231,15 @@ class DebugInfoTest(jtu.JaxTestCase): 2, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - "traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=", - "traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=", - "traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=result", + "traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=result", + "traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=result", ], expected_tracer_debug_infos=[ - "traced_for=switch, fun=my_branch0, arg_names=x0", - "traced_for=switch, fun=my_branch1, arg_names=x1", - "traced_for=switch, fun=my_branch2, arg_names=x2" + "traced_for=switch, fun=my_branch0, arg_names=x0, from x0", + "traced_for=switch, fun=my_branch1, arg_names=x1, from x1", + "traced_for=switch, fun=my_branch2, arg_names=x2, from x2" ]) def test_grad_cond_with_remat(self): @@ -1258,18 +1269,19 @@ class DebugInfoTest(jtu.JaxTestCase): 1., 2., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - 'traced_for=jit, fun=my_f, arg_names=x,y, result_paths=', + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", # TODO(necula): arg_names? result_paths? - "traced_for=cond, fun=my_true_branch, arg_names=None, result_paths=,", - "traced_for=cond, fun=my_false_branch, arg_names=None, result_paths=,", - "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=[0],[1]", - "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=[0],[1]", - "traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=,", + "traced_for=cond, fun=my_true_branch, arg_names=, result_paths=,", + "traced_for=cond, fun=my_false_branch, arg_names=, result_paths=,", + "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result[0],result[1]", + "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=,", ], expected_tracer_debug_infos=[ - 'traced_for=cond, fun=my_true_branch, arg_names=a,b', - 'traced_for=cond, fun=my_false_branch, arg_names=c,d', - 'traced_for=checkpoint / remat, fun=my_g, arg_names=x,y', + "traced_for=cond, fun=my_true_branch, arg_names=a,b, from a", + "traced_for=cond, fun=my_false_branch, arg_names=c,d, from c", + # TODO(necula): from None + "traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None", ]) def test_grad_scan(self): @@ -1302,30 +1314,29 @@ class DebugInfoTest(jtu.JaxTestCase): c, as_, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]", - # TODO(necula): bad result paths + "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", + # TODO(necula): arg names, bad result paths "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", - # TODO(necula): arg_names? "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", - "traced_for=for_loop, fun=f, arg_names=None,None,None, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None, result_paths=", - "traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None,None,None,None,None, result_paths=,", - "traced_for=checkpoint / remat, fun=to_remat, arg_names=None,None,None, result_paths=,", - "traced_for=jit, fun=my_f, arg_names=None,None,x,as_, result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", ], expected_tracer_debug_infos=[ - "traced_for=jit, fun=the_grad, arg_names=c,as_", - "traced_for=scan, fun=f, arg_names=c,a", - "traced_for=jit, fun=my_f, arg_names=x,as_", - # TODO(necula): arg_names - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2]", + "traced_for=jit, fun=the_grad, arg_names=c,as_, from c", + "traced_for=scan, fun=f, arg_names=c,a, from c", + "traced_for=jit, fun=my_f, arg_names=x,as_, from x", + # TODO(necula): arg_names, and "from x" + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], from refs[0]", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"c\"\)"), re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"), - re.compile(r".*func.func public @main\(.* -> .*tensor {jax.result_info = \"\[0\]\""), - re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"\[1\]\""), + re.compile(r".*func.func public @main\(.* -> .*tensor {jax.result_info = \"result\[0\]\""), + re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"result\[1\]\""), # TODO(necula): unnamed function? re.compile(r".*func.func private @None"), ]) @@ -1348,11 +1359,10 @@ class DebugInfoTest(jtu.JaxTestCase): 0, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - 'traced_for=while_body, fun=my_body, arg_names=b, result_paths=', - 'traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=', + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=while_body, fun=my_body, arg_names=b, result_paths=result", + "traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=while_cond, fun=my_cond, arg_names=a, from a", "traced_for=while_body, fun=my_body, arg_names=b, from b", @@ -1370,14 +1380,14 @@ class DebugInfoTest(jtu.JaxTestCase): 3., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=", + "traced_for=jit, fun=, arg_names=x, result_paths=result", # TODO(necula): bad arg_names, result_paths - 'traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=[0][0],[0][1]', + "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=result[0][0],result[0][1]", ], expected_tracer_debug_infos=[ # TODO(necula): the arg_names are not right - "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1]", + "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], from loop_carry[1]", ] ) @@ -1389,14 +1399,14 @@ class DebugInfoTest(jtu.JaxTestCase): 5, 3., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=ub,x, result_paths=", - re.compile(r'traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths='), + "traced_for=jit, fun=, arg_names=ub,x, result_paths=result", + re.compile(r"traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths="), # TODO(necula): arg_names and result_paths are not right - "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=[0],[1],[2]", + "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=result[0],result[1],result[2]", ], expected_tracer_debug_infos=[ # TODO(necula): the arg_names are not right - "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2]", + "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], from loop_carry[2]", ]) def test_scan(self): @@ -1413,10 +1423,9 @@ class DebugInfoTest(jtu.JaxTestCase): np.arange(8, dtype=np.int32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=[0],[1]", - "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=[0],[1]", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result[0],result[1]", + "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=result[0],result[1]", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, from carry" ]) @@ -1433,7 +1442,7 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[], expected_tracer_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x"], + "traced_for=jit, fun=my_f, arg_names=x, from x"], ) def test_vmap_of_nested_jit(self): @@ -1452,13 +1461,13 @@ class DebugInfoTest(jtu.JaxTestCase): np.ones((8,), dtype=np.float32), np.zeros((8,), dtype=np.float32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=['c']", + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info "None", - "traced_for=jit, fun=my_g, arg_names=u,v" + "traced_for=jit, fun=my_g, arg_names=u,v, from u" ]) def test_pmap(self): @@ -1471,11 +1480,11 @@ class DebugInfoTest(jtu.JaxTestCase): jax.pmap(my_f), np.ones((jax.device_count(),), dtype=np.float32), expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=x, result_paths=" + "traced_for=pmap, fun=my_f, arg_names=x, result_paths=result" ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=x" + "traced_for=pmap, fun=my_f, arg_names=x, from x" ], ) @@ -1493,10 +1502,9 @@ class DebugInfoTest(jtu.JaxTestCase): 1., x, x, x, # x, y, args[0], args[1] d=x, a=x, b=x, # kwargs expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=['u'],['v']", + "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", ], @@ -1508,8 +1516,8 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"), re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['u'\]\"\}"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['v'\]\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"), ] ) @@ -1523,7 +1531,7 @@ class DebugInfoTest(jtu.JaxTestCase): jax.pmap(jax.grad(my_f)), np.ones((jax.device_count(),), dtype=np.float32), expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=x, result_paths=", + "traced_for=pmap, fun=my_f, arg_names=x, result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ @@ -1550,12 +1558,12 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ # TODO(necula): why this? re.compile(r'traced_for=jit, fun=_multi_slice at .*array_methods.py:.*, arg_names=self, result_paths=.*'), - "traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=['u'],['v']", + "traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=result['u'],result['v']", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ # TODO(necula): missing debug_info - 'None' + "None" ], ) @@ -1574,13 +1582,13 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))), x, x_tan, expected_jaxpr_debug_infos=[ - 'traced_for=jit, fun=, arg_names=x,x_tan, result_paths=[0],[1]', - "traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=", + "traced_for=jit, fun=, arg_names=x,x_tan, result_paths=result[0],result[1]", + "traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ # TODO(necula): missing debug_info - 'None' + "None" ], ) @@ -1597,13 +1605,12 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(jax.hessian(jax.jit(my_f))), x, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=None,x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, from x", ], @@ -1626,11 +1633,10 @@ class DebugInfoTest(jtu.JaxTestCase): 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): missing result_paths - "traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=", + "traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y" ]) @@ -1650,12 +1656,12 @@ class DebugInfoTest(jtu.JaxTestCase): 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - # TODO(necula): arg_names? - "traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + # TODO(necula): arg_names? result_paths? + "traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=", ], expected_tracer_debug_infos=[ - "traced_for=checkpoint / remat, fun=my_g, arg_names=y" + "traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y", ]) def test_remat_shard_map(self): @@ -1677,13 +1683,12 @@ class DebugInfoTest(jtu.JaxTestCase): jnp.arange(2, dtype=np.float32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - # TODO(necula): arg_names - "traced_for=jit, fun=, arg_names=x, result_paths=", - "traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=", - "traced_for=shard_map, fun=my_f, arg_names=x, result_paths=", - "traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=", + # TODO(necula): arg_names, result_paths + "traced_for=jit, fun=, arg_names=x, result_paths=result", + "traced_for=checkpoint / remat, fun=my_f, arg_names=,, result_paths=", + "traced_for=shard_map, fun=my_f, arg_names=x, result_paths=result", + "traced_for=shard_map, fun=my_f, arg_names=,, result_paths=", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "None" # TODO(necula): missing ]) @@ -1719,11 +1724,11 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ # TODO(necula): this should not be pointing into the JAX internals re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\]"), - re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths="), - "traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=[0]", + re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths=result"), + "traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=result[0]", ], expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=my_x", + "traced_for=pmap, fun=my_f, arg_names=my_x, from my_x", ], check_lowering=False, # TODO(necula): warning during lowering ) @@ -1751,12 +1756,12 @@ class DebugInfoTest(jtu.JaxTestCase): 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=", - "traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=[0],[1]", + "traced_for=jit, fun=, arg_names=x, result_paths=result", + "traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=result[0],result[1]", ], expected_tracer_debug_infos=[ # TODO(necula): no leaked tracer from my_g_dce? - "traced_for=custom_dce, fun=my_g, arg_names=x", + "traced_for=custom_dce, fun=my_g, arg_names=x, from x", ]) def test_custom_dce_consts(self): @@ -1779,11 +1784,10 @@ class DebugInfoTest(jtu.JaxTestCase): np.array(1.1234), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=", + "traced_for=jit, fun=, arg_names=x, result_paths=result", # TODO(necula): bad arg_names (why None), bad result_paths - 'traced_for=custom_dce, fun=my_f, arg_names=None,x, result_paths=[0],[1]', + 'traced_for=custom_dce, fun=my_f, arg_names=,x, result_paths=result[0],result[1]', ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ # TODO(necula): no leaked tracer from my_rule? "traced_for=custom_dce, fun=my_f, arg_names=x, from x", @@ -1816,23 +1820,22 @@ class DebugInfoTest(jtu.JaxTestCase): a, b, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=a,b, result_paths=[0],[1]", - re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths="), - re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths="), - re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths="), + "traced_for=jit, fun=, arg_names=a,b, result_paths=result[0],result[1]", + re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths=result"), + re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths=result"), + re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths=result"), # TODO(necula): why pointers to internal functions, arg_names, result_paths? - re.compile(r'traced_for=custom_linear_solve solve, fun= at .*linalg.py:.*, arg_names=None,None,x, result_paths='), - re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=None,None,x, result_paths='), - re.compile(r'traced_for=custom_linear_solve, fun= at .*linalg.py:.*, arg_names=None,x, result_paths='), - re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=None,x, result_paths='), - 'traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=None,b, result_paths=', - 'traced_for=custom_linear_solve solve, fun=my_solve, arg_names=None,x, result_paths=', - 'traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=None,x, result_paths=', + re.compile(r'traced_for=custom_linear_solve solve, fun= at .*linalg.py:.*, arg_names=,,x, result_paths='), + re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=,,x, result_paths='), + re.compile(r'traced_for=custom_linear_solve, fun= at .*linalg.py:.*, arg_names=,x, result_paths='), + re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=,x, result_paths='), + "traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=,b, result_paths=result", + "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=,x, result_paths=result", + "traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=,x, result_paths=result", ], expected_tracer_debug_infos=[ - "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x", - "traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x", - "traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=b", + "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x, from x", + "traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x, from x", "None", # TODO(necula): there are missing debug info ]) @@ -1856,14 +1859,15 @@ class DebugInfoTest(jtu.JaxTestCase): 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=[0],[1]", + "traced_for=jit, fun=, arg_names=x, result_paths=result[0],result[1]", # TODO(necula): internal function? - re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=\[0\]"), + re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=result\[0\]"), ], expected_tracer_debug_infos=[ - "traced_for=custom_root, fun=my_f, arg_names=x", - "traced_for=custom_root solve, fun=my_solve, arg_names=x", - "traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x", + "traced_for=custom_root, fun=my_f, arg_names=x, from x", + "traced_for=custom_root solve, fun=my_solve, arg_names=x, from x", + # TODO(necula): from None + "traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x, from None", "None", # TODO(necula): there are missing debug info ]) @@ -1891,13 +1895,13 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(my_f), x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=[0],[1]", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=result[0],result[1]", "traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, result_paths=", ], expected_tracer_debug_infos=[ - "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j", - "traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref", + "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, from i", + "traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, from x_ref", ], check_lowering=False, # We need interpret mode on CPU. TODO(necula) ) @@ -1921,14 +1925,14 @@ class DebugInfoTest(jtu.JaxTestCase): jnp.arange(4, dtype=jnp.float32) - 2, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=input, result_paths=", + "traced_for=jit, fun=my_f, arg_names=input, result_paths=result", # TODO(necula): function source location points in JAX internals # TODO(necula): arg_names and result_paths are wrong re.compile(r"traced_for=checkify_pallas, fun=checked_kernel_fn at .*pallas_call.py:.*, arg_names=args\[0\],.*, result_paths="), re.compile(r"traced_for=pallas_call index_map, fun= at .*pallas.core.py:.*, arg_names=, result_paths="), ], expected_tracer_debug_infos=[ - "traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref", + "traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref, from x_ref", ], check_lowering=False, # We need interpret mode on CPU. TODO(necula) ) @@ -1948,11 +1952,11 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(my_consts), x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_consts, arg_names=x, result_paths=", - "traced_for=composite, fun=my_consts, arg_names=x, result_paths=", + "traced_for=jit, fun=my_consts, arg_names=x, result_paths=result", + "traced_for=composite, fun=my_consts, arg_names=x, result_paths=result", ], expected_tracer_debug_infos=[ - "traced_for=composite, fun=my_consts, arg_names=x"]) + "traced_for=composite, fun=my_consts, arg_names=x, from x"]) if __name__ == '__main__': diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 7a1d36317..9a6c5c167 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -274,13 +274,13 @@ class MutableArrayErrorsTest(jtu.JaxTestCase): def test_return_from_jit_pytree(self): with self.assertRaisesRegex( ValueError, - r"tree path \['hi'\]"): + r"tree path result\['hi'\]"): jax.jit(lambda x_ref: {'hi': x_ref})(core.mutable_array(jnp.arange(3))) def test_return_from_jit_closure(self): with self.assertRaisesRegex( ValueError, - r"tree path \['hi'\]"): + r"tree path result\['hi'\]"): x_ref = core.mutable_array(jnp.arange(3)) jax.jit(lambda: {'hi': x_ref})() diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9abfd14fc..bdb117166 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6846,7 +6846,7 @@ class PJitErrorTest(jtu.JaxTestCase): spec = P(resources, None) mesh_size = str(math.prod([dim[1] for dim in mesh])) error = re.compile( - r"One of pjit outputs with pytree key path \['rrr'\].*" + spec_regex(spec) + r".*" + r"One of pjit outputs with pytree key path result\['rrr'\].*" + spec_regex(spec) + r".*" r"implies that the global size of its dimension 0 should be " r"divisible by " + mesh_size + r", but it is equal to 3", re.M | re.S) with self.assertRaisesRegex(ValueError, error):