From 1be801bac8863a4f588fd6c7eae8f3099fa48036 Mon Sep 17 00:00:00 2001 From: George Necula Date: Tue, 18 Feb 2025 10:09:47 +0100 Subject: [PATCH] [better_errors] Cleanup use of DebugInfo.arg_names and result_paths Previously, we represented a missing arg name with `None`, and a missing result path with the empty string. We now adopt the same convention for arg names and use empty strings. This simplifies the typing, and prevents the string "None" from appearing in error messages. I changed how we encode the result paths. Previously for a function that returns a single array the path was the empty string (the same as for an unknown path). And for a function that returns a pair of arrays it was `([0], [1])`. Now we add the "result" prefix: `("result",)` for a function returning a single array and `(result[0], result[1])` for a function returning a pair of arrays. Finally, in debug_info_test, I removed the `check_tracer_arg_name` so that all spied tracers are printed with the argument name they depend on. --- docs/aot.md | 4 +- docs/export/export.md | 4 +- jax/_src/ad_checkpoint.py | 4 +- jax/_src/api.py | 2 +- jax/_src/api_util.py | 2 +- jax/_src/core.py | 4 +- jax/_src/interpreters/ad.py | 2 +- jax/_src/interpreters/mlir.py | 2 +- jax/_src/interpreters/partial_eval.py | 4 +- jax/_src/interpreters/pxla.py | 2 +- jax/_src/lax/control_flow/loops.py | 2 +- jax/_src/linear_util.py | 59 ++-- jax/_src/pjit.py | 30 +- tests/debug_info_test.py | 388 +++++++++++++------------- tests/mutable_array_test.py | 4 +- tests/pjit_test.py | 2 +- 16 files changed, 266 insertions(+), 249 deletions(-) diff --git a/docs/aot.md b/docs/aot.md index 1f24d64fa..1fcf11ab9 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -56,7 +56,7 @@ some other features along the way. An example: >>> # Print lowered HLO >>> print(lowered.as_text()) module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor, %arg1: tensor) -> (tensor {jax.result_info = ""}) { + func.func public @main(%arg0: tensor, %arg1: tensor) -> (tensor {jax.result_info = "result"}) { %c = stablehlo.constant dense<2> : tensor %0 = stablehlo.multiply %c, %arg0 : tensor %1 = stablehlo.add %0, %arg1 : tensor @@ -140,7 +140,7 @@ to invoke the resulting compiled function. Continuing with our example above: >>> # Lowered HLO, specialized to the *value* of the first argument (7) >>> print(lowered_with_x.as_text()) module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor) -> (tensor {jax.result_info = ""}) { + func.func public @main(%arg0: tensor) -> (tensor {jax.result_info = "result"}) { %c = stablehlo.constant dense<14> : tensor %0 = stablehlo.add %c, %arg0 : tensor return %0 : tensor diff --git a/docs/export/export.md b/docs/export/export.md index bab4723f3..f1542d80d 100644 --- a/docs/export/export.md +++ b/docs/export/export.md @@ -44,7 +44,7 @@ Here is an example: (ShapedArray(float32[]),) >>> print(re.search(r".*@main.*", exported.mlir_module()).group(0)) - func.func public @main(%arg0: tensor loc("x")) -> (tensor {jax.result_info = ""}) { + func.func public @main(%arg0: tensor loc("x")) -> (tensor {jax.result_info = "result"}) { >>> # And you can serialize the Exported to a bytearray. >>> serialized: bytearray = exported.serialize() @@ -206,7 +206,7 @@ as in the following example: >>> _ = mlir.register_lowering(new_prim, lambda ctx, o: mlir.custom_call("my_new_prim", operands=[o], result_types=[o.type]).results) >>> print(jax.jit(new_prim.bind).lower(1.).compiler_ir()) module @jit_bind attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} { - func.func public @main(%arg0: tensor) -> (tensor {jax.result_info = ""}) { + func.func public @main(%arg0: tensor) -> (tensor {jax.result_info = "result"}) { %0 = stablehlo.custom_call @my_new_prim(%arg0) {api_version = 2 : i32, backend_config = ""} : (tensor) -> tensor return %0 : tensor } diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index 61576bb50..f1c0078cd 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -458,7 +458,7 @@ def saved_residuals(f: Callable, return _saved_residuals(jaxpr, debug_info.arg_names) def _saved_residuals(jaxpr: core.Jaxpr, - arg_names: tuple[str | None, ...]) -> list[tuple[core.AbstractValue, str]]: + arg_names: Sequence[str]) -> list[tuple[core.AbstractValue, str]]: res_lits = [x for x in jaxpr.outvars if isinstance(x, core.Literal)] res_vars = {x for x in jaxpr.outvars if not isinstance(x, core.Literal)} @@ -473,7 +473,7 @@ def _saved_residuals(jaxpr: core.Jaxpr, for i, v in enumerate(jaxpr.invars): if v in res_vars: - if arg_names[i] is not None: + if arg_names[i]: src = f'from the argument {arg_names[i]}' else: src = 'from the argument at flattened index {i}' diff --git a/jax/_src/api.py b/jax/_src/api.py index 0055f6466..d24ea7cd7 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2274,7 +2274,7 @@ def _check_sharding(aval, s): aval = core.get_token_aval() if not isinstance(s, PmapSharding): pjit.pjit_check_aval_sharding( - (s,), (aval,), None, "device_put args", allow_uneven_sharding=False) + (s,), (aval,), ("",), "device_put args", allow_uneven_sharding=False) s.shard_shape(aval.shape) # should raise an Error if incompatible diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index a597e8b5b..1fd371034 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -662,7 +662,7 @@ def _non_static_arg_names(fn_signature: inspect.Signature | None, args: Sequence[Any], kwargs: dict[str, Any], static_argnums: Sequence[int], static_argnames: Sequence[str], - ) -> tuple[str | None, ...]: + ) -> tuple[str, ...]: """Returns the names of the non-static arguments. If the `fn_signature` is given then we get from it the names of the diff --git a/jax/_src/core.py b/jax/_src/core.py index d8f91789b..cd08d12ef 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -149,8 +149,8 @@ class Jaxpr: debug_info = debug_info or lu._missing_debug_info("core.Jaxpr") self._debug_info = debug_info.resolve_result_paths() # TODO(necula): re-enable these safety checks - # assert (not debug_info or len(debug_info.arg_names) == len(invars)), (debug_info, invars) - # assert (not debug_info or len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) + # assert (len(debug_info.arg_names) == len(invars)), (debug_info, invars) + # assert (len(debug_info.result_paths) == len(outvars)), (debug_info, outvars) def __str__(self): return str(self.pretty_print()) diff --git a/jax/_src/interpreters/ad.py b/jax/_src/interpreters/ad.py index 4a0e6ca46..37ad40d22 100644 --- a/jax/_src/interpreters/ad.py +++ b/jax/_src/interpreters/ad.py @@ -135,7 +135,7 @@ def jvp_subtrace_aux(f, store, tag, primals, tangents): def convert_constvars_jaxpr_constvars_at_end(jaxpr: core.Jaxpr) -> core.Jaxpr: dbg = jaxpr.debug_info._replace( - arg_names=jaxpr.debug_info.arg_names + (None,) * len(jaxpr.constvars)) + arg_names=jaxpr.debug_info.arg_names + ("",) * len(jaxpr.constvars)) return core.Jaxpr(constvars=(), invars=jaxpr.invars + jaxpr.constvars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index b4d7b104a..4c722dedb 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1113,7 +1113,7 @@ def lower_jaxpr_to_module( result_shardings: Sequence[JSharding | AUTO | None] | None = None, in_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, out_layouts: Sequence[DeviceLocalLayout | None | AutoLayout] | None = None, - arg_names: Sequence[str | None] | None = None, + arg_names: Sequence[str] | None = None, result_names: Sequence[str] | None = None, num_replicas: int = 1, num_partitions: int = 1, diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 8d27f11c0..6fde73705 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -842,7 +842,7 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: """Moves the constvars to the start of invars.""" config.enable_checks.value and core.check_jaxpr(jaxpr) dbg = jaxpr.debug_info._replace( - arg_names=(None,) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) + arg_names=("",) * len(jaxpr.constvars) + jaxpr.debug_info.arg_names) lifted_jaxpr = Jaxpr(constvars=(), invars=jaxpr.constvars + jaxpr.invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns, @@ -1574,7 +1574,7 @@ class DynamicJaxprTracer(core.Tracer): origin = ("The error occurred while tracing the function " f"{dbg.func_src_info} for {dbg.traced_for}. ") - if invar_pos and dbg.arg_names: + if invar_pos: try: arg_names = [dbg.arg_names[i] for i in invar_pos] except IndexError: diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index f38084120..9baa5e977 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -3260,7 +3260,7 @@ def check_array_xla_sharding_layout_match( from jax._src.array import ArrayImpl # jaxpr_debug_info.arg_names are before DCE, so need to DCE them. arg_names = ( - [a for i, a in enumerate(jaxpr_debug_info.arg_names) # type: ignore + [a for i, a in enumerate(jaxpr_debug_info.arg_names) if i in kept_var_idx] ) errors = [] diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index af01dc249..861acdd42 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -1554,7 +1554,7 @@ def _while_loop_jvp(primals, tangents, cond_nconsts, cond_jaxpr, body_nconsts, cond_debug = cond_jaxpr.jaxpr.debug_info augmented_debug = cond_debug and ( cond_debug._replace( - arg_names=cond_debug.arg_names + (None,) * len(init_dot) + arg_names=cond_debug.arg_names + ("",) * len(init_dot) ) ) cond_jaxpr_augmented = core.Jaxpr(cond_jaxpr.jaxpr.constvars, diff --git a/jax/_src/linear_util.py b/jax/_src/linear_util.py index d1439f8a4..b8272fa0f 100644 --- a/jax/_src/linear_util.py +++ b/jax/_src/linear_util.py @@ -276,30 +276,43 @@ class DebugInfo(NamedTuple): """Debugging info about a func, its arguments, and results.""" traced_for: str # e.g. 'jit', 'scan', etc - # e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have - # no source location information. The first word is always the function name, - # which may be ''. func_src_info: str + """e.g. f'{fun.__name__} at {filename}:{lineno}' or {fun.__name__} if we have + no source location information. The first word is always the function name, + which may be ''. + """ - # The paths of the flattened non-static argnames, - # e.g. ('x', 'dict_arg["a"]', ... ). - # Uses `None` for the args that do not correspond to user-named arguments, - # e.g., tangent args in jax.jvp. At the moment, `arg_names` accuracy is - # best-effort. Use `safe_arg_names` to detect and handle an unexpected - # number of elements in `arg_names`. - arg_names: tuple[str | None, ...] + arg_names: tuple[str, ...] + """The paths of the flattened non-static argnames, + e.g. `('x', 'dict_arg["a"]', ... )`. + Uses the empty string for the args that do not correspond to + user-named arguments, e.g., tangent args in `jax.jvp`, or for arguments that + we are not yet tracking properly. + At the moment, `arg_names` accuracy is best-effort. + Use `safe_arg_names` to detect and handle an unexpected + number of elements in `arg_names`. + """ - # The result paths are not available while we are tracing the function, - # instead we keep a thunk. Once we are done tracing, we use - # `self.resolve_result_paths()` to execute the thunk and replace the - # actual result paths. At the moment, `result_paths` accuracy is - # best-effort. Use `safe_result_paths` to detect and handle an unexpected - # number of elements in `result_paths`. - # e.g. ('[0]', '[1]', ...) result_paths: tuple[str, ...] | Callable[[], tuple[str, ...]] | None + """The paths to the flattened results, e.g., `('result[0]', result[1])` for a + function that returns a tuple of arrays, or `(result,)` for a function that + returns a single array. + The result paths are not available while we are tracing the function, + instead we keep a thunk. It is possible for the result paths to be `None` + only when we first create a `DebugInfo`, before we put it in `lu.WrappedFun` + and before we start tracing. + Inside a `lu.WrappedFun` it can be only a thunk or a tuple of strings. + Once we are done tracing, we use + `self.resolve_result_paths()` to execute the thunk and replace the + actual result paths. + At the moment, `result_paths` accuracy is best-effort. + Use `safe_result_paths` to detect and handle an unexpected + number of elements in `result_paths`. + """ def resolve_result_paths(self) -> DebugInfo: """Return a debug info with resolved result paths.""" + assert self.result_paths is not None if callable(self.result_paths): return self._replace(result_paths=tuple(self.result_paths())) return self @@ -308,21 +321,21 @@ class DebugInfo(NamedTuple): def func_name(self) -> str: return self.func_src_info.split(" ")[0] - def safe_arg_names(self, expected: int) -> tuple[str | None, ...]: + def safe_arg_names(self, expected: int) -> tuple[str, ...]: """Get the arg_names with a safety check.""" if len(self.arg_names) == expected: return self.arg_names else: # TODO(necula): this should not happen - return (None,) * expected + return ("",) * expected - def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str | None, ...]: + def filter_arg_names(self, keep: Sequence[bool]) -> tuple[str, ...]: """Keep only the arg_names for which `keep` is True.""" return tuple(v for v, b in zip(self.safe_arg_names(len(keep)), keep) if b) def safe_result_paths(self, expected: int) -> tuple[str, ...]: """Get the result paths with a safety check.""" - assert not callable(self.result_paths), self + assert self.result_paths is not None and not callable(self.result_paths), self if self.result_paths is not None and len(self.result_paths) == expected: return self.result_paths else: @@ -331,7 +344,7 @@ class DebugInfo(NamedTuple): def filter_result_paths(self, keep: Sequence[bool]) -> tuple[str, ...]: """Keep only the result_paths for which `keep` is True.""" - assert not callable(self.result_paths), self + assert self.result_paths is not None and not callable(self.result_paths), self return tuple(v for v, b in zip(self.safe_result_paths(len(keep)), keep) if b) @@ -368,7 +381,7 @@ def _clean_keystr_arg_names(k: KeyPath) -> str: @transformation_with_aux2 def _get_result_paths_thunk(_fun: Callable, _store: Store, *args, **kwargs): ans = _fun(*args, **kwargs) - result_paths = [_clean_keystr_arg_names(path) for path, _ in generate_key_paths(ans)] + result_paths = tuple(f"result{_clean_keystr_arg_names(path)}" for path, _ in generate_key_paths(ans)) if _store: # In some instances a lu.WrappedFun is called multiple times, e.g., # the bwd function in a custom_vjp diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index be44c2fe1..ac8dd8f6e 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -538,7 +538,7 @@ class PjitParams(NamedTuple): in_tree: PyTreeDef out_tree: PyTreeDef donated_invars: tuple[bool, ...] - arg_names: tuple[str | None, ...] + arg_names: tuple[str, ...] num_consts: int attrs_tracked: list[tuple[PyTreeDef, PyTreeDef, tuple[Any, str]]] @@ -663,7 +663,7 @@ def _infer_params_impl( compiler_options_kvs=ji.compiler_options_kvs, ) return PjitParams(consts, params, in_avals, in_tree, out_tree(), - donated_invars, dbg.arg_names if dbg else None, len(consts), + donated_invars, dbg.arg_names, len(consts), attrs_tracked), args_flat @@ -741,13 +741,13 @@ def _infer_input_type(fun: Callable, dbg: core.DebugInfo, for i, x in enumerate(explicit_args): avals.append(core.shaped_abstractify(x)) except OverflowError: - arg_path = f"argument path is {dbg.arg_names[i]}" # type: ignore + arg_path = f"argument path is {dbg.arg_names[i]}" raise OverflowError( "An overflow was encountered while parsing an argument to a jitted " f"computation, whose {arg_path}." ) from None except TypeError: - arg_description = f"path {dbg.arg_names[i]}" # type: ignore + arg_description = f"path {dbg.arg_names[i]}" raise TypeError( f"Error interpreting argument to {fun} as an abstract array." f" The problematic value is of type {type(x)} and was passed to" @@ -1134,7 +1134,7 @@ def _process_in_axis_resources(in_shardings_treedef, in_shardings_leaves, "pjit in_layouts", in_tree, in_layouts, tupled_args=True) # TODO(dougalm,mattjj): enable debug info with attrs_tracked - attrs_tracked = debug_info and len(debug_info.arg_names) != len(in_avals) + attrs_tracked = len(debug_info.arg_names) != len(in_avals) if not config.dynamic_shapes.value and not attrs_tracked: pjit_check_aval_sharding(in_shardings_flat, in_avals, debug_info.safe_arg_names(len(in_avals)), @@ -1338,11 +1338,11 @@ def _check_and_canonicalize_out_shardings( if not config.dynamic_shapes.value: pjit_check_aval_sharding( out_shardings_flat, out_avals, - debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] + debug_info.safe_result_paths(len(out_avals)), "pjit outputs", allow_uneven_sharding=False) check_aval_layout_compatibility( out_layouts_flat, out_avals, - debug_info.safe_result_paths(len(out_avals)), # type: ignore[arg-type] + debug_info.safe_result_paths(len(out_avals)), "jit outputs") return out_shardings_flat, out_layouts_flat @@ -1396,10 +1396,9 @@ class IgnoreKey: def pjit_check_aval_sharding( - shardings, flat_avals, names: tuple[str | None, ...] | None, + shardings, flat_avals, names: Sequence[str], what_aval: str, allow_uneven_sharding: bool): - new_names = [None] * len(shardings) if names is None else names - for aval, s, name in zip(flat_avals, shardings, new_names): + for aval, s, name in zip(flat_avals, shardings, names): if isinstance(s, (UnspecifiedValue, AUTO)): continue name_str = f' with pytree key path {name}' if name else '' @@ -1431,9 +1430,8 @@ def pjit_check_aval_sharding( def check_aval_layout_compatibility( - layouts, flat_avals, names: tuple[str, ...] | None, what_aval: str): - new_names = [''] * len(layouts) if names is None else names - for aval, l, name in zip(flat_avals, layouts, new_names): + layouts, flat_avals, names: Sequence[str], what_aval: str): + for aval, l, name in zip(flat_avals, layouts, names): if l is None or isinstance(l, AutoLayout): continue name_str = f' with pytree key path {name}' if name else '' @@ -2557,12 +2555,14 @@ def with_sharding_constraint(x, shardings): for s in shardings_flat] pjit_check_aval_sharding( - shardings_flat, x_flat, None, "with_sharding_constraint arguments", + shardings_flat, x_flat, ("",) * len(shardings_flat), + "with_sharding_constraint arguments", allow_uneven_sharding=True) check_shardings_are_auto(shardings_flat) - check_aval_layout_compatibility(user_layouts_flat, x_flat, None, + check_aval_layout_compatibility(user_layouts_flat, x_flat, + ("",) * len(user_layouts_flat), "with_sharding_constraint arguments") outs = [sharding_constraint_p.bind(xf, sharding=s, layout=l, diff --git a/tests/debug_info_test.py b/tests/debug_info_test.py index de5bb0bc9..8ad235daf 100644 --- a/tests/debug_info_test.py +++ b/tests/debug_info_test.py @@ -100,11 +100,12 @@ class TracerSpy: try: # We plan to do boolean conversion and catch the exception, but this works # only for scalars - if t.shape: - t = jnp.sum(t) - if t: + t_scalar = t + while t_scalar.shape: + t_scalar = t_scalar[0] + if t_scalar: pass - assert False, t + assert False, t_scalar except Exception as e: self.tracers.append((t, e)) @@ -118,7 +119,6 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy: TracerSpy | None = None, expected_tracer_debug_infos: list[str | re.Pattern] = [], check_lowering: bool = True, - check_tracer_arg_name: bool = False, expected_lowering_lines: list[str | re.Pattern] = [], **kwargs) -> None: """Checks the expected debug info in all jaxprs, in spied tracers, and StableHLO. @@ -172,11 +172,14 @@ class DebugInfoTest(jtu.JaxTestCase): for t, exc in tracer_spy.tracers: if hasattr(t, "_debug_info"): t_debug_info = _debug_info_to_string(t._debug_info) - if check_tracer_arg_name: - msg = str(exc) - m = re.match(r".* while tracing the function (.+) for ([^.]+)\.", - msg, - re.DOTALL) + msg = str(exc) + m = re.match(r".* while tracing the function (.+) for ([^.]+)\.", + msg, + re.DOTALL) + if m is None: + found_tracer_debug_infos.append( + f"{t_debug_info}, from None") + else: self.assertIsNotNone(m, msg) self.assertEqual(t._debug_info.func_src_info, m.group(1)) self.assertEqual(t._debug_info.traced_for, m.group(2)) @@ -185,8 +188,6 @@ class DebugInfoTest(jtu.JaxTestCase): re.DOTALL) found_tracer_debug_infos.append( f"{t_debug_info}, from {m.group(1) if m else None}") - else: - found_tracer_debug_infos.append(t_debug_info) else: found_tracer_debug_infos.append("None") @@ -646,9 +647,8 @@ class DebugInfoTest(jtu.JaxTestCase): dict(a=1, b=2), 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=['c'],['d']" + "traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, result_paths=result['c'],result['d']" ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x_dict['a'],x_dict['b'],y, from x_dict['a']", ]) @@ -669,10 +669,10 @@ class DebugInfoTest(jtu.JaxTestCase): 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - 'traced_for=jit, fun=my_f, arg_names=a, result_paths=', - 'traced_for=jit, fun=my_g, arg_names=b, result_paths=', + # TODO(necula): result_paths? + "traced_for=jit, fun=my_f, arg_names=a, result_paths=", + "traced_for=jit, fun=my_g, arg_names=b, result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_g, arg_names=b, from b", "traced_for=jit, fun=my_f, arg_names=a, from a", @@ -690,10 +690,9 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(f), {"ho": 1.}, {"hi": 2.}, 3., 4., z=5., w=6., expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=", + "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], result_paths=result", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from kwargs['w']", "traced_for=jit, fun=f, arg_names=x['ho'],y['hi'],args[0],args[1],kwargs['w'],kwargs['z'], from args[0]", @@ -703,7 +702,7 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"kwargs\['w'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg3: tensor loc\(\"kwargs\['z'\]\"\)"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ] ) @@ -721,10 +720,9 @@ class DebugInfoTest(jtu.JaxTestCase): (1.,), {"hi": 2.}, 3., 4., 5., 6., # x, y, z, args[0], args[1], args[2] t=11., w=12., # kwargs expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=", + "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], result_paths=result", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=y['hi'],z,args[0],args[1],kwargs['t'],kwargs['w'], from kwargs['w']", ], @@ -732,7 +730,7 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"y\['hi'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"kwargs\['t'\]\"\)"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) def test_jit_arg_names_static_argnames(self): @@ -748,10 +746,9 @@ class DebugInfoTest(jtu.JaxTestCase): (1.,), {'hi': 2.}, 3., 4., # x, y, args[0], args[1] z=5., w=6., a=7., b=8., # kwargs expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=", + "traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], result_paths=result", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=f, arg_names=x[0],y['hi'],args[0],args[1],kwargs['b'],kwargs['w'],kwargs['z'], from x[0]", ], @@ -760,7 +757,7 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"args\[1\]\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"kwargs\['b'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg3: tensor loc\(\"kwargs\['w'\]\"\)"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\"\}"), ]) def test_jit_result_info(self): @@ -770,13 +767,13 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(f), 1., (2.,), [3.], expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=['a'],['b'][0][0]", + "traced_for=jit, fun=f, arg_names=x,y[0],z[0], result_paths=result['a'],result['b'][0][0]", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"x\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"y\[0\]\"\)"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\['a'\]\"\}"), - re.compile(r".*func.func public @main\(.*\{jax.result_info = \"\['b'\]\[0\]\[0\]\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['a'\]\"\}"), + re.compile(r".*func.func public @main\(.*\{jax.result_info = \"result\['b'\]\[0\]\[0\]\"\}"), ]) def test_nested_jit(self): @@ -795,14 +792,35 @@ class DebugInfoTest(jtu.JaxTestCase): 2, 3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=[\'c\']" + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']" ], expected_tracer_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y", - "traced_for=jit, fun=my_g, arg_names=u,v" + "traced_for=jit, fun=my_f, arg_names=x,y, from x", + "traced_for=jit, fun=my_g, arg_names=u,v, from u" ]) + def test_nested_jit_with_const_and_unused_args(self): + def my_f(x, y): # y is unused + def my_g(u, v): # v is unused + return v + np.ones(v.shape, v.dtype) + + return x + jax.jit(my_g)(y, x) + + x = y = np.ones((8,), dtype=np.float32) + self._check_tracers_and_jaxprs( + jax.jit(my_f), + x, y, + expected_jaxpr_debug_infos=[ + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result" + ], + expected_lowering_lines=[ + re.compile(r".*func.func public @main\(%arg0: tensor<8xf..> loc\(\"x\"\)\)"), + re.compile(r".*call @my_g\(%arg.\) : \(tensor<8xf..>\)"), + ] + ) + def test_jvp_of_jit(self): tracer_spy = TracerSpy() def f(x, y, z): @@ -813,11 +831,11 @@ class DebugInfoTest(jtu.JaxTestCase): jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ # TODO(necula): arg_names, result_paths - "traced_for=jit, fun=f, arg_names=None,None,None,None, result_paths=,,,", + "traced_for=jit, fun=f, arg_names=,,,, result_paths=,,,", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ - "traced_for=jit, fun=f, arg_names=x,y[0],z[0]", + "traced_for=jit, fun=f, arg_names=x,y[0],z[0], from x", ], expected_lowering_lines=[ # TODO(necula): missing arg_names @@ -840,13 +858,12 @@ class DebugInfoTest(jtu.JaxTestCase): lambda x, y, z: jax.vjp(jax.jit(my_f), x, y, z)[1](dict(a=x, b=[y])), jnp.float32(1.), (jnp.float32(2.),), [jnp.float32(3.)], expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=", - re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths="), - # TODO(necula): arg_names? - "traced_for=jit, fun=my_f, arg_names=None,None,None,None, result_paths=['a'],['b'][0][0]", + "traced_for=jit, fun=my_f, arg_names=x,y[0], result_paths=result", + re.compile(r"traced_for=jit, fun=convert_element_type at .*dispatch.py:.*, arg_names=args\[0\], result_paths=result"), + # TODO(necula): arg_names? result_paths? + "traced_for=jit, fun=my_f, arg_names=,,,, result_paths=['a'],['b'][0][0]", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x,y[0],z[0], from y[0]", ], @@ -854,6 +871,7 @@ class DebugInfoTest(jtu.JaxTestCase): # TODO(necula): missing arg_names re.compile(r".*func.func public @main\(%arg0: tensor loc\(unknown\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(unknown\)"), + # TODO(necula): result_paths? re.compile(r".*func.func public @main\(.*-> \(tensor {jax.result_info = \"\"}"), ]) @@ -873,13 +891,12 @@ class DebugInfoTest(jtu.JaxTestCase): 2., 3., 0.3, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=[0],[1]", + "traced_for=jit, fun=, arg_names=x,y,res_ct, result_paths=result[0],result[1]", # TODO(necula): result_paths "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=", # TODO(necula): arg_names - "traced_for=jit, fun=my_g, arg_names=None,None,u,v, result_paths=['c'],['d']", + "traced_for=jit, fun=my_g, arg_names=,,u,v, result_paths=result['c'],result['d']", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ # TODO(necula): missing debug info "None", @@ -889,8 +906,8 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"x\"\)"), re.compile(r".*func.func public @main\(.*%arg1: tensor loc\(\"y\"\)"), re.compile(r".*func.func public @main\(.*%arg2: tensor loc\(\"res_ct\"\)"), - re.compile(r".*func.func public @main\(.*jax.result_info = \"\[0\]\"}"), - re.compile(r".*func.func public @main\(.*jax.result_info = \"\[1\]\"}"), + re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[0\]\"}"), + re.compile(r".*func.func public @main\(.*jax.result_info = \"result\[1\]\"}"), ]) def test_vjp_remat(self): @@ -909,11 +926,10 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): what are these flat_index components? - "traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=[0],[1][0][0][0][0][0]", - re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="), - re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths="), + "traced_for=jit, fun=apply_fn, arg_names=inp, result_paths=result[0],result[1][0][0][0][0][0]", + re.compile(r"traced_for=custom_jvp fun, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"), + re.compile(r"traced_for=jit, fun=relu at .*nn.functions.py:.*, arg_names=x, result_paths=result"), ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=to_remat, arg_names=x, from x", "traced_for=jit, fun=apply_fn, arg_names=inp, from inp", @@ -942,11 +958,11 @@ class DebugInfoTest(jtu.JaxTestCase): 42., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=a, result_paths=[0],[1]", - "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=", + "traced_for=jit, fun=, arg_names=a, result_paths=result[0],result[1]", + "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ + # TODO(necula): from None? "traced_for=jit, fun=, arg_names=a, from None", "traced_for=custom_jvp fun, fun=my_fun, arg_names=x,y,c, from y", ]) @@ -976,10 +992,9 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ # TODO(necula): arg_names - "traced_for=jit, fun=, arg_names=None,a,b, result_paths=[0],[1]", - "traced_for=custom_jvp fun, fun=my_g, arg_names=None,xy[0],xy[1], result_paths=", + "traced_for=jit, fun=, arg_names=,a,b, result_paths=result[0],result[1]", + "traced_for=custom_jvp fun, fun=my_g, arg_names=,xy[0],xy[1], result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=custom_jvp fun, fun=my_g, arg_names=xy[0],xy[1], from xy[0]", # TODO(necula): from None @@ -1010,10 +1025,9 @@ class DebugInfoTest(jtu.JaxTestCase): {"a" : 3.}, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=['a']", - "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=['b']", + "traced_for=jit, fun=to_diff, arg_names=x['a'], result_paths=result['a']", + "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], result_paths=result['b']", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=custom_vjp fun, fun=my_f, arg_names=x['a'], from x['a']", # TODO(necula): from None? @@ -1041,10 +1055,9 @@ class DebugInfoTest(jtu.JaxTestCase): (3., 3.), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=xy[0],xy[1], result_paths=[0],[1]", - "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=", + "traced_for=jit, fun=, arg_names=xy[0],xy[1], result_paths=result[0],result[1]", + "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=, arg_names=xy[0],xy[1], from xy[0]", "traced_for=custom_vjp fun, fun=app, arg_names=xy[0],xy[1], from xy[0]", @@ -1097,11 +1110,10 @@ class DebugInfoTest(jtu.JaxTestCase): x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=", - "traced_for=cond, fun=, arg_names=x['c'], result_paths=", - "traced_for=jit, fun=, arg_names=x, result_paths=[0][0][0],[0][0][1]", + "traced_for=cond, fun=my_f, arg_names=x['c'], result_paths=result", + "traced_for=cond, fun=, arg_names=x['c'], result_paths=result", + "traced_for=jit, fun=, arg_names=x, result_paths=result[0][0][0],result[0][0][1]", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from r", "traced_for=custom_transpose fun, fun=fn, arg_names=r,x['c'], from x['c']", @@ -1129,11 +1141,10 @@ class DebugInfoTest(jtu.JaxTestCase): x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=[0]['c']", - "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=['b']", - "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=['c']", + "traced_for=jit, fun=, arg_names=x, result_paths=result[0]['c']", + "traced_for=linear_call fun, fun=fn, arg_names=r,x['c'], result_paths=result['b']", + "traced_for=linear_call fun_transpose, fun=fn_tp, arg_names=r,t['c'], result_paths=result['c']", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ # TODO(necula): from None? "traced_for=jit, fun=, arg_names=x, from None", @@ -1167,11 +1178,11 @@ class DebugInfoTest(jtu.JaxTestCase): xy, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=['a']", + "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], result_paths=result['a']", ], expected_tracer_debug_infos=[ - "traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y']", - "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y']" + "traced_for=custom_vmap fun, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']", + "traced_for=jit, fun=my_f, arg_names=xdict['x'],xdict['y'], from xdict['x']" ]) def test_cond(self): @@ -1192,13 +1203,13 @@ class DebugInfoTest(jtu.JaxTestCase): 0, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=", - "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result", + "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result", ], expected_tracer_debug_infos=[ - "traced_for=cond, fun=my_true_branch, arg_names=a,b", - "traced_for=cond, fun=my_false_branch, arg_names=c,d" + "traced_for=cond, fun=my_true_branch, arg_names=a,b, from a", + "traced_for=cond, fun=my_false_branch, arg_names=c,d, from c" ]) def test_switch(self): @@ -1220,15 +1231,15 @@ class DebugInfoTest(jtu.JaxTestCase): 2, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - "traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=", - "traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=", - "traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=switch, fun=my_branch0, arg_names=x0, result_paths=result", + "traced_for=switch, fun=my_branch1, arg_names=x1, result_paths=result", + "traced_for=switch, fun=my_branch2, arg_names=x2, result_paths=result", ], expected_tracer_debug_infos=[ - "traced_for=switch, fun=my_branch0, arg_names=x0", - "traced_for=switch, fun=my_branch1, arg_names=x1", - "traced_for=switch, fun=my_branch2, arg_names=x2" + "traced_for=switch, fun=my_branch0, arg_names=x0, from x0", + "traced_for=switch, fun=my_branch1, arg_names=x1, from x1", + "traced_for=switch, fun=my_branch2, arg_names=x2, from x2" ]) def test_grad_cond_with_remat(self): @@ -1258,18 +1269,19 @@ class DebugInfoTest(jtu.JaxTestCase): 1., 2., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - 'traced_for=jit, fun=my_f, arg_names=x,y, result_paths=', + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", # TODO(necula): arg_names? result_paths? - "traced_for=cond, fun=my_true_branch, arg_names=None, result_paths=,", - "traced_for=cond, fun=my_false_branch, arg_names=None, result_paths=,", - "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=[0],[1]", - "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=[0],[1]", - "traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=,", + "traced_for=cond, fun=my_true_branch, arg_names=, result_paths=,", + "traced_for=cond, fun=my_false_branch, arg_names=, result_paths=,", + "traced_for=cond, fun=my_true_branch, arg_names=a,b, result_paths=result[0],result[1]", + "traced_for=cond, fun=my_false_branch, arg_names=c,d, result_paths=result[0],result[1]", + "traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=,", ], expected_tracer_debug_infos=[ - 'traced_for=cond, fun=my_true_branch, arg_names=a,b', - 'traced_for=cond, fun=my_false_branch, arg_names=c,d', - 'traced_for=checkpoint / remat, fun=my_g, arg_names=x,y', + "traced_for=cond, fun=my_true_branch, arg_names=a,b, from a", + "traced_for=cond, fun=my_false_branch, arg_names=c,d, from c", + # TODO(necula): from None + "traced_for=checkpoint / remat, fun=my_g, arg_names=x,y, from None", ]) def test_grad_scan(self): @@ -1302,30 +1314,29 @@ class DebugInfoTest(jtu.JaxTestCase): c, as_, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=[0],[1]", - # TODO(necula): bad result paths + "traced_for=jit, fun=the_grad, arg_names=c,as_, result_paths=result[0],result[1]", + # TODO(necula): arg names, bad result paths "traced_for=jit, fun=my_f, arg_names=x,as_, result_paths=,,", - # TODO(necula): arg_names? "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], result_paths=", - "traced_for=for_loop, fun=f, arg_names=None,None,None, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None, result_paths=,", - "traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None, result_paths=", - "traced_for=for_loop, fun=f, arg_names=None,None,None,None,None,None,None,None,None,None,None,None,None,None,None, result_paths=,", - "traced_for=checkpoint / remat, fun=to_remat, arg_names=None,None,None, result_paths=,", - "traced_for=jit, fun=my_f, arg_names=None,None,x,as_, result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,, result_paths=,", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,, result_paths=", + "traced_for=for_loop, fun=f, arg_names=,,,,,,,,,,,,,,, result_paths=,", + "traced_for=checkpoint / remat, fun=to_remat, arg_names=,,, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=,,x,as_, result_paths=", ], expected_tracer_debug_infos=[ - "traced_for=jit, fun=the_grad, arg_names=c,as_", - "traced_for=scan, fun=f, arg_names=c,a", - "traced_for=jit, fun=my_f, arg_names=x,as_", - # TODO(necula): arg_names - "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2]", + "traced_for=jit, fun=the_grad, arg_names=c,as_, from c", + "traced_for=scan, fun=f, arg_names=c,a, from c", + "traced_for=jit, fun=my_f, arg_names=x,as_, from x", + # TODO(necula): arg_names, and "from x" + "traced_for=for_loop, fun=f, arg_names=i,refs[0],refs[1],refs[2], from refs[0]", ], expected_lowering_lines=[ re.compile(r".*func.func public @main\(%arg0: tensor loc\(\"c\"\)"), re.compile(r".*func.func public @main\(.*, %arg1: tensor<3x2xf..> loc\(\"as_\"\)"), - re.compile(r".*func.func public @main\(.* -> .*tensor {jax.result_info = \"\[0\]\""), - re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"\[1\]\""), + re.compile(r".*func.func public @main\(.* -> .*tensor {jax.result_info = \"result\[0\]\""), + re.compile(r".*func.func public @main\(.* -> .*tensor<3x2xf..> {jax.result_info = \"result\[1\]\""), # TODO(necula): unnamed function? re.compile(r".*func.func private @None"), ]) @@ -1348,11 +1359,10 @@ class DebugInfoTest(jtu.JaxTestCase): 0, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - 'traced_for=while_body, fun=my_body, arg_names=b, result_paths=', - 'traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=', + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=while_body, fun=my_body, arg_names=b, result_paths=result", + "traced_for=while_cond, fun=my_cond, arg_names=a, result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=while_cond, fun=my_cond, arg_names=a, from a", "traced_for=while_body, fun=my_body, arg_names=b, from b", @@ -1370,14 +1380,14 @@ class DebugInfoTest(jtu.JaxTestCase): 3., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=", + "traced_for=jit, fun=, arg_names=x, result_paths=result", # TODO(necula): bad arg_names, result_paths - 'traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=[0][0],[0][1]', + "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], result_paths=result[0][0],result[0][1]", ], expected_tracer_debug_infos=[ # TODO(necula): the arg_names are not right - "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1]", + "traced_for=scan, fun=my_body, arg_names=loop_carry[0],loop_carry[1], from loop_carry[1]", ] ) @@ -1389,14 +1399,14 @@ class DebugInfoTest(jtu.JaxTestCase): 5, 3., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=ub,x, result_paths=", - re.compile(r'traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths='), + "traced_for=jit, fun=, arg_names=ub,x, result_paths=result", + re.compile(r"traced_for=while_cond, fun=_fori_cond_fun at .*loops.py:.*, arg_names=loop_carry\[0\],loop_carry\[1\],loop_carry\[2\], result_paths="), # TODO(necula): arg_names and result_paths are not right - "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=[0],[1],[2]", + "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], result_paths=result[0],result[1],result[2]", ], expected_tracer_debug_infos=[ # TODO(necula): the arg_names are not right - "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2]", + "traced_for=while_body, fun=my_body, arg_names=loop_carry[0],loop_carry[1],loop_carry[2], from loop_carry[2]", ]) def test_scan(self): @@ -1413,10 +1423,9 @@ class DebugInfoTest(jtu.JaxTestCase): np.arange(8, dtype=np.int32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=[0],[1]", - "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=[0],[1]", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result[0],result[1]", + "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, result_paths=result[0],result[1]", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=scan, fun=my_scan_body, arg_names=carry,inp, from carry" ]) @@ -1433,7 +1442,7 @@ class DebugInfoTest(jtu.JaxTestCase): tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[], expected_tracer_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x"], + "traced_for=jit, fun=my_f, arg_names=x, from x"], ) def test_vmap_of_nested_jit(self): @@ -1452,13 +1461,13 @@ class DebugInfoTest(jtu.JaxTestCase): np.ones((8,), dtype=np.float32), np.zeros((8,), dtype=np.float32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=", - "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=['c']", + "traced_for=jit, fun=my_f, arg_names=x,y, result_paths=result", + "traced_for=jit, fun=my_g, arg_names=u,v, result_paths=result['c']", ], expected_tracer_debug_infos=[ # TODO(necula): missing debug info "None", - "traced_for=jit, fun=my_g, arg_names=u,v" + "traced_for=jit, fun=my_g, arg_names=u,v, from u" ]) def test_pmap(self): @@ -1471,11 +1480,11 @@ class DebugInfoTest(jtu.JaxTestCase): jax.pmap(my_f), np.ones((jax.device_count(),), dtype=np.float32), expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=x, result_paths=" + "traced_for=pmap, fun=my_f, arg_names=x, result_paths=result" ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=x" + "traced_for=pmap, fun=my_f, arg_names=x, from x" ], ) @@ -1493,10 +1502,9 @@ class DebugInfoTest(jtu.JaxTestCase): 1., x, x, x, # x, y, args[0], args[1] d=x, a=x, b=x, # kwargs expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=['u'],['v']", + "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], result_paths=result['u'],result['v']", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=pmap, fun=my_f, arg_names=y,args[0],args[1],a,kwargs['b'],kwargs['d'], from args[1]", ], @@ -1508,8 +1516,8 @@ class DebugInfoTest(jtu.JaxTestCase): re.compile(r".*func.func public @main\(.*%arg3: tensor<1xf..> loc\(\"a\"\)"), re.compile(r".*func.func public @main\(.*%arg4: tensor<1xf..> loc\(\"kwargs\['b'\]\"\)"), re.compile(r".*func.func public @main\(.*%arg5: tensor<1xf..> loc\(\"kwargs\['d'\]\"\)"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['u'\]\"\}"), - re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"\['v'\]\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['u'\]\"\}"), + re.compile(r".*func.func public @main\(.* -> .*\{jax.result_info = \"result\['v'\]\"\}"), ] ) @@ -1523,7 +1531,7 @@ class DebugInfoTest(jtu.JaxTestCase): jax.pmap(jax.grad(my_f)), np.ones((jax.device_count(),), dtype=np.float32), expected_jaxpr_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=x, result_paths=", + "traced_for=pmap, fun=my_f, arg_names=x, result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ @@ -1550,12 +1558,12 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ # TODO(necula): why this? re.compile(r'traced_for=jit, fun=_multi_slice at .*array_methods.py:.*, arg_names=self, result_paths=.*'), - "traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=['u'],['v']", + "traced_for=pmap, fun=my_f, arg_names=x,y,args[0],args[1], result_paths=result['u'],result['v']", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ # TODO(necula): missing debug_info - 'None' + "None" ], ) @@ -1574,13 +1582,13 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(lambda x, x_tan: jax.jvp(jax.pmap(my_f), (x, x), (x_tan, x_tan))), x, x_tan, expected_jaxpr_debug_infos=[ - 'traced_for=jit, fun=, arg_names=x,x_tan, result_paths=[0],[1]', - "traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=", + "traced_for=jit, fun=, arg_names=x,x_tan, result_paths=result[0],result[1]", + "traced_for=pmap, fun=my_f, arg_names=x,y, result_paths=result", ], tracer_spy=tracer_spy, expected_tracer_debug_infos=[ # TODO(necula): missing debug_info - 'None' + "None" ], ) @@ -1597,13 +1605,12 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(jax.hessian(jax.jit(my_f))), x, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): arg_names and result_paths? - "traced_for=jit, fun=my_f, arg_names=None,x, result_paths=,", + "traced_for=jit, fun=my_f, arg_names=,x, result_paths=,", "traced_for=jit, fun=my_f, arg_names=x, result_paths=,,,", ], tracer_spy=tracer_spy, - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=jit, fun=my_f, arg_names=x, from x", ], @@ -1626,11 +1633,10 @@ class DebugInfoTest(jtu.JaxTestCase): 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", # TODO(necula): missing result_paths - "traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=", + "traced_for=checkpoint / remat, fun=my_g, arg_names=y, result_paths=result", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y" ]) @@ -1650,12 +1656,12 @@ class DebugInfoTest(jtu.JaxTestCase): 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - # TODO(necula): arg_names? - "traced_for=checkpoint / remat, fun=my_g, arg_names=None,None, result_paths=", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + # TODO(necula): arg_names? result_paths? + "traced_for=checkpoint / remat, fun=my_g, arg_names=,, result_paths=", ], expected_tracer_debug_infos=[ - "traced_for=checkpoint / remat, fun=my_g, arg_names=y" + "traced_for=checkpoint / remat, fun=my_g, arg_names=y, from y", ]) def test_remat_shard_map(self): @@ -1677,13 +1683,12 @@ class DebugInfoTest(jtu.JaxTestCase): jnp.arange(2, dtype=np.float32), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - # TODO(necula): arg_names - "traced_for=jit, fun=, arg_names=x, result_paths=", - "traced_for=checkpoint / remat, fun=my_f, arg_names=None,None, result_paths=", - "traced_for=shard_map, fun=my_f, arg_names=x, result_paths=", - "traced_for=shard_map, fun=my_f, arg_names=None,None, result_paths=", + # TODO(necula): arg_names, result_paths + "traced_for=jit, fun=, arg_names=x, result_paths=result", + "traced_for=checkpoint / remat, fun=my_f, arg_names=,, result_paths=", + "traced_for=shard_map, fun=my_f, arg_names=x, result_paths=result", + "traced_for=shard_map, fun=my_f, arg_names=,, result_paths=", ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ "None" # TODO(necula): missing ]) @@ -1719,11 +1724,11 @@ class DebugInfoTest(jtu.JaxTestCase): expected_jaxpr_debug_infos=[ # TODO(necula): this should not be pointing into the JAX internals re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\]"), - re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths="), - "traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=[0]", + re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths=result"), + "traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=result[0]", ], expected_tracer_debug_infos=[ - "traced_for=pmap, fun=my_f, arg_names=my_x", + "traced_for=pmap, fun=my_f, arg_names=my_x, from my_x", ], check_lowering=False, # TODO(necula): warning during lowering ) @@ -1751,12 +1756,12 @@ class DebugInfoTest(jtu.JaxTestCase): 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=", - "traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=[0],[1]", + "traced_for=jit, fun=, arg_names=x, result_paths=result", + "traced_for=custom_dce, fun=my_g, arg_names=x, result_paths=result[0],result[1]", ], expected_tracer_debug_infos=[ # TODO(necula): no leaked tracer from my_g_dce? - "traced_for=custom_dce, fun=my_g, arg_names=x", + "traced_for=custom_dce, fun=my_g, arg_names=x, from x", ]) def test_custom_dce_consts(self): @@ -1779,11 +1784,10 @@ class DebugInfoTest(jtu.JaxTestCase): np.array(1.1234), tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=", + "traced_for=jit, fun=, arg_names=x, result_paths=result", # TODO(necula): bad arg_names (why None), bad result_paths - 'traced_for=custom_dce, fun=my_f, arg_names=None,x, result_paths=[0],[1]', + 'traced_for=custom_dce, fun=my_f, arg_names=,x, result_paths=result[0],result[1]', ], - check_tracer_arg_name=True, expected_tracer_debug_infos=[ # TODO(necula): no leaked tracer from my_rule? "traced_for=custom_dce, fun=my_f, arg_names=x, from x", @@ -1816,23 +1820,22 @@ class DebugInfoTest(jtu.JaxTestCase): a, b, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=a,b, result_paths=[0],[1]", - re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths="), - re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths="), - re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths="), + "traced_for=jit, fun=, arg_names=a,b, result_paths=result[0],result[1]", + re.compile(r"traced_for=jit, fun=_solve at .*scipy.linalg.py:.*, arg_names=a,b, result_paths=result"), + re.compile(r"traced_for=jit, fun=solve at .*linalg.py:.*, arg_names=a,b, result_paths=result"), + re.compile(r"traced_for=jit, fun=_lu_solve at .*linalg.py:.*, arg_names=lu,permutation,b, result_paths=result"), # TODO(necula): why pointers to internal functions, arg_names, result_paths? - re.compile(r'traced_for=custom_linear_solve solve, fun= at .*linalg.py:.*, arg_names=None,None,x, result_paths='), - re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=None,None,x, result_paths='), - re.compile(r'traced_for=custom_linear_solve, fun= at .*linalg.py:.*, arg_names=None,x, result_paths='), - re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=None,x, result_paths='), - 'traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=None,b, result_paths=', - 'traced_for=custom_linear_solve solve, fun=my_solve, arg_names=None,x, result_paths=', - 'traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=None,x, result_paths=', + re.compile(r'traced_for=custom_linear_solve solve, fun= at .*linalg.py:.*, arg_names=,,x, result_paths='), + re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=,,x, result_paths='), + re.compile(r'traced_for=custom_linear_solve, fun= at .*linalg.py:.*, arg_names=,x, result_paths='), + re.compile(r'traced_for=custom_linear_solve transpose_solve, fun= at .*linalg.py:.*, arg_names=,x, result_paths='), + "traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=,b, result_paths=result", + "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=,x, result_paths=result", + "traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=,x, result_paths=result", ], expected_tracer_debug_infos=[ - "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x", - "traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x", - "traced_for=custom_linear_solve, fun=my_high_precision_dot, arg_names=b", + "traced_for=custom_linear_solve solve, fun=my_solve, arg_names=x, from x", + "traced_for=custom_linear_solve transpose_solve, fun=my_tr_solve, arg_names=x, from x", "None", # TODO(necula): there are missing debug info ]) @@ -1856,14 +1859,15 @@ class DebugInfoTest(jtu.JaxTestCase): 0., tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=, arg_names=x, result_paths=[0],[1]", + "traced_for=jit, fun=, arg_names=x, result_paths=result[0],result[1]", # TODO(necula): internal function? - re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=\[0\]"), + re.compile(r"traced_for=custom_jvp fun, fun=_custom_root at .*control_flow.solves.py:.*, arg_names=args\[0\], result_paths=result\[0\]"), ], expected_tracer_debug_infos=[ - "traced_for=custom_root, fun=my_f, arg_names=x", - "traced_for=custom_root solve, fun=my_solve, arg_names=x", - "traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x", + "traced_for=custom_root, fun=my_f, arg_names=x, from x", + "traced_for=custom_root solve, fun=my_solve, arg_names=x, from x", + # TODO(necula): from None + "traced_for=custom_root tangent_solve, fun=my_transpose_solve, arg_names=x, from None", "None", # TODO(necula): there are missing debug info ]) @@ -1891,13 +1895,13 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(my_f), x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=x, result_paths=", - "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=[0],[1]", + "traced_for=jit, fun=my_f, arg_names=x, result_paths=result", + "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, result_paths=result[0],result[1]", "traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, result_paths=", ], expected_tracer_debug_infos=[ - "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j", - "traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref", + "traced_for=pallas_call index_map, fun=my_index_map, arg_names=i,j, from i", + "traced_for=pallas_call, fun=my_kernel, arg_names=x_ref,y_ref,o_ref, from x_ref", ], check_lowering=False, # We need interpret mode on CPU. TODO(necula) ) @@ -1921,14 +1925,14 @@ class DebugInfoTest(jtu.JaxTestCase): jnp.arange(4, dtype=jnp.float32) - 2, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_f, arg_names=input, result_paths=", + "traced_for=jit, fun=my_f, arg_names=input, result_paths=result", # TODO(necula): function source location points in JAX internals # TODO(necula): arg_names and result_paths are wrong re.compile(r"traced_for=checkify_pallas, fun=checked_kernel_fn at .*pallas_call.py:.*, arg_names=args\[0\],.*, result_paths="), re.compile(r"traced_for=pallas_call index_map, fun= at .*pallas.core.py:.*, arg_names=, result_paths="), ], expected_tracer_debug_infos=[ - "traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref", + "traced_for=pallas_call, fun=kernel, arg_names=x_ref,y_ref, from x_ref", ], check_lowering=False, # We need interpret mode on CPU. TODO(necula) ) @@ -1948,11 +1952,11 @@ class DebugInfoTest(jtu.JaxTestCase): jax.jit(my_consts), x, tracer_spy=tracer_spy, expected_jaxpr_debug_infos=[ - "traced_for=jit, fun=my_consts, arg_names=x, result_paths=", - "traced_for=composite, fun=my_consts, arg_names=x, result_paths=", + "traced_for=jit, fun=my_consts, arg_names=x, result_paths=result", + "traced_for=composite, fun=my_consts, arg_names=x, result_paths=result", ], expected_tracer_debug_infos=[ - "traced_for=composite, fun=my_consts, arg_names=x"]) + "traced_for=composite, fun=my_consts, arg_names=x, from x"]) if __name__ == '__main__': diff --git a/tests/mutable_array_test.py b/tests/mutable_array_test.py index 7a1d36317..9a6c5c167 100644 --- a/tests/mutable_array_test.py +++ b/tests/mutable_array_test.py @@ -274,13 +274,13 @@ class MutableArrayErrorsTest(jtu.JaxTestCase): def test_return_from_jit_pytree(self): with self.assertRaisesRegex( ValueError, - r"tree path \['hi'\]"): + r"tree path result\['hi'\]"): jax.jit(lambda x_ref: {'hi': x_ref})(core.mutable_array(jnp.arange(3))) def test_return_from_jit_closure(self): with self.assertRaisesRegex( ValueError, - r"tree path \['hi'\]"): + r"tree path result\['hi'\]"): x_ref = core.mutable_array(jnp.arange(3)) jax.jit(lambda: {'hi': x_ref})() diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9abfd14fc..bdb117166 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -6846,7 +6846,7 @@ class PJitErrorTest(jtu.JaxTestCase): spec = P(resources, None) mesh_size = str(math.prod([dim[1] for dim in mesh])) error = re.compile( - r"One of pjit outputs with pytree key path \['rrr'\].*" + spec_regex(spec) + r".*" + r"One of pjit outputs with pytree key path result\['rrr'\].*" + spec_regex(spec) + r".*" r"implies that the global size of its dimension 0 should be " r"divisible by " + mesh_size + r", but it is equal to 3", re.M | re.S) with self.assertRaisesRegex(ValueError, error):