mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
try resetting global tracer state in loops_test.py
attempting to address #2507
This commit is contained in:
parent
de37eae628
commit
1f03d48c83
@ -27,6 +27,15 @@ from jax.experimental import loops
|
||||
from jax.config import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
# Attempted fix for https://github.com/google/jax/issues/2507 based on resetting
|
||||
# the global trace state. It could be that methods like _BodyTracer.end_subtrace
|
||||
# are not cleaning up global trace state after exceptions because they don't use
|
||||
# a try/finally pattern. This is just a guess though!
|
||||
# TODO(mattjj,necula): check this attempted fix
|
||||
from jax import core
|
||||
def tearDownModule():
|
||||
core.trace_state = core.TraceState()
|
||||
|
||||
class LoopsTest(jtu.JaxTestCase):
|
||||
|
||||
def test_scope_no_loops(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user