improve an escaped tracer error message (#4312)

* improve an escaped tracer error message

Before this commit, encountering an escaped tracer in a specific way
would lead to a bad internal error. This change
1. raises an UnexpectedTracerError instead, and
2. includes in the error message the user source line which created the
tracer.

* deflake

* replace _live propety with _assert_live method

Thanks @jekbradbury !
This commit is contained in:
Matthew Johnson 2020-09-16 15:59:50 -07:00 committed by GitHub
parent e18a973198
commit 325d3bc71d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 56 additions and 17 deletions

View File

@ -422,12 +422,14 @@ class Trace:
del primitive, fwd, bwd, out_trees # Unused.
return fun.call_wrapped(*tracers)
def escaped_tracer_error(detail):
def escaped_tracer_error(detail=None):
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
"through global state from a previously traced function.\n"
"The functions being transformed should not save traced values to "
"global state.\nDetails: {}.")
return UnexpectedTracerError(msg.format(detail))
"global state.")
if detail:
msg += " Detail: {}.".format(detail)
return UnexpectedTracerError(msg)
class UnexpectedTracerError(Exception): pass
@ -462,6 +464,9 @@ class Tracer:
def aval(self):
raise NotImplementedError("must override")
def _assert_live(self) -> None:
pass # Override for liveness checking
# Python looks up special methods only on classes, not instances. This means
# these methods needs to be defined explicitly rather than relying on
# __getattr__.
@ -737,8 +742,8 @@ def full_lower(val):
return val
def find_top_trace(xs) -> Trace:
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=None, key=attrgetter('level'))
traces = [x._assert_live() or x._trace.main for x in xs if isinstance(x, Tracer)] # type: ignore
top_main = max(traces, default=None, key=attrgetter('level'))
dynamic = thread_local_state.trace_state.trace_stack.dynamic
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
else top_main)

View File

@ -830,6 +830,11 @@ class DynamicJaxprTracer(core.Tracer):
f"{self._trace.main.source_info}.")
return origin
def _assert_live(self) -> None:
if not self._trace.main.jaxpr_stack: # type: ignore
msg = f"tracer created on line {source_info_util.summarize(self.line_info)}"
raise core.escaped_tracer_error(msg)
class JaxprStackFrame:
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
'tracers', 'eqns']
@ -896,16 +901,18 @@ class DynamicJaxprTrace(core.Trace):
__slots__ = [] # type: ignore
@property
def frame(self): return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
def frame(self):
return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
def new_arg(self, aval):
tracer = DynamicJaxprTracer(self, aval)
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(aval)
return tracer
def new_const(self, val):
tracer = DynamicJaxprTracer(self, raise_to_shaped(get_aval(val), weak_type=dtypes.is_python_scalar(val)))
aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_python_scalar(val))
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
self.frame.tracers.append(tracer)
var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val)
self.frame.constvar_to_val[var] = val
@ -937,11 +944,11 @@ class DynamicJaxprTrace(core.Trace):
avals = [t.aval for t in tracers]
out_avals = primitive.abstract_eval(*avals, **params)
out_avals = [out_avals] if not primitive.multiple_results else out_avals
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.getvar, out_tracers)
eqn = new_jaxpr_eqn(invars, outvars, primitive, params,
source_info_util.current())
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
self.frame.eqns.append(eqn)
return out_tracers if primitive.multiple_results else out_tracers.pop()
@ -950,7 +957,8 @@ class DynamicJaxprTrace(core.Trace):
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
if not jaxpr.eqns:
return core.eval_jaxpr(jaxpr, consts, *tracers)
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.getvar, out_tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
@ -958,8 +966,8 @@ class DynamicJaxprTrace(core.Trace):
update_params = call_param_updaters.get(call_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, new_params,
source_info_util.current())
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
new_params, source_info)
self.frame.eqns.append(eqn)
return out_tracers
@ -975,7 +983,8 @@ class DynamicJaxprTrace(core.Trace):
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
f, self.main, reduced_in_avals)
out_avals = [core.unmapped_aval(params['axis_size'], a) for a in reduced_out_avals]
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
source_info = source_info_util.current()
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
invars = map(self.getvar, tracers)
outvars = map(self.getvar, out_tracers)
constvars = map(self.getvar, map(self.instantiate_const, consts))
@ -985,7 +994,8 @@ class DynamicJaxprTrace(core.Trace):
update_params = call_param_updaters.get(map_primitive)
if update_params:
new_params = update_params(new_params, [True] * len(tracers))
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, new_params)
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
new_params, source_info)
self.frame.eqns.append(eqn)
return out_tracers

View File

@ -1641,7 +1641,7 @@ class APITest(jtu.JaxTestCase):
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
"Encountered an unexpected tracer",
re.DOTALL)):
api.jit(lambda x: x)(self._saved_tracer)
@ -1688,6 +1688,30 @@ class APITest(jtu.JaxTestCase):
re.DOTALL)):
api.jit(func1)(2.)
def test_escaped_tracer_omnistaging(self):
if not config.omnistaging_enabled:
raise unittest.SkipTest("test is omnistaging-specific")
count = 1
@jit
def f():
nonlocal count
count = jnp.add(count, 1)
f() # leaked a tracer! but currently undetected
def f(x, c):
jnp.add(count, 1)
return None, None
@jit
def g():
lax.scan(f, None, None, length=2)
with self.assertRaisesRegex(core.UnexpectedTracerError,
"tracer created on line"):
g()
def test_pmap_static_kwarg_error_message(self):
# https://github.com/google/jax/issues/3007
def f(a, b):