Checkify: support batched while.

This commit is contained in:
lenamartens 2022-09-16 17:46:25 +01:00
parent 9791199488
commit 018e700ead
2 changed files with 70 additions and 1 deletions

View File

@ -662,18 +662,54 @@ def ignore_error_output_jaxpr(jaxpr):
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[3:])
return core.ClosedJaxpr(new_jaxpr, consts)
def batch_error(err, code, payload, batch_shape):
err = jnp.broadcast_to(err, batch_shape)
code = jnp.broadcast_to(code, batch_shape)
payload = jnp.broadcast_to(payload, batch_shape+(3,))
return err, code, payload
def unbatch_error(err, code, payload):
err = err.ravel()[0]
code = code.ravel()[0]
payload = payload.reshape(-1, 3)[0]
return err, code, payload
def trivial_batched_jaxpr(jaxpr, batch_shape, batched_err):
fun = core.jaxpr_as_fun(jaxpr)
def g(err, code, payload, *a):
err_args = unbatch_error(err, code, payload)
err, code, payload, *out = fun(*err_args, *a)
err, code, payload = batch_error(err, code, payload, batch_shape)
return (err, code, payload, *out)
error_avals = map(lambda x: core.raise_to_shaped(core.get_aval(x)), batched_err)
new_jaxpr, _, literals_out = pe.trace_to_jaxpr_dynamic(
lu.wrap_init(g), [*error_avals, *jaxpr.in_avals[3:]])
return core.ClosedJaxpr(new_jaxpr, literals_out)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
cond_jaxpr, body_nconsts, body_jaxpr):
batch_shape = cond_jaxpr.out_avals[0].shape
if batch_shape:
err_args = batch_error(error.err, error.code, error.payload, batch_shape)
else:
err_args = [error.err, error.code, error.payload]
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
# Check if the first cond application will error.
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error,
enabled_errors)
if batch_shape:
checked_cond_jaxpr = trivial_batched_jaxpr(checked_cond_jaxpr, batch_shape, err_args)
cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(checked_cond_jaxpr)(
error.err, error.code, error.payload, *c_consts, *carry)
*err_args, *c_consts, *carry)
checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr(
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
if batch_shape:
checked_body_jaxpr_ = trivial_batched_jaxpr(checked_body_jaxpr_, batch_shape, err_args)
to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry)
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
@ -686,6 +722,9 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
if batch_shape:
err, code, payload = unbatch_error(err, code, payload)
return out, Error(err, code, new_msgs, payload)
error_checks[lax.while_p] = while_loop_error_check

View File

@ -875,5 +875,35 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertIn("should be positive", err.get())
def test_checkify_of_vmap_of_while(self):
@jax.vmap
def fun(n, v):
def while_cond(s):
counter, value = s
checkify.check(value < 6, "value needs to be less than 6!")
return counter > 0
def while_body(s):
counter, value = s
checkify.check(value >= 0, "value needs to be positive!")
return counter/value, value - 1.
_, result = jax.lax.while_loop(while_cond, while_body, (n, v))
return result
checked_f = checkify.checkify(fun, errors=checkify.all_checks)
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "value needs to be positive")
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([6., 2., -4.]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "value needs to be less than 6")
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())