Better error message when raise_if_error() is called within a traced context

PiperOrigin-RevId: 735557928
This commit is contained in:
Ayaka 2025-03-10 16:54:21 -07:00 committed by jax authors
parent aceae84fab
commit 988a1208a9
2 changed files with 26 additions and 1 deletions

View File

@ -88,12 +88,17 @@ def raise_if_error() -> None:
"""Raise error if an error is set.
This function should be called after the computation is finished. It should
be used outside jit.
not be called within a traced context, such as within a jitted function."
"""
if _error_storage.ref is None: # if not initialized, do nothing
return
error_code = _error_storage.ref[...]
if isinstance(error_code, core.Tracer):
raise ValueError(
"raise_if_error() should not be called within a traced context, such as"
" within a jitted function."
)
if error_code == jnp.uint32(_NO_ERROR):
return
_error_storage.ref[...] = jnp.uint32(_NO_ERROR)

View File

@ -170,6 +170,26 @@ class ErrorCheckTests(jtu.JaxTestCase):
_ = body(init, xs)
error_check.raise_if_error() # should not raise error
@parameterized.product(jit=[True, False])
def test_raise_if_error_fails_in_traced_context(self, jit):
def f(x):
error_check.set_error_if(x <= 0, "x must be greater than 0")
return x + 1
if jit:
f = jax.jit(f)
x = jnp.full((4,), 1, dtype=jnp.int32)
f(x)
with self.assertRaises(
ValueError,
msg=(
"raise_if_error() should not be called within a traced context,"
" such as within a jitted function."
),
):
jax.jit(error_check.raise_if_error)()
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())