Add division by zero check.

This commit is contained in:
Lena Martens 2021-12-20 15:56:50 +00:00 committed by lenamartens
parent eaf7885460
commit bbd127f8fa
2 changed files with 34 additions and 7 deletions

View File

@ -279,6 +279,14 @@ def gather_error_check(error, operand, start_indices, *,
return out, assert_func(error, all_inbounds, msg)
error_checks[lax.gather_p] = gather_error_check
def div_error_check(error, x, y):
"""Checks for division by zero and NaN."""
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
msg = f'divided by zero at {summary()}'
div_by_zero_err = assert_func(error, all_nonzero, msg)
return nan_error_check(lax.div_p, div_by_zero_err, x, y)
error_checks[lax.div_p] = div_error_check
def cond_error_check(error, index, *ops, branches, linear):
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
new_linear = (False, False, *linear)
@ -418,7 +426,6 @@ add_nan_check(lax.integer_pow_p)
add_nan_check(lax.tanh_p)
add_nan_check(lax.log_p)
add_nan_check(lax.atan2_p)
add_nan_check(lax.div_p)
add_nan_check(lax.sin_p)
add_nan_check(lax.cos_p)
add_nan_check(lax.sinh_p)

View File

@ -66,6 +66,26 @@ class CheckifyTransformTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in [False, True]))
def test_jit_div_errors(self, jit):
def f(x, y):
return x/y
f = jax.jit(f) if jit else f
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.ones((3,)))
self.assertIs(err.get(), None)
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.array([1, 0, 1]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
err, _ = checkify.checkify(f)(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive div')
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in [False, True]))
@ -192,7 +212,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "divided by zero")
self.assertArraysEqual(ch_outs, outs)
self.assertArraysEqual(ch_out_carry, out_carry)
@ -201,7 +221,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "divided by zero")
self.assertArraysEqual(ch_outs, outs)
self.assertArraysEqual(ch_out_carry, out_carry)
@ -230,7 +250,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, ch_out = checkify.checkify(f)(init_val)
out = f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "divided by zero")
self.assertArraysEqual(ch_out, out)
@jtu.skip_on_devices("tpu")
@ -256,7 +276,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, ch_out = checkify.checkify(f)(init_val)
out = f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "divided by zero")
self.assertArraysEqual(ch_out, out)
@jtu.skip_on_devices("tpu")
@ -274,13 +294,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
init_val = 0.
err, _ = checkify.checkify(f)(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "divided by zero")
# error on second cond
init_val = 1.
err, _ = checkify.checkify(f)(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "divided by zero")
@jtu.skip_on_devices("tpu")
def test_while_loop_body_and_cond_error(self):