From c2a00a0526a183a2d5eddbd69e0dd85457bffd63 Mon Sep 17 00:00:00 2001 From: lenamartens Date: Mon, 17 Oct 2022 21:48:38 +0100 Subject: [PATCH] Disallow checkify-of-vmap-of-while. --- jax/_src/checkify.py | 46 ++++++++---------------------------------- tests/checkify_test.py | 43 ++++++++++++++++++++++++++++++--------- 2 files changed, 41 insertions(+), 48 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 9d30160c2..912116d66 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -702,54 +702,26 @@ def ignore_error_output_jaxpr(jaxpr): new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[3:]) return core.ClosedJaxpr(new_jaxpr, consts) -def batch_error(err, code, payload, batch_shape): - err = jnp.broadcast_to(err, batch_shape) - code = jnp.broadcast_to(code, batch_shape) - payload = jnp.broadcast_to(payload, batch_shape+(3,)) - return err, code, payload - -def unbatch_error(err, code, payload): - err = err.ravel()[0] - code = code.ravel()[0] - payload = payload.reshape(-1, 3)[0] - return err, code, payload - -def trivial_batched_jaxpr(jaxpr, batch_shape, batched_err): - fun = core.jaxpr_as_fun(jaxpr) - - def g(err, code, payload, *a): - err_args = unbatch_error(err, code, payload) - err, code, payload, *out = fun(*err_args, *a) - err, code, payload = batch_error(err, code, payload, batch_shape) - return (err, code, payload, *out) - - error_avals = map(lambda x: core.raise_to_shaped(core.get_aval(x)), batched_err) - new_jaxpr, _, literals_out = pe.trace_to_jaxpr_dynamic( - lu.wrap_init(g), [*error_avals, *jaxpr.in_avals[3:]]) - return core.ClosedJaxpr(new_jaxpr, literals_out) - def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): - batch_shape = cond_jaxpr.out_avals[0].shape - if batch_shape: - err_args = batch_error(error.err, error.code, error.payload, batch_shape) - else: - err_args = [error.err, error.code, error.payload] + if cond_jaxpr.out_avals[0].shape: + # TODO(lenamartens, sharadmv): support batched while. + raise ValueError('Checkify does not support batched while-loops ' + '(checkify-of-vmap-of-while). \nHint: if possible, move ' + 'the vmap to the outer level to get ' + 'vmap-of-checkify-of-while.') + err_args = [error.err, error.code, error.payload] c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) # Check if the first cond application will error. checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors) - if batch_shape: - checked_cond_jaxpr = trivial_batched_jaxpr(checked_cond_jaxpr, batch_shape, err_args) cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(checked_cond_jaxpr)( *err_args, *c_consts, *carry) checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr( - cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) - if batch_shape: - checked_body_jaxpr_ = trivial_batched_jaxpr(checked_body_jaxpr_, batch_shape, err_args) + cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts) to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry) checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move) @@ -762,8 +734,6 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, *new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr, body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr) new_msgs = {**error.msgs, **msgs_body, **msgs_cond} - if batch_shape: - err, code, payload = unbatch_error(err, code, payload) return out, Error(err, code, new_msgs, payload) error_checks[lax.while_p] = while_loop_error_check diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 3b8e17933..64fd8d098 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -891,7 +891,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase): self.assertIsNotNone(err.get()) self.assertIn("should be positive", err.get()) - def test_checkify_of_vmap_of_while(self): + def test_checkify_of_vmap_of_while_errors(self): @jax.vmap def fun(n, v): def while_cond(s): @@ -909,17 +909,39 @@ class AssertPrimitiveTests(jtu.JaxTestCase): checked_f = checkify.checkify(fun, errors=checkify.all_checks) - err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.])) - self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "division by zero") + with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"): + checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.])) + # TODO(lenamartens): reenable assertions below. + # self.assertIsNotNone(err.get()) + # self.assertStartsWith(err.get(), "division by zero") - err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.])) - self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "value needs to be positive") + # err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.])) + # self.assertIsNotNone(err.get()) + # self.assertStartsWith(err.get(), "value needs to be positive") - err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([6., 2., -4.])) - self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "value needs to be less than 6") + # err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([6., 2., -4.])) + # self.assertIsNotNone(err.get()) + # self.assertStartsWith(err.get(), "value needs to be less than 6") + + def test_checkify_of_vmap_of_while_masked_errors(self): + def cond(x): + return x < 5 + + def body(x): + # This will only trigger in the masked portion of the batched while. + checkify.check(x < 5, "should never happen") + return x + 1 + + @jax.vmap + def fun(x): + return lax.while_loop(cond, body, x) + + checked_f = checkify.checkify(fun) + + with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"): + checked_f(jnp.arange(5)) + # TODO(lenamartens): reenable assertions below. + # self.assertIsNone(err.get()) def test_assert_cond_no_data_dependence(self): def f(): @@ -965,5 +987,6 @@ class LowerableChecksTest(jtu.JaxTestCase): with self.assertRaisesRegex(xla_extension.XlaRuntimeError, "x needs to be positive"): f(-1.) + if __name__ == "__main__": absltest.main(testLoader=jtu.JaxTestLoader())