mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Checkify: misc improvements.
- err.throw == check_error(err) -> meaning they have the same behavior under checkify now - "divided by zero" -> "division by zero" - add validation that check_error only takes args of type Error
This commit is contained in:
parent
640e15fe07
commit
7078f81dd0
@ -110,10 +110,16 @@ class Error:
|
||||
return None
|
||||
|
||||
def throw(self):
|
||||
"""Throw ValueError with error message if error happened."""
|
||||
err = self.get()
|
||||
if err:
|
||||
raise ValueError(err)
|
||||
check_error(self)
|
||||
|
||||
def __str__(self):
|
||||
return f'Error({self.get()})'
|
||||
|
||||
|
||||
def raise_error(error):
|
||||
err = error.get()
|
||||
if err:
|
||||
raise ValueError(err)
|
||||
|
||||
|
||||
register_pytree_node(Error,
|
||||
@ -142,7 +148,6 @@ class CheckifyTracer(core.Tracer):
|
||||
def __init__(self, trace, val):
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
core.get_aval(val), val
|
||||
aval = property(lambda self: core.get_aval(self.val))
|
||||
full_lower = lambda self: self
|
||||
|
||||
@ -457,6 +462,10 @@ def check_error(error: Error) -> None:
|
||||
>>> # can re-checkify
|
||||
>>> error, _ = checkify.checkify(with_inner_jit)(-1)
|
||||
"""
|
||||
if not isinstance(error, Error):
|
||||
raise ValueError('check_error takes an Error as argument, '
|
||||
f'got type {type(error)} instead.')
|
||||
|
||||
if np.shape(error.err):
|
||||
err, code, payload = _reduce_any_error(error.err, error.code, error.payload)
|
||||
else:
|
||||
@ -470,7 +479,7 @@ assert_p.multiple_results = True # zero results
|
||||
|
||||
@assert_p.def_impl
|
||||
def assert_impl(err, code, payload, *, msgs):
|
||||
Error(err, code, msgs, payload).throw()
|
||||
raise_error(Error(err, code, msgs, payload))
|
||||
return []
|
||||
|
||||
CheckEffect = object()
|
||||
@ -564,7 +573,7 @@ def div_error_check(error, enabled_errors, x, y):
|
||||
"""Checks for division by zero and NaN."""
|
||||
if ErrorCategory.DIV in enabled_errors:
|
||||
any_zero = jnp.any(jnp.equal(y, 0))
|
||||
msg = f'divided by zero at {summary()}'
|
||||
msg = f'division by zero at {summary()}'
|
||||
error = assert_func(error, any_zero, msg, None)
|
||||
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
|
||||
error_checks[lax.div_p] = div_error_check
|
||||
|
@ -109,7 +109,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
err, _ = checked_f(jnp.ones((3,)), jnp.array([1., 0., 1.]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
|
||||
err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
|
||||
self.assertIsNotNone(err.get())
|
||||
@ -281,7 +281,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
||||
out_carry, outs = f(carry, xs)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
self.assertArraysEqual(ch_outs, outs)
|
||||
self.assertArraysEqual(ch_out_carry, out_carry)
|
||||
|
||||
@ -290,7 +290,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
||||
out_carry, outs = f(carry, xs)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
self.assertArraysEqual(ch_outs, outs)
|
||||
self.assertArraysEqual(ch_out_carry, out_carry)
|
||||
|
||||
@ -321,7 +321,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, ch_out = checked_f(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
self.assertArraysEqual(ch_out, out)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -349,7 +349,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, ch_out = checked_f(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
self.assertArraysEqual(ch_out, out)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -369,13 +369,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
init_val = 0.
|
||||
err, _ = checked_f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
|
||||
# error on second cond
|
||||
init_val = 1.
|
||||
err, _ = checked_f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_while_loop_body_and_cond_error(self):
|
||||
@ -441,9 +441,9 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
b_err, _ = g(x, x)
|
||||
|
||||
self.assertIsNotNone(u_err.get())
|
||||
self.assertStartsWith(u_err.get(), "divided by zero")
|
||||
self.assertStartsWith(u_err.get(), "division by zero")
|
||||
self.assertIsNotNone(b_err.get())
|
||||
self.assertStartsWith(b_err.get(), "divided by zero")
|
||||
self.assertStartsWith(b_err.get(), "division by zero")
|
||||
|
||||
def test_empty_enabled_errors(self):
|
||||
def multi_errors(x):
|
||||
@ -459,10 +459,10 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("assert", checkify.user_checks, "must be negative!"),
|
||||
("div", {checkify.ErrorCategory.DIV}, "divided by zero"),
|
||||
("div", {checkify.ErrorCategory.DIV}, "division by zero"),
|
||||
("nan", {checkify.ErrorCategory.NAN}, "nan generated"),
|
||||
("oob", checkify.index_checks, "out-of-bounds indexing"),
|
||||
("automatic_checks", checkify.automatic_checks, "divided by zero"),
|
||||
("automatic_checks", checkify.automatic_checks, "division by zero"),
|
||||
)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_enabled_errors(self, error_set, expected_error):
|
||||
@ -625,7 +625,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
err, _ = checked_f(jnp.ones((7, 3)))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
|
||||
def test_multiple_payloads(self):
|
||||
def f(x):
|
||||
@ -651,7 +651,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
cf = checkify.checkify(f, errors=checkify.automatic_checks)
|
||||
errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100]))
|
||||
self.assertIsNotNone(errs.get())
|
||||
self.assertIn("divided by zero", errs.get())
|
||||
self.assertIn("division by zero", errs.get())
|
||||
self.assertIn("index 100", errs.get())
|
||||
|
||||
|
||||
@ -895,7 +895,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
|
||||
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
|
||||
err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.]))
|
||||
self.assertIsNotNone(err.get())
|
||||
|
Loading…
x
Reference in New Issue
Block a user