Make sure while_loop cond generates an error even if it returns False.

This commit is contained in:
Lena Martens 2021-12-16 21:48:37 +00:00 committed by lenamartens
parent 961c28c811
commit 98a5461132
2 changed files with 34 additions and 6 deletions

View File

@ -305,8 +305,9 @@ def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error):
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)
def new_body_f(*vals):
_ = cond_f(*vals)
return body_f(*vals)
out = body_f(*vals)
_ = cond_f(*out) # this checks if the next cond application will error
return out
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
def ignore_errors_jaxpr(jaxpr, error):
@ -322,18 +323,22 @@ def ignore_errors_jaxpr(jaxpr, error):
return core.ClosedJaxpr(new_jaxpr, consts)
def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
# TODO(lenamartens): fix when an error occurs in the cond function and it then returns False.
checked_body_jaxpr, msgs_ = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error)
checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
# Check if the first cond application will error.
cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
new_in_flat = [*c_consts, *b_consts, error.err, error.code, *carry]
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
err, code, *out = control_flow.while_p.bind(
*new_in_flat,
cond_nconsts=cond_nconsts,
cond_jaxpr=compat_cond_jaxpr,
body_nconsts=body_nconsts,
body_jaxpr=checked_body_jaxpr)
new_msgs = {**error.msgs, **msgs_}
new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
return out, Error(err, code, new_msgs)
error_checks[control_flow.while_p] = while_loop_error_check

View File

@ -259,6 +259,29 @@ class CheckifyTransformTests(jtu.JaxTestCase):
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertArraysEqual(ch_out, out)
@jtu.skip_on_devices("tpu")
def test_while_loop_cond_error_and_false(self):
# Tests if an error is generated when cond returns False.
def while_cond(val):
possible_nan = jnp.sin(1./val)
return jnp.logical_not(jnp.isnan(possible_nan))
@jax.jit
def f(init_val):
return lax.while_loop(while_cond, lambda val: val-1, init_val)
# error on first cond
init_val = 0.
err, _ = checkify.checkify(f)(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
# error on second cond
init_val = 1.
err, _ = checkify.checkify(f)(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
@jtu.skip_on_devices("tpu")
def test_while_loop_body_and_cond_error(self):
def while_cond(val):