mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
improve an escaped tracer error message (#4312)
* improve an escaped tracer error message Before this commit, encountering an escaped tracer in a specific way would lead to a bad internal error. This change 1. raises an UnexpectedTracerError instead, and 2. includes in the error message the user source line which created the tracer. * deflake * replace _live propety with _assert_live method Thanks @jekbradbury !
This commit is contained in:
parent
e18a973198
commit
325d3bc71d
15
jax/core.py
15
jax/core.py
@ -422,12 +422,14 @@ class Trace:
|
||||
del primitive, fwd, bwd, out_trees # Unused.
|
||||
return fun.call_wrapped(*tracers)
|
||||
|
||||
def escaped_tracer_error(detail):
|
||||
def escaped_tracer_error(detail=None):
|
||||
msg = ("Encountered an unexpected tracer. Perhaps this tracer escaped "
|
||||
"through global state from a previously traced function.\n"
|
||||
"The functions being transformed should not save traced values to "
|
||||
"global state.\nDetails: {}.")
|
||||
return UnexpectedTracerError(msg.format(detail))
|
||||
"global state.")
|
||||
if detail:
|
||||
msg += " Detail: {}.".format(detail)
|
||||
return UnexpectedTracerError(msg)
|
||||
|
||||
class UnexpectedTracerError(Exception): pass
|
||||
|
||||
@ -462,6 +464,9 @@ class Tracer:
|
||||
def aval(self):
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def _assert_live(self) -> None:
|
||||
pass # Override for liveness checking
|
||||
|
||||
# Python looks up special methods only on classes, not instances. This means
|
||||
# these methods needs to be defined explicitly rather than relying on
|
||||
# __getattr__.
|
||||
@ -737,8 +742,8 @@ def full_lower(val):
|
||||
return val
|
||||
|
||||
def find_top_trace(xs) -> Trace:
|
||||
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
|
||||
default=None, key=attrgetter('level'))
|
||||
traces = [x._assert_live() or x._trace.main for x in xs if isinstance(x, Tracer)] # type: ignore
|
||||
top_main = max(traces, default=None, key=attrgetter('level'))
|
||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
|
||||
else top_main)
|
||||
|
@ -830,6 +830,11 @@ class DynamicJaxprTracer(core.Tracer):
|
||||
f"{self._trace.main.source_info}.")
|
||||
return origin
|
||||
|
||||
def _assert_live(self) -> None:
|
||||
if not self._trace.main.jaxpr_stack: # type: ignore
|
||||
msg = f"tracer created on line {source_info_util.summarize(self.line_info)}"
|
||||
raise core.escaped_tracer_error(msg)
|
||||
|
||||
class JaxprStackFrame:
|
||||
__slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val',
|
||||
'tracers', 'eqns']
|
||||
@ -896,16 +901,18 @@ class DynamicJaxprTrace(core.Trace):
|
||||
__slots__ = [] # type: ignore
|
||||
|
||||
@property
|
||||
def frame(self): return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
|
||||
def frame(self):
|
||||
return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error
|
||||
|
||||
def new_arg(self, aval):
|
||||
tracer = DynamicJaxprTracer(self, aval)
|
||||
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
|
||||
self.frame.tracers.append(tracer)
|
||||
self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(aval)
|
||||
return tracer
|
||||
|
||||
def new_const(self, val):
|
||||
tracer = DynamicJaxprTracer(self, raise_to_shaped(get_aval(val), weak_type=dtypes.is_python_scalar(val)))
|
||||
aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_python_scalar(val))
|
||||
tracer = DynamicJaxprTracer(self, aval, source_info_util.current())
|
||||
self.frame.tracers.append(tracer)
|
||||
var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val)
|
||||
self.frame.constvar_to_val[var] = val
|
||||
@ -937,11 +944,11 @@ class DynamicJaxprTrace(core.Trace):
|
||||
avals = [t.aval for t in tracers]
|
||||
out_avals = primitive.abstract_eval(*avals, **params)
|
||||
out_avals = [out_avals] if not primitive.multiple_results else out_avals
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
source_info = source_info_util.current()
|
||||
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
outvars = map(self.getvar, out_tracers)
|
||||
eqn = new_jaxpr_eqn(invars, outvars, primitive, params,
|
||||
source_info_util.current())
|
||||
eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info)
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers if primitive.multiple_results else out_tracers.pop()
|
||||
|
||||
@ -950,7 +957,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals)
|
||||
if not jaxpr.eqns:
|
||||
return core.eval_jaxpr(jaxpr, consts, *tracers)
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
source_info = source_info_util.current()
|
||||
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
outvars = map(self.getvar, out_tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
||||
@ -958,8 +966,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
update_params = call_param_updaters.get(call_primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, [True] * len(tracers))
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, new_params,
|
||||
source_info_util.current())
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
|
||||
new_params, source_info)
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
|
||||
@ -975,7 +983,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic(
|
||||
f, self.main, reduced_in_avals)
|
||||
out_avals = [core.unmapped_aval(params['axis_size'], a) for a in reduced_out_avals]
|
||||
out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals]
|
||||
source_info = source_info_util.current()
|
||||
out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
|
||||
invars = map(self.getvar, tracers)
|
||||
outvars = map(self.getvar, out_tracers)
|
||||
constvars = map(self.getvar, map(self.instantiate_const, consts))
|
||||
@ -985,7 +994,8 @@ class DynamicJaxprTrace(core.Trace):
|
||||
update_params = call_param_updaters.get(map_primitive)
|
||||
if update_params:
|
||||
new_params = update_params(new_params, [True] * len(tracers))
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, new_params)
|
||||
eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive,
|
||||
new_params, source_info)
|
||||
self.frame.eqns.append(eqn)
|
||||
return out_tracers
|
||||
|
||||
|
@ -1641,7 +1641,7 @@ class APITest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Can't lift sublevels 1 to 0",
|
||||
"Encountered an unexpected tracer",
|
||||
re.DOTALL)):
|
||||
api.jit(lambda x: x)(self._saved_tracer)
|
||||
|
||||
@ -1688,6 +1688,30 @@ class APITest(jtu.JaxTestCase):
|
||||
re.DOTALL)):
|
||||
api.jit(func1)(2.)
|
||||
|
||||
def test_escaped_tracer_omnistaging(self):
|
||||
if not config.omnistaging_enabled:
|
||||
raise unittest.SkipTest("test is omnistaging-specific")
|
||||
|
||||
count = 1
|
||||
|
||||
@jit
|
||||
def f():
|
||||
nonlocal count
|
||||
count = jnp.add(count, 1)
|
||||
f() # leaked a tracer! but currently undetected
|
||||
|
||||
def f(x, c):
|
||||
jnp.add(count, 1)
|
||||
return None, None
|
||||
|
||||
@jit
|
||||
def g():
|
||||
lax.scan(f, None, None, length=2)
|
||||
|
||||
with self.assertRaisesRegex(core.UnexpectedTracerError,
|
||||
"tracer created on line"):
|
||||
g()
|
||||
|
||||
def test_pmap_static_kwarg_error_message(self):
|
||||
# https://github.com/google/jax/issues/3007
|
||||
def f(a, b):
|
||||
|
Loading…
x
Reference in New Issue
Block a user