From 7078f81dd00646b0236463cf81e2ad1bfb08a008 Mon Sep 17 00:00:00 2001 From: lenamartens Date: Thu, 22 Sep 2022 15:23:54 +0100 Subject: [PATCH] Checkify: misc improvements. - err.throw == check_error(err) -> meaning they have the same behavior under checkify now - "divided by zero" -> "division by zero" - add validation that check_error only takes args of type Error --- jax/_src/checkify.py | 23 ++++++++++++++++------- tests/checkify_test.py | 28 ++++++++++++++-------------- 2 files changed, 30 insertions(+), 21 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 5a48a14c1..076e3b261 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -110,10 +110,16 @@ class Error: return None def throw(self): - """Throw ValueError with error message if error happened.""" - err = self.get() - if err: - raise ValueError(err) + check_error(self) + + def __str__(self): + return f'Error({self.get()})' + + +def raise_error(error): + err = error.get() + if err: + raise ValueError(err) register_pytree_node(Error, @@ -142,7 +148,6 @@ class CheckifyTracer(core.Tracer): def __init__(self, trace, val): self._trace = trace self.val = val - core.get_aval(val), val aval = property(lambda self: core.get_aval(self.val)) full_lower = lambda self: self @@ -457,6 +462,10 @@ def check_error(error: Error) -> None: >>> # can re-checkify >>> error, _ = checkify.checkify(with_inner_jit)(-1) """ + if not isinstance(error, Error): + raise ValueError('check_error takes an Error as argument, ' + f'got type {type(error)} instead.') + if np.shape(error.err): err, code, payload = _reduce_any_error(error.err, error.code, error.payload) else: @@ -470,7 +479,7 @@ assert_p.multiple_results = True # zero results @assert_p.def_impl def assert_impl(err, code, payload, *, msgs): - Error(err, code, msgs, payload).throw() + raise_error(Error(err, code, msgs, payload)) return [] CheckEffect = object() @@ -564,7 +573,7 @@ def div_error_check(error, enabled_errors, x, y): """Checks for division by zero and NaN.""" if ErrorCategory.DIV in enabled_errors: any_zero = jnp.any(jnp.equal(y, 0)) - msg = f'divided by zero at {summary()}' + msg = f'division by zero at {summary()}' error = assert_func(error, any_zero, msg, None) return nan_error_check(lax.div_p, error, enabled_errors, x, y) error_checks[lax.div_p] = div_error_check diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 68b839ae0..656bb5b19 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -109,7 +109,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): err, _ = checked_f(jnp.ones((3,)), jnp.array([1., 0., 1.])) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "divided by zero") + self.assertStartsWith(err.get(), "division by zero") err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1])) self.assertIsNotNone(err.get()) @@ -281,7 +281,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): err, (ch_out_carry, ch_outs) = checked_f(carry, xs) out_carry, outs = f(carry, xs) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "divided by zero") + self.assertStartsWith(err.get(), "division by zero") self.assertArraysEqual(ch_outs, outs) self.assertArraysEqual(ch_out_carry, out_carry) @@ -290,7 +290,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): err, (ch_out_carry, ch_outs) = checked_f(carry, xs) out_carry, outs = f(carry, xs) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "divided by zero") + self.assertStartsWith(err.get(), "division by zero") self.assertArraysEqual(ch_outs, outs) self.assertArraysEqual(ch_out_carry, out_carry) @@ -321,7 +321,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): err, ch_out = checked_f(init_val) out = f(init_val) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "divided by zero") + self.assertStartsWith(err.get(), "division by zero") self.assertArraysEqual(ch_out, out) @jtu.skip_on_devices("tpu") @@ -349,7 +349,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): err, ch_out = checked_f(init_val) out = f(init_val) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "divided by zero") + self.assertStartsWith(err.get(), "division by zero") self.assertArraysEqual(ch_out, out) @jtu.skip_on_devices("tpu") @@ -369,13 +369,13 @@ class CheckifyTransformTests(jtu.JaxTestCase): init_val = 0. err, _ = checked_f(init_val) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "divided by zero") + self.assertStartsWith(err.get(), "division by zero") # error on second cond init_val = 1. err, _ = checked_f(init_val) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "divided by zero") + self.assertStartsWith(err.get(), "division by zero") @jtu.skip_on_devices("tpu") def test_while_loop_body_and_cond_error(self): @@ -441,9 +441,9 @@ class CheckifyTransformTests(jtu.JaxTestCase): b_err, _ = g(x, x) self.assertIsNotNone(u_err.get()) - self.assertStartsWith(u_err.get(), "divided by zero") + self.assertStartsWith(u_err.get(), "division by zero") self.assertIsNotNone(b_err.get()) - self.assertStartsWith(b_err.get(), "divided by zero") + self.assertStartsWith(b_err.get(), "division by zero") def test_empty_enabled_errors(self): def multi_errors(x): @@ -459,10 +459,10 @@ class CheckifyTransformTests(jtu.JaxTestCase): @parameterized.named_parameters( ("assert", checkify.user_checks, "must be negative!"), - ("div", {checkify.ErrorCategory.DIV}, "divided by zero"), + ("div", {checkify.ErrorCategory.DIV}, "division by zero"), ("nan", {checkify.ErrorCategory.NAN}, "nan generated"), ("oob", checkify.index_checks, "out-of-bounds indexing"), - ("automatic_checks", checkify.automatic_checks, "divided by zero"), + ("automatic_checks", checkify.automatic_checks, "division by zero"), ) @jtu.skip_on_devices("tpu") def test_enabled_errors(self, error_set, expected_error): @@ -625,7 +625,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(jnp.ones((7, 3))) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "divided by zero") + self.assertStartsWith(err.get(), "division by zero") def test_multiple_payloads(self): def f(x): @@ -651,7 +651,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): cf = checkify.checkify(f, errors=checkify.automatic_checks) errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100])) self.assertIsNotNone(errs.get()) - self.assertIn("divided by zero", errs.get()) + self.assertIn("division by zero", errs.get()) self.assertIn("index 100", errs.get()) @@ -895,7 +895,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase): err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.])) self.assertIsNotNone(err.get()) - self.assertStartsWith(err.get(), "divided by zero") + self.assertStartsWith(err.get(), "division by zero") err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.])) self.assertIsNotNone(err.get())