try resetting global tracer state in loops_test.py

attempting to address #2507
This commit is contained in:
Matthew Johnson 2020-03-30 20:10:39 -07:00
parent de37eae628
commit 1f03d48c83

View File

@ -27,6 +27,15 @@ from jax.experimental import loops
from jax.config import config
config.parse_flags_with_absl()
# Attempted fix for https://github.com/google/jax/issues/2507 based on resetting
# the global trace state. It could be that methods like _BodyTracer.end_subtrace
# are not cleaning up global trace state after exceptions because they don't use
# a try/finally pattern. This is just a guess though!
# TODO(mattjj,necula): check this attempted fix
from jax import core
def tearDownModule():
core.trace_state = core.TraceState()
class LoopsTest(jtu.JaxTestCase):
def test_scope_no_loops(self):