mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
test eval_context works w/ and w/o omnistaging (#4325)
This commit is contained in:
parent
40e20242db
commit
11007ba0e3
@ -1493,7 +1493,8 @@ def pp_kv_pairs(kv_pairs):
|
||||
@config.register_omnistaging_disabler
|
||||
def omnistaging_disabler() -> None:
|
||||
global thread_local_state, call_bind, find_top_trace, initial_style_staging, \
|
||||
new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env
|
||||
new_main, reset_trace_state, TraceStack, TraceState, extend_axis_env, \
|
||||
eval_context
|
||||
|
||||
class TraceStack:
|
||||
upward: List[MainTrace]
|
||||
|
@ -1784,6 +1784,14 @@ class APITest(jtu.JaxTestCase):
|
||||
jaxpr = api.make_jaxpr(lambda: jnp.add(1, 1))()
|
||||
self.assertLen(jaxpr.jaxpr.eqns, 0)
|
||||
|
||||
def test_eval_context(self):
|
||||
@jit
|
||||
def f():
|
||||
with core.eval_context():
|
||||
assert jnp.add(1, 1) == 2
|
||||
|
||||
f() # doesn't crash
|
||||
|
||||
|
||||
class RematTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user