mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #10358 from LenaMartens:changelist/442788164
PiperOrigin-RevId: 443116900
This commit is contained in:
commit
b2132f7884
@ -445,6 +445,8 @@ def check_error(error: Error) -> None:
|
||||
err, code, payload = _reduce_any_error(error.err, error.code, error.payload)
|
||||
else:
|
||||
err, code, payload = error.err, error.code, error.payload
|
||||
|
||||
err = core.raise_as_much_as_possible(err)
|
||||
return assert_p.bind(~err, code, payload, msgs=error.msgs)
|
||||
|
||||
assert_p = core.Primitive('assert') # TODO: rename to check?
|
||||
|
19
jax/core.py
19
jax/core.py
@ -478,6 +478,25 @@ class Trace:
|
||||
"to handle custom_vjp primitives")
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def raise_as_much_as_possible(tracer) -> Tracer:
|
||||
# Find effective bottom of trace stack (highest dynamic Trace on the stack).
|
||||
trace_stack = thread_local_state.trace_state.trace_stack.stack
|
||||
idx = next(i for i, m in enumerate(trace_stack) if m is
|
||||
thread_local_state.trace_state.trace_stack.dynamic)
|
||||
|
||||
# Only pay attention to effective part of trace stack.
|
||||
trace_stack = trace_stack[idx:]
|
||||
|
||||
# Lift tracer into everything in the effective stack higher than its level
|
||||
for trace in trace_stack:
|
||||
trace = trace.with_cur_sublevel()
|
||||
if (not isinstance(tracer, Tracer) or tracer._trace.level < trace.level):
|
||||
tracer = trace.full_raise(tracer)
|
||||
|
||||
return tracer
|
||||
|
||||
|
||||
def escaped_tracer_error(tracer, detail=None):
|
||||
num_frames = FLAGS.jax_tracer_error_num_traceback_frames
|
||||
msg = ('Encountered an unexpected tracer. A function transformed by JAX '
|
||||
|
@ -628,7 +628,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(ValueError, "hi"):
|
||||
f()
|
||||
|
||||
def test_assert_primitive_(self):
|
||||
def test_assert_primitive_staging(self):
|
||||
@jax.jit
|
||||
def f():
|
||||
checkify.check(False, "hi")
|
||||
@ -658,15 +658,32 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "must be positive")
|
||||
|
||||
def test_assert_discharging_no_data_dependence(self):
|
||||
@jax.jit
|
||||
def g(x):
|
||||
@checkify.checkify
|
||||
def f():
|
||||
# Note that x is not an argument to the checkified function.
|
||||
checkify.check(x > 0, "must be positive!")
|
||||
return jnp.log(x)
|
||||
return f()
|
||||
|
||||
err, _ = g(1.)
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
err, _ = g(0.)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "must be positive")
|
||||
|
||||
def test_check_error(self):
|
||||
def f(pred): # note: data dependence needed!
|
||||
checkify.check_error(checkify.Error(~pred, 0, {0: "hi"}))
|
||||
def f():
|
||||
checkify.check_error(checkify.Error(True, 0, {0: "hi"}))
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "hi"):
|
||||
f(False)
|
||||
f()
|
||||
|
||||
f = checkify.checkify(f)
|
||||
err, none = f(False)
|
||||
err, none = f()
|
||||
|
||||
self.assertIsNone(none)
|
||||
self.assertIsNotNone(err.get())
|
||||
|
Loading…
x
Reference in New Issue
Block a user