Checkify: misc improvements.

- err.throw == check_error(err) -> meaning they have the same behavior
  under checkify now
- "divided by zero" -> "division by zero"
- add validation that check_error only takes args of type Error
This commit is contained in:
lenamartens 2022-09-22 15:23:54 +01:00
parent 640e15fe07
commit 7078f81dd0
2 changed files with 30 additions and 21 deletions

View File

@ -110,10 +110,16 @@ class Error:
return None
def throw(self):
"""Throw ValueError with error message if error happened."""
err = self.get()
if err:
raise ValueError(err)
check_error(self)
def __str__(self):
return f'Error({self.get()})'
def raise_error(error):
err = error.get()
if err:
raise ValueError(err)
register_pytree_node(Error,
@ -142,7 +148,6 @@ class CheckifyTracer(core.Tracer):
def __init__(self, trace, val):
self._trace = trace
self.val = val
core.get_aval(val), val
aval = property(lambda self: core.get_aval(self.val))
full_lower = lambda self: self
@ -457,6 +462,10 @@ def check_error(error: Error) -> None:
>>> # can re-checkify
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
"""
if not isinstance(error, Error):
raise ValueError('check_error takes an Error as argument, '
f'got type {type(error)} instead.')
if np.shape(error.err):
err, code, payload = _reduce_any_error(error.err, error.code, error.payload)
else:
@ -470,7 +479,7 @@ assert_p.multiple_results = True # zero results
@assert_p.def_impl
def assert_impl(err, code, payload, *, msgs):
Error(err, code, msgs, payload).throw()
raise_error(Error(err, code, msgs, payload))
return []
CheckEffect = object()
@ -564,7 +573,7 @@ def div_error_check(error, enabled_errors, x, y):
"""Checks for division by zero and NaN."""
if ErrorCategory.DIV in enabled_errors:
any_zero = jnp.any(jnp.equal(y, 0))
msg = f'divided by zero at {summary()}'
msg = f'division by zero at {summary()}'
error = assert_func(error, any_zero, msg, None)
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
error_checks[lax.div_p] = div_error_check

View File

@ -109,7 +109,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, _ = checked_f(jnp.ones((3,)), jnp.array([1., 0., 1.]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertStartsWith(err.get(), "division by zero")
err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
self.assertIsNotNone(err.get())
@ -281,7 +281,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertStartsWith(err.get(), "division by zero")
self.assertArraysEqual(ch_outs, outs)
self.assertArraysEqual(ch_out_carry, out_carry)
@ -290,7 +290,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertStartsWith(err.get(), "division by zero")
self.assertArraysEqual(ch_outs, outs)
self.assertArraysEqual(ch_out_carry, out_carry)
@ -321,7 +321,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, ch_out = checked_f(init_val)
out = f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertStartsWith(err.get(), "division by zero")
self.assertArraysEqual(ch_out, out)
@jtu.skip_on_devices("tpu")
@ -349,7 +349,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, ch_out = checked_f(init_val)
out = f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertStartsWith(err.get(), "division by zero")
self.assertArraysEqual(ch_out, out)
@jtu.skip_on_devices("tpu")
@ -369,13 +369,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
init_val = 0.
err, _ = checked_f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertStartsWith(err.get(), "division by zero")
# error on second cond
init_val = 1.
err, _ = checked_f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertStartsWith(err.get(), "division by zero")
@jtu.skip_on_devices("tpu")
def test_while_loop_body_and_cond_error(self):
@ -441,9 +441,9 @@ class CheckifyTransformTests(jtu.JaxTestCase):
b_err, _ = g(x, x)
self.assertIsNotNone(u_err.get())
self.assertStartsWith(u_err.get(), "divided by zero")
self.assertStartsWith(u_err.get(), "division by zero")
self.assertIsNotNone(b_err.get())
self.assertStartsWith(b_err.get(), "divided by zero")
self.assertStartsWith(b_err.get(), "division by zero")
def test_empty_enabled_errors(self):
def multi_errors(x):
@ -459,10 +459,10 @@ class CheckifyTransformTests(jtu.JaxTestCase):
@parameterized.named_parameters(
("assert", checkify.user_checks, "must be negative!"),
("div", {checkify.ErrorCategory.DIV}, "divided by zero"),
("div", {checkify.ErrorCategory.DIV}, "division by zero"),
("nan", {checkify.ErrorCategory.NAN}, "nan generated"),
("oob", checkify.index_checks, "out-of-bounds indexing"),
("automatic_checks", checkify.automatic_checks, "divided by zero"),
("automatic_checks", checkify.automatic_checks, "division by zero"),
)
@jtu.skip_on_devices("tpu")
def test_enabled_errors(self, error_set, expected_error):
@ -625,7 +625,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
checked_f = checkify.checkify(f, errors=checkify.float_checks)
err, _ = checked_f(jnp.ones((7, 3)))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertStartsWith(err.get(), "division by zero")
def test_multiple_payloads(self):
def f(x):
@ -651,7 +651,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
cf = checkify.checkify(f, errors=checkify.automatic_checks)
errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100]))
self.assertIsNotNone(errs.get())
self.assertIn("divided by zero", errs.get())
self.assertIn("division by zero", errs.get())
self.assertIn("index 100", errs.get())
@ -895,7 +895,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
self.assertStartsWith(err.get(), "division by zero")
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.]))
self.assertIsNotNone(err.get())