diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index 7a08dbfce..9e5614a71 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -305,8 +305,9 @@ def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error): cond_f = core.jaxpr_as_fun(cond_jaxpr) body_f = core.jaxpr_as_fun(body_jaxpr) def new_body_f(*vals): - _ = cond_f(*vals) - return body_f(*vals) + out = body_f(*vals) + _ = cond_f(*out) # this checks if the next cond application will error + return out return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals) def ignore_errors_jaxpr(jaxpr, error): @@ -322,18 +323,22 @@ def ignore_errors_jaxpr(jaxpr, error): return core.ClosedJaxpr(new_jaxpr, consts) def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): - # TODO(lenamartens): fix when an error occurs in the cond function and it then returns False. - checked_body_jaxpr, msgs_ = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error) + checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error) + checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr) + # Check if the first cond application will error. + cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat) + + checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error) compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error) c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) - new_in_flat = [*c_consts, *b_consts, error.err, error.code, *carry] + new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry] err, code, *out = control_flow.while_p.bind( *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) - new_msgs = {**error.msgs, **msgs_} + new_msgs = {**error.msgs, **msgs_body, **msgs_cond} return out, Error(err, code, new_msgs) error_checks[control_flow.while_p] = while_loop_error_check diff --git a/tests/checkify_test.py b/tests/checkify_test.py index db9247ed0..ac311f3fe 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -259,6 +259,29 @@ class CheckifyTransformTests(jtu.JaxTestCase): self.assertStartsWith(err.get(), "nan generated by primitive sin") self.assertArraysEqual(ch_out, out) + @jtu.skip_on_devices("tpu") + def test_while_loop_cond_error_and_false(self): + # Tests if an error is generated when cond returns False. + def while_cond(val): + possible_nan = jnp.sin(1./val) + return jnp.logical_not(jnp.isnan(possible_nan)) + + @jax.jit + def f(init_val): + return lax.while_loop(while_cond, lambda val: val-1, init_val) + + # error on first cond + init_val = 0. + err, _ = checkify.checkify(f)(init_val) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "nan generated by primitive sin") + + # error on second cond + init_val = 1. + err, _ = checkify.checkify(f)(init_val) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "nan generated by primitive sin") + @jtu.skip_on_devices("tpu") def test_while_loop_body_and_cond_error(self): def while_cond(val):