mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Disallow checkify-of-vmap-of-while.
This commit is contained in:
parent
504b3c1b25
commit
c2a00a0526
@ -702,54 +702,26 @@ def ignore_error_output_jaxpr(jaxpr):
|
||||
new_jaxpr = jaxpr.replace(outvars=jaxpr.outvars[3:])
|
||||
return core.ClosedJaxpr(new_jaxpr, consts)
|
||||
|
||||
def batch_error(err, code, payload, batch_shape):
|
||||
err = jnp.broadcast_to(err, batch_shape)
|
||||
code = jnp.broadcast_to(code, batch_shape)
|
||||
payload = jnp.broadcast_to(payload, batch_shape+(3,))
|
||||
return err, code, payload
|
||||
|
||||
def unbatch_error(err, code, payload):
|
||||
err = err.ravel()[0]
|
||||
code = code.ravel()[0]
|
||||
payload = payload.reshape(-1, 3)[0]
|
||||
return err, code, payload
|
||||
|
||||
def trivial_batched_jaxpr(jaxpr, batch_shape, batched_err):
|
||||
fun = core.jaxpr_as_fun(jaxpr)
|
||||
|
||||
def g(err, code, payload, *a):
|
||||
err_args = unbatch_error(err, code, payload)
|
||||
err, code, payload, *out = fun(*err_args, *a)
|
||||
err, code, payload = batch_error(err, code, payload, batch_shape)
|
||||
return (err, code, payload, *out)
|
||||
|
||||
error_avals = map(lambda x: core.raise_to_shaped(core.get_aval(x)), batched_err)
|
||||
new_jaxpr, _, literals_out = pe.trace_to_jaxpr_dynamic(
|
||||
lu.wrap_init(g), [*error_avals, *jaxpr.in_avals[3:]])
|
||||
return core.ClosedJaxpr(new_jaxpr, literals_out)
|
||||
|
||||
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
cond_jaxpr, body_nconsts, body_jaxpr):
|
||||
batch_shape = cond_jaxpr.out_avals[0].shape
|
||||
if batch_shape:
|
||||
err_args = batch_error(error.err, error.code, error.payload, batch_shape)
|
||||
else:
|
||||
err_args = [error.err, error.code, error.payload]
|
||||
if cond_jaxpr.out_avals[0].shape:
|
||||
# TODO(lenamartens, sharadmv): support batched while.
|
||||
raise ValueError('Checkify does not support batched while-loops '
|
||||
'(checkify-of-vmap-of-while). \nHint: if possible, move '
|
||||
'the vmap to the outer level to get '
|
||||
'vmap-of-checkify-of-while.')
|
||||
err_args = [error.err, error.code, error.payload]
|
||||
|
||||
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
|
||||
|
||||
# Check if the first cond application will error.
|
||||
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error,
|
||||
enabled_errors)
|
||||
if batch_shape:
|
||||
checked_cond_jaxpr = trivial_batched_jaxpr(checked_cond_jaxpr, batch_shape, err_args)
|
||||
cond_err, cond_code, cond_payload, _ = core.jaxpr_as_fun(checked_cond_jaxpr)(
|
||||
*err_args, *c_consts, *carry)
|
||||
|
||||
checked_body_jaxpr_, msgs_body = checkify_while_body_jaxpr(
|
||||
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
|
||||
if batch_shape:
|
||||
checked_body_jaxpr_ = trivial_batched_jaxpr(checked_body_jaxpr_, batch_shape, err_args)
|
||||
cond_jaxpr, body_jaxpr, error, enabled_errors, c_consts)
|
||||
to_move = [False] * 3 + [True] * body_nconsts + [False] * len(carry)
|
||||
checked_body_jaxpr = pe.move_binders_to_front(checked_body_jaxpr_, to_move)
|
||||
|
||||
@ -762,8 +734,6 @@ def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
*new_in_flat, cond_nconsts=cond_nconsts, cond_jaxpr=compat_cond_jaxpr,
|
||||
body_nconsts=body_nconsts, body_jaxpr=checked_body_jaxpr)
|
||||
new_msgs = {**error.msgs, **msgs_body, **msgs_cond}
|
||||
if batch_shape:
|
||||
err, code, payload = unbatch_error(err, code, payload)
|
||||
|
||||
return out, Error(err, code, new_msgs, payload)
|
||||
error_checks[lax.while_p] = while_loop_error_check
|
||||
|
@ -891,7 +891,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertIn("should be positive", err.get())
|
||||
|
||||
def test_checkify_of_vmap_of_while(self):
|
||||
def test_checkify_of_vmap_of_while_errors(self):
|
||||
@jax.vmap
|
||||
def fun(n, v):
|
||||
def while_cond(s):
|
||||
@ -909,17 +909,39 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
|
||||
checked_f = checkify.checkify(fun, errors=checkify.all_checks)
|
||||
|
||||
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"):
|
||||
checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.]))
|
||||
# TODO(lenamartens): reenable assertions below.
|
||||
# self.assertIsNotNone(err.get())
|
||||
# self.assertStartsWith(err.get(), "division by zero")
|
||||
|
||||
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "value needs to be positive")
|
||||
# err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.]))
|
||||
# self.assertIsNotNone(err.get())
|
||||
# self.assertStartsWith(err.get(), "value needs to be positive")
|
||||
|
||||
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([6., 2., -4.]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "value needs to be less than 6")
|
||||
# err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([6., 2., -4.]))
|
||||
# self.assertIsNotNone(err.get())
|
||||
# self.assertStartsWith(err.get(), "value needs to be less than 6")
|
||||
|
||||
def test_checkify_of_vmap_of_while_masked_errors(self):
|
||||
def cond(x):
|
||||
return x < 5
|
||||
|
||||
def body(x):
|
||||
# This will only trigger in the masked portion of the batched while.
|
||||
checkify.check(x < 5, "should never happen")
|
||||
return x + 1
|
||||
|
||||
@jax.vmap
|
||||
def fun(x):
|
||||
return lax.while_loop(cond, body, x)
|
||||
|
||||
checked_f = checkify.checkify(fun)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"):
|
||||
checked_f(jnp.arange(5))
|
||||
# TODO(lenamartens): reenable assertions below.
|
||||
# self.assertIsNone(err.get())
|
||||
|
||||
def test_assert_cond_no_data_dependence(self):
|
||||
def f():
|
||||
@ -965,5 +987,6 @@ class LowerableChecksTest(jtu.JaxTestCase):
|
||||
with self.assertRaisesRegex(xla_extension.XlaRuntimeError,
|
||||
"x needs to be positive"):
|
||||
f(-1.)
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user