Thread through transform name.

This commit is contained in:
Lena Martens 2021-03-23 19:47:58 +00:00 committed by lenamartens
parent 88f5e26482
commit be70820ca1
6 changed files with 55 additions and 15 deletions

View File

@ -63,14 +63,18 @@ T = TypeVar('T')
Array = Any
@cache()
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals):
def _initial_style_open_jaxpr(fun: Callable, in_tree, in_avals,
transform_name: str = ""):
wrapped_fun, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals)
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(wrapped_fun, in_avals,
transform_name=transform_name)
return jaxpr, consts, out_tree()
@cache()
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals):
jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals)
def _initial_style_jaxpr(fun: Callable, in_tree, in_avals,
transform_name: str = ""):
jaxpr, consts, out_tree = _initial_style_open_jaxpr(fun, in_tree, in_avals,
transform_name)
closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
return closed_jaxpr, consts, out_tree
@ -266,8 +270,8 @@ def while_loop(cond_fun: Callable[[T], bool],
def _create_jaxpr(init_val):
init_vals, in_tree = tree_flatten((init_val,))
init_avals = tuple(_map(_abstractify, init_vals))
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals)
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals)
cond_jaxpr, cond_consts, cond_tree = _initial_style_jaxpr(cond_fun, in_tree, init_avals, "while_cond")
body_jaxpr, body_consts, body_tree = _initial_style_jaxpr(body_fun, in_tree, init_avals, "while_loop")
if not treedef_is_leaf(cond_tree) or len(cond_jaxpr.out_avals) != 1:
msg = "cond_fun must return a boolean scalar, but got pytree {}."
raise TypeError(msg.format(cond_tree))
@ -1256,7 +1260,7 @@ def scan(f: Callable[[Carry, X], Tuple[Carry, Y]],
in_flat, in_tree = tree_flatten((init, xs))
carry_avals = tuple(_map(_abstractify, init_flat))
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals)
jaxpr, consts, out_tree = _initial_style_jaxpr(f, in_tree, carry_avals + x_avals, "scan")
out_tree_children = out_tree.children()
if len(out_tree_children) != 2:
msg = "scan body output must be a pair, got {}."

View File

@ -1169,9 +1169,11 @@ def _memoize(thunk):
return memoized
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
def trace_to_jaxpr_dynamic(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
transform_name: str = ""):
with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore
main.source_info = fun_sourceinfo(fun.f) # type: ignore
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
del main, fun
@ -1198,9 +1200,11 @@ def extend_jaxpr_stack(main, frame):
assert frame is main.jaxpr_stack[-1]
main.jaxpr_stack = main.jaxpr_stack[:-1]
def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]):
def trace_to_jaxpr_final(fun: lu.WrappedFun,
in_avals: Sequence[AbstractValue],
transform_name: str = ""):
with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore
main.source_info = fun_sourceinfo(fun.f) # type: ignore
main.source_info = fun_sourceinfo(fun.f, transform_name) # type: ignore
main.jaxpr_stack = () # type: ignore
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals)
del fun, main
@ -1214,12 +1218,15 @@ def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[Partial
with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore
return trace_to_jaxpr(fun, in_pvals)
def fun_sourceinfo(fun):
def fun_sourceinfo(fun, transform_name: str = ""):
if isinstance(fun, functools.partial):
fun = fun.func
try:
filename = fun.__code__.co_filename
lineno = fun.__code__.co_firstlineno
return f"{fun.__name__} at {filename}:{lineno}"
line_info = f"{fun.__name__} at {filename}:{lineno}"
if transform_name:
line_info += f', transformed by {transform_name}.'
return line_info
except AttributeError:
return "<unknown>"

View File

@ -685,7 +685,7 @@ def parallel_callable(fun: lu.WrappedFun,
logging.vlog(2, "global_sharded_avals: %s", global_sharded_avals)
with core.extend_axis_env(axis_name, global_axis_size, None): # type: ignore
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals)
jaxpr, out_sharded_avals, consts = pe.trace_to_jaxpr_final(fun, global_sharded_avals, transform_name="pmap")
jaxpr = xla.apply_outfeed_rewriter(jaxpr)
out_axes = out_axes_thunk()

View File

@ -649,7 +649,7 @@ def _xla_callable(fun: lu.WrappedFun, device, backend, name, donated_invars, *ar
"got device={} and backend={}".format(device, backend))
abstract_args, arg_devices = unzip2(arg_specs)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(fun, abstract_args, transform_name="jit")
if any(isinstance(c, core.Tracer) for c in consts):
raise core.UnexpectedTracerError("Encountered an unexpected tracer.")
map(prefetch, it.chain(consts, jaxpr_literals(jaxpr)))

View File

@ -2070,6 +2070,17 @@ class APITest(jtu.JaxTestCase):
# level, which is no longer live.
jax.jit(jnp.add)(jnp.ones(()), count)
def test_escaped_tracer_transform_name(self):
with self.assertRaisesRegex(core.UnexpectedTracerError,
"transformed by jit"):
jax.jit(self.helper_save_tracer)(1)
_ = self._saved_tracer+1
with self.assertRaisesRegex(core.UnexpectedTracerError,
"transformed by pmap"):
jax.pmap(self.helper_save_tracer)(jnp.ones((1, 2)))
_ = self._saved_tracer+1
def test_pmap_static_kwarg_error_message(self):
# https://github.com/google/jax/issues/3007
def f(a, b):

View File

@ -2602,6 +2602,24 @@ class LaxControlFlowTest(jtu.JaxTestCase):
self.assertAllClose(deriv(my_pow)(3.0, 1), 1.0, check_dtypes=False)
def test_unexpected_tracer_error(self):
with self.assertRaisesRegex(core.UnexpectedTracerError,
"transformed by while_loop"):
lst = []
def side_effecting_body(val):
lst.append(val)
return val+1
lax.while_loop(lambda x: x < 2, side_effecting_body, 1)
lst[0] += 1
with self.assertRaisesRegex(core.UnexpectedTracerError,
"transformed by scan"):
lst = []
def side_effecting_scan(carry, val):
lst.append(val)
return carry, val+1
lax.scan(side_effecting_scan, None, jnp.ones((2, 2)))
lst[0] += 1
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())