mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Better error message when raise_if_error()
is called within a traced context
PiperOrigin-RevId: 735557928
This commit is contained in:
parent
aceae84fab
commit
988a1208a9
@ -88,12 +88,17 @@ def raise_if_error() -> None:
|
||||
"""Raise error if an error is set.
|
||||
|
||||
This function should be called after the computation is finished. It should
|
||||
be used outside jit.
|
||||
not be called within a traced context, such as within a jitted function."
|
||||
"""
|
||||
if _error_storage.ref is None: # if not initialized, do nothing
|
||||
return
|
||||
|
||||
error_code = _error_storage.ref[...]
|
||||
if isinstance(error_code, core.Tracer):
|
||||
raise ValueError(
|
||||
"raise_if_error() should not be called within a traced context, such as"
|
||||
" within a jitted function."
|
||||
)
|
||||
if error_code == jnp.uint32(_NO_ERROR):
|
||||
return
|
||||
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)
|
||||
|
@ -170,6 +170,26 @@ class ErrorCheckTests(jtu.JaxTestCase):
|
||||
_ = body(init, xs)
|
||||
error_check.raise_if_error() # should not raise error
|
||||
|
||||
@parameterized.product(jit=[True, False])
|
||||
def test_raise_if_error_fails_in_traced_context(self, jit):
|
||||
def f(x):
|
||||
error_check.set_error_if(x <= 0, "x must be greater than 0")
|
||||
return x + 1
|
||||
|
||||
if jit:
|
||||
f = jax.jit(f)
|
||||
|
||||
x = jnp.full((4,), 1, dtype=jnp.int32)
|
||||
f(x)
|
||||
with self.assertRaises(
|
||||
ValueError,
|
||||
msg=(
|
||||
"raise_if_error() should not be called within a traced context,"
|
||||
" such as within a jitted function."
|
||||
),
|
||||
):
|
||||
jax.jit(error_check.raise_if_error)()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user