move the trace liveness check from #4312 (#4315)

This commit is contained in:
Matthew Johnson 2020-09-16 23:59:58 -07:00 committed by GitHub
parent c6b7269480
commit b81c246a18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 7 deletions

View File

@ -365,6 +365,7 @@ class Trace:
def full_raise(self, val) -> 'Tracer':
if not isinstance(val, Tracer):
return self.pure(val)
val._assert_live()
level = self.level
sublevel = self.sublevel
if val._trace.main is self.main:
@ -742,8 +743,8 @@ def full_lower(val):
return val
def find_top_trace(xs) -> Trace:
traces = [x._assert_live() or x._trace.main for x in xs if isinstance(x, Tracer)] # type: ignore
top_main = max(traces, default=None, key=attrgetter('level'))
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
default=None, key=attrgetter('level'))
dynamic = thread_local_state.trace_state.trace_stack.dynamic
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
else top_main)

View File

@ -1662,13 +1662,10 @@ class APITest(jtu.JaxTestCase):
self._saved_tracer = x
return x
def test_escaped_tracers_diffent_top_level_traces(self):
def test_escaped_tracers_different_top_level_traces(self):
api.jit(self.helper_save_tracer)(0.)
with self.assertRaisesRegex(
core.UnexpectedTracerError,
re.compile(
"Encountered an unexpected tracer.*Different traces at same level",
re.DOTALL)):
core.UnexpectedTracerError, "Encountered an unexpected tracer"):
api.jit(lambda x: self._saved_tracer)(0.)
def test_escaped_tracers_cant_lift_sublevels(self):