Disallow checkify-of-vmap-of-while.

This commit is contained in:
lenamartens 2022-10-17 21:48:38 +01:00
parent 504b3c1b25
commit c2a00a0526
2 changed files with 41 additions and 48 deletions

View File

@ -702,54 +702,26 @@ def ignore_error_output_jaxpr(jaxpr):
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[3:])
return core.ClosedJaxpr(new_jaxpr, consts)
def batch_error(err, code, payload, batch_shape):
err = jnp.broadcast_to(err, batch_shape)
code = jnp.broadcast_to(code, batch_shape)
payload = jnp.broadcast_to(payload, batch_shape+(3,))
return err, code, payload
def unbatch_error(err, code, payload):
err = err.ravel()[0]
code = code.ravel()[0]
payload = payload.reshape(-1, 3)[0]
return err, code, payload
def trivial_batched_jaxpr(jaxpr, batch_shape, batched_err):
fun = core.jaxpr_as_fun(jaxpr)
def g(err, code, payload, *a):
err_args = unbatch_error(err, code, payload)
err, code, payload, *out = fun(*err_args, *a)
err, code, payload = batch_error(err, code, payload, batch_shape)
return (err, code, payload, *out)
error_avals = map(lambda x: core.raise_to_shaped(core.get_aval(x)), batched_err)
new_jaxpr, _, literals_out = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(g), [*error_avals, *jaxpr.in_avals[3:]])
return core.ClosedJaxpr(new_jaxpr, literals_out)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
cond_jaxpr, body_nconsts, body_jaxpr):
batch_shape = cond_jaxpr.out_avals[0].shape
if batch_shape:
err_args = batch_error(error.err, error.code, error.payload, batch_shape)
else:
err_args = [error.err, error.code, error.payload]
if cond_jaxpr.out_avals[0].shape:
# TODO(lenamartens, sharadmv): support batched while.
raise ValueError('Checkify does not support batched while-loops '
'(checkify-of-vmap-of-while). \nHint: if possible, move '
'the vmap to the outer level to get '
'vmap-of-checkify-of-while.')
err_args = [error.err, error.code, error.payload]
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
# Check if the first cond application will error.
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error,
enabled_errors)
if batch_shape:
checked_cond_jaxpr = trivial_batched_jaxpr(checked_cond_jaxpr, batch_shape, err_args)
cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(checked_cond_jaxpr)(
*err_args, *c_consts, *carry)
checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr(
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
if batch_shape:
checked_body_jaxpr_ = trivial_batched_jaxpr(checked_body_jaxpr_, batch_shape, err_args)
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry)
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
@ -762,8 +734,6 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
if batch_shape:
err, code, payload = unbatch_error(err, code, payload)
return out, Error(err, code, new_msgs, payload)
error_checks[lax.while_p] = while_loop_error_check

View File

@ -891,7 +891,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertIn("should be positive", err.get())
def test_checkify_of_vmap_of_while(self):
def test_checkify_of_vmap_of_while_errors(self):
@jax.vmap
def fun(n, v):
def while_cond(s):
@ -909,17 +909,39 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
checked_f = checkify.checkify(fun, errors=checkify.all_checks)
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "division by zero")
with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"):
checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.]))
# TODO(lenamartens): reenable assertions below.
# self.assertIsNotNone(err.get())
# self.assertStartsWith(err.get(), "division by zero")
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "value needs to be positive")
# err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.]))
# self.assertIsNotNone(err.get())
# self.assertStartsWith(err.get(), "value needs to be positive")
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([6., 2., -4.]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "value needs to be less than 6")
# err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([6., 2., -4.]))
# self.assertIsNotNone(err.get())
# self.assertStartsWith(err.get(), "value needs to be less than 6")
def test_checkify_of_vmap_of_while_masked_errors(self):
def cond(x):
return x < 5
def body(x):
# This will only trigger in the masked portion of the batched while.
checkify.check(x < 5, "should never happen")
return x + 1
@jax.vmap
def fun(x):
return lax.while_loop(cond, body, x)
checked_f = checkify.checkify(fun)
with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"):
checked_f(jnp.arange(5))
# TODO(lenamartens): reenable assertions below.
# self.assertIsNone(err.get())
def test_assert_cond_no_data_dependence(self):
def f():
@ -965,5 +987,6 @@ class LowerableChecksTest(jtu.JaxTestCase):
with self.assertRaisesRegex(xla_extension.XlaRuntimeError,
"x needs to be positive"):
f(-1.)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())