test eval_context works w/ and w/o omnistaging (#4325)

This commit is contained in:
Matthew Johnson 2020-09-17 09:57:43 -07:00 committed by GitHub
parent 40e20242db
commit 11007ba0e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 1 deletions

View File

@ -1493,7 +1493,8 @@ def pp_kv_pairs(kv_pairs):
@config.register_omnistaging_disabler
def omnistaging_disabler() -> None:
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env
new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env, \
eval_context
class TraceStack:
upward: List[MainTrace]

View File

@ -1784,6 +1784,14 @@ class APITest(jtu.JaxTestCase):
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
self.assertLen(jaxpr.jaxpr.eqns, 0)
def test_eval_context(self):
@jit
def f():
with core.eval_context():
assert jnp.add(1, 1) == 2
f() # doesn't crash
class RematTest(jtu.JaxTestCase):