diff --git a/jax/_src/lax/control_flow.py b/jax/_src/lax/control_flow.py index 4e08a79c4..2220cc6b6 100644 --- a/jax/_src/lax/control_flow.py +++ b/jax/_src/lax/control_flow.py @@ -63,14 +63,18 @@ T = TypeVar('T') Array = Any @cache() -def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals): +def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals, + transform_name: str = ""): wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) - jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals) + jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals, + transform_name=transform_name) return jaxpr, consts, out_tree() @cache() -def _initial_style_jaxpr(fun: Callable, in_tree, in_avals): - jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals) +def _initial_style_jaxpr(fun: Callable, in_tree, in_avals, + transform_name: str = ""): + jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals, + transform_name) closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) return closed_jaxpr, consts, out_tree @@ -266,8 +270,8 @@ def while_loop(cond_fun: Callable[[T], bool], def _create_jaxpr(init_val): init_vals, in_tree = tree_flatten((init_val,)) init_avals = tuple(_map(_abstractify, init_vals)) - cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals) - body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals) + cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals, "while_cond") + body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals, "while_loop") if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1: msg = "cond_fun must return a boolean scalar, but got pytree {}." raise TypeError(msg.format(cond_tree)) @@ -1256,7 +1260,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]], in_flat, in_tree = tree_flatten((init, xs)) carry_avals = tuple(_map(_abstractify, init_flat)) - jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals) + jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals, "scan") out_tree_children = out_tree.children() if len(out_tree_children) != 2: msg = "scan body output must be a pair, got {}." diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 66d5d01dc..24fd68bee 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -1169,9 +1169,11 @@ def _memoize(thunk): return memoized -def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): +def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, + in_avals: Sequence[AbstractValue], + transform_name: str = ""): with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore - main.source_info = fun_sourceinfo(fun.f) # type: ignore + main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore main.jaxpr_stack = () # type: ignore jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) del main, fun @@ -1198,9 +1200,11 @@ def extend_jaxpr_stack(main, frame): assert frame is main.jaxpr_stack[-1] main.jaxpr_stack = main.jaxpr_stack[:-1] -def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): +def trace_to_jaxpr_final(fun: lu.WrappedFun, + in_avals: Sequence[AbstractValue], + transform_name: str = ""): with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore - main.source_info = fun_sourceinfo(fun.f) # type: ignore + main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore main.jaxpr_stack = () # type: ignore jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) del fun, main @@ -1214,12 +1218,15 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore return trace_to_jaxpr(fun, in_pvals) -def fun_sourceinfo(fun): +def fun_sourceinfo(fun, transform_name: str = ""): if isinstance(fun, functools.partial): fun = fun.func try: filename = fun.__code__.co_filename lineno = fun.__code__.co_firstlineno - return f"{fun.__name__} at {filename}:{lineno}" + line_info = f"{fun.__name__} at {filename}:{lineno}" + if transform_name: + line_info += f', transformed by {transform_name}.' + return line_info except AttributeError: return "" diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index b355ff9c7..aa45011f8 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -685,7 +685,7 @@ def parallel_callable(fun: lu.WrappedFun, logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals) with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore - jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals) + jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals, transform_name="pmap") jaxpr = xla.apply_outfeed_rewriter(jaxpr) out_axes = out_axes_thunk() diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index 0bdbbb62e..e1dea3d54 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -649,7 +649,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar "got device={} and backend={}".format(device, backend)) abstract_args, arg_devices = unzip2(arg_specs) - jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args) + jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit") if any(isinstance(c, core.Tracer) for c in consts): raise core.UnexpectedTracerError("Encountered an unexpected tracer.") map(prefetch, it.chain(consts, jaxpr_literals(jaxpr))) diff --git a/tests/api_test.py b/tests/api_test.py index 47fd8b5bb..be84bdff7 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -2070,6 +2070,17 @@ class APITest(jtu.JaxTestCase): # level, which is no longer live. jax.jit(jnp.add)(jnp.ones(()), count) + def test_escaped_tracer_transform_name(self): + with self.assertRaisesRegex(core.UnexpectedTracerError, + "transformed by jit"): + jax.jit(self.helper_save_tracer)(1) + _ = self._saved_tracer+1 + + with self.assertRaisesRegex(core.UnexpectedTracerError, + "transformed by pmap"): + jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2))) + _ = self._saved_tracer+1 + def test_pmap_static_kwarg_error_message(self): # https://github.com/google/jax/issues/3007 def f(a, b): diff --git a/tests/lax_control_flow_test.py b/tests/lax_control_flow_test.py index ee01e0a3f..2b21a7925 100644 --- a/tests/lax_control_flow_test.py +++ b/tests/lax_control_flow_test.py @@ -2602,6 +2602,24 @@ class LaxControlFlowTest(jtu.JaxTestCase): self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False) + def test_unexpected_tracer_error(self): + with self.assertRaisesRegex(core.UnexpectedTracerError, + "transformed by while_loop"): + lst = [] + def side_effecting_body(val): + lst.append(val) + return val+1 + lax.while_loop(lambda x: x < 2, side_effecting_body, 1) + lst[0] += 1 + + with self.assertRaisesRegex(core.UnexpectedTracerError, + "transformed by scan"): + lst = [] + def side_effecting_scan(carry, val): + lst.append(val) + return carry, val+1 + lax.scan(side_effecting_scan, None, jnp.ones((2, 2))) + lst[0] += 1 if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())