Merge pull request #10358 from LenaMartens:changelist/442788164

PiperOrigin-RevId: 443116900
This commit is contained in:
jax authors 2022-04-20 09:34:22 -07:00
commit b2132f7884
3 changed files with 43 additions and 5 deletions

View File

@ -445,6 +445,8 @@ def check_error(error: Error) -> None:
err, code, payload = _reduce_any_error(error.err, error.code, error.payload)
else:
err, code, payload = error.err, error.code, error.payload
err = core.raise_as_much_as_possible(err)
return assert_p.bind(~err, code, payload, msgs=error.msgs)
assert_p = core.Primitive('assert') # TODO: rename to check?

View File

@ -478,6 +478,25 @@ class Trace:
"to handle custom_vjp primitives")
raise NotImplementedError(msg)
def raise_as_much_as_possible(tracer) -> Tracer:
# Find effective bottom of trace stack (highest dynamic Trace on the stack).
trace_stack = thread_local_state.trace_state.trace_stack.stack
idx = next(i for i, m in enumerate(trace_stack) if m is
thread_local_state.trace_state.trace_stack.dynamic)
# Only pay attention to effective part of trace stack.
trace_stack = trace_stack[idx:]
# Lift tracer into everything in the effective stack higher than its level
for trace in trace_stack:
trace = trace.with_cur_sublevel()
if (not isinstance(tracer, Tracer) or tracer._trace.level < trace.level):
tracer = trace.full_raise(tracer)
return tracer
def escaped_tracer_error(tracer, detail=None):
num_frames = FLAGS.jax_tracer_error_num_traceback_frames
msg = ('Encountered an unexpected tracer. A function transformed by JAX '

View File

@ -628,7 +628,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
with self.assertRaisesRegex(ValueError, "hi"):
f()
def test_assert_primitive_(self):
def test_assert_primitive_staging(self):
@jax.jit
def f():
checkify.check(False, "hi")
@ -658,15 +658,32 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "must be positive")
def test_assert_discharging_no_data_dependence(self):
@jax.jit
def g(x):
@checkify.checkify
def f():
# Note that x is not an argument to the checkified function.
checkify.check(x > 0, "must be positive!")
return jnp.log(x)
return f()
err, _ = g(1.)
self.assertIsNone(err.get())
err, _ = g(0.)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "must be positive")
def test_check_error(self):
def f(pred): # note: data dependence needed!
checkify.check_error(checkify.Error(~pred, 0, {0: "hi"}))
def f():
checkify.check_error(checkify.Error(True, 0, {0: "hi"}))
with self.assertRaisesRegex(ValueError, "hi"):
f(False)
f()
f = checkify.checkify(f)
err, none = f(False)
err, none = f()
self.assertIsNone(none)
self.assertIsNotNone(err.get())