From 018e700ead0612a56f1e7702220012661161f4fc Mon Sep 17 00:00:00 2001 From: lenamartens Date: Fri, 16 Sep 2022 17:46:25 +0100 Subject: [PATCH] Checkify: support batched while. --- jax/_src/checkify.py | 41 ++++++++++++++++++++++++++++++++++++++++- tests/checkify_test.py | 30 ++++++++++++++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 4243a9209..5a48a14c1 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -662,18 +662,54 @@ def ignore_error_output_jaxpr(jaxpr): new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[3:]) return core.ClosedJaxpr(new_jaxpr, consts) +def batch_error(err, code, payload, batch_shape): + err = jnp.broadcast_to(err, batch_shape) + code = jnp.broadcast_to(code, batch_shape) + payload = jnp.broadcast_to(payload, batch_shape+(3,)) + return err, code, payload + +def unbatch_error(err, code, payload): + err = err.ravel()[0] + code = code.ravel()[0] + payload = payload.reshape(-1, 3)[0] + return err, code, payload + +def trivial_batched_jaxpr(jaxpr, batch_shape, batched_err): + fun = core.jaxpr_as_fun(jaxpr) + + def g(err, code, payload, *a): + err_args = unbatch_error(err, code, payload) + err, code, payload, *out = fun(*err_args, *a) + err, code, payload = batch_error(err, code, payload, batch_shape) + return (err, code, payload, *out) + + error_avals = map(lambda x: core.raise_to_shaped(core.get_aval(x)), batched_err) + new_jaxpr, _, literals_out = pe.trace_to_jaxpr_dynamic( + lu.wrap_init(g), [*error_avals, *jaxpr.in_avals[3:]]) + return core.ClosedJaxpr(new_jaxpr, literals_out) + def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): + batch_shape = cond_jaxpr.out_avals[0].shape + if batch_shape: + err_args = batch_error(error.err, error.code, error.payload, batch_shape) + else: + err_args = [error.err, error.code, error.payload] + c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) # Check if the first cond application will error. checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors) + if batch_shape: + checked_cond_jaxpr = trivial_batched_jaxpr(checked_cond_jaxpr, batch_shape, err_args) cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(checked_cond_jaxpr)( - error.err, error.code, error.payload, *c_consts, *carry) + *err_args, *c_consts, *carry) checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr( cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) + if batch_shape: + checked_body_jaxpr_ = trivial_batched_jaxpr(checked_body_jaxpr_, batch_shape, err_args) to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry) checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) @@ -686,6 +722,9 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) new_msgs = {**error.msgs, **msgs_body, **msgs_cond} + if batch_shape: + err, code, payload = unbatch_error(err, code, payload) + return out, Error(err, code, new_msgs, payload) error_checks[lax.while_p] = while_loop_error_check diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 66ddd69de..68b839ae0 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -875,5 +875,35 @@ class AssertPrimitiveTests(jtu.JaxTestCase): self.assertIsNotNone(err.get()) self.assertIn("should be positive", err.get()) + def test_checkify_of_vmap_of_while(self): + @jax.vmap + def fun(n, v): + def while_cond(s): + counter, value = s + checkify.check(value < 6, "value needs to be less than 6!") + return counter > 0 + + def while_body(s): + counter, value = s + checkify.check(value >= 0, "value needs to be positive!") + return counter/value, value - 1. + + _, result = jax.lax.while_loop(while_cond, while_body, (n, v)) + return result + + checked_f = checkify.checkify(fun, errors=checkify.all_checks) + + err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.])) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "divided by zero") + + err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.])) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "value needs to be positive") + + err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([6., 2., -4.])) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "value needs to be less than 6") + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())