mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Make sure while_loop cond generates an error even if it returns False.
This commit is contained in:
parent
961c28c811
commit
98a5461132
@ -305,8 +305,9 @@ def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error):
|
||||
cond_f = core.jaxpr_as_fun(cond_jaxpr)
|
||||
body_f = core.jaxpr_as_fun(body_jaxpr)
|
||||
def new_body_f(*vals):
|
||||
_ = cond_f(*vals)
|
||||
return body_f(*vals)
|
||||
out = body_f(*vals)
|
||||
_ = cond_f(*out) # this checks if the next cond application will error
|
||||
return out
|
||||
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
|
||||
|
||||
def ignore_errors_jaxpr(jaxpr, error):
|
||||
@ -322,18 +323,22 @@ def ignore_errors_jaxpr(jaxpr, error):
|
||||
return core.ClosedJaxpr(new_jaxpr, consts)
|
||||
|
||||
def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
|
||||
# TODO(lenamartens): fix when an error occurs in the cond function and it then returns False.
|
||||
checked_body_jaxpr, msgs_ = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
|
||||
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error)
|
||||
checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
|
||||
# Check if the first cond application will error.
|
||||
cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)
|
||||
|
||||
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
|
||||
compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
|
||||
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
|
||||
new_in_flat = [*c_consts, *b_consts, error.err, error.code, *carry]
|
||||
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
|
||||
err, code, *out = control_flow.while_p.bind(
|
||||
*new_in_flat,
|
||||
cond_nconsts=cond_nconsts,
|
||||
cond_jaxpr=compat_cond_jaxpr,
|
||||
body_nconsts=body_nconsts,
|
||||
body_jaxpr=checked_body_jaxpr)
|
||||
new_msgs = {**error.msgs, **msgs_}
|
||||
new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
|
||||
return out, Error(err, code, new_msgs)
|
||||
error_checks[control_flow.while_p] = while_loop_error_check
|
||||
|
||||
|
@ -259,6 +259,29 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertArraysEqual(ch_out, out)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_while_loop_cond_error_and_false(self):
|
||||
# Tests if an error is generated when cond returns False.
|
||||
def while_cond(val):
|
||||
possible_nan = jnp.sin(1./val)
|
||||
return jnp.logical_not(jnp.isnan(possible_nan))
|
||||
|
||||
@jax.jit
|
||||
def f(init_val):
|
||||
return lax.while_loop(while_cond, lambda val: val-1, init_val)
|
||||
|
||||
# error on first cond
|
||||
init_val = 0.
|
||||
err, _ = checkify.checkify(f)(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
|
||||
# error on second cond
|
||||
init_val = 1.
|
||||
err, _ = checkify.checkify(f)(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_while_loop_body_and_cond_error(self):
|
||||
def while_cond(val):
|
||||
|
Loading…
x
Reference in New Issue
Block a user