mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
parent
c6b7269480
commit
b81c246a18
@ -365,6 +365,7 @@ class Trace:
|
||||
def full_raise(self, val) -> 'Tracer':
|
||||
if not isinstance(val, Tracer):
|
||||
return self.pure(val)
|
||||
val._assert_live()
|
||||
level = self.level
|
||||
sublevel = self.sublevel
|
||||
if val._trace.main is self.main:
|
||||
@ -742,8 +743,8 @@ def full_lower(val):
|
||||
return val
|
||||
|
||||
def find_top_trace(xs) -> Trace:
|
||||
traces = [x._assert_live() or x._trace.main for x in xs if isinstance(x, Tracer)] # type: ignore
|
||||
top_main = max(traces, default=None, key=attrgetter('level'))
|
||||
top_main = max((x._trace.main for x in xs if isinstance(x, Tracer)),
|
||||
default=None, key=attrgetter('level'))
|
||||
dynamic = thread_local_state.trace_state.trace_stack.dynamic
|
||||
top_main = (dynamic if top_main is None or dynamic.level > top_main.level
|
||||
else top_main)
|
||||
|
@ -1662,13 +1662,10 @@ class APITest(jtu.JaxTestCase):
|
||||
self._saved_tracer = x
|
||||
return x
|
||||
|
||||
def test_escaped_tracers_diffent_top_level_traces(self):
|
||||
def test_escaped_tracers_different_top_level_traces(self):
|
||||
api.jit(self.helper_save_tracer)(0.)
|
||||
with self.assertRaisesRegex(
|
||||
core.UnexpectedTracerError,
|
||||
re.compile(
|
||||
"Encountered an unexpected tracer.*Different traces at same level",
|
||||
re.DOTALL)):
|
||||
core.UnexpectedTracerError, "Encountered an unexpected tracer"):
|
||||
api.jit(lambda x: self._saved_tracer)(0.)
|
||||
|
||||
def test_escaped_tracers_cant_lift_sublevels(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user