mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add division by zero check.
This commit is contained in:
parent
eaf7885460
commit
bbd127f8fa
@ -279,6 +279,14 @@ def gather_error_check(error, operand, start_indices, *,
|
||||
return out, assert_func(error, all_inbounds, msg)
|
||||
error_checks[lax.gather_p] = gather_error_check
|
||||
|
||||
def div_error_check(error, x, y):
|
||||
"""Checks for division by zero and NaN."""
|
||||
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
|
||||
msg = f'divided by zero at {summary()}'
|
||||
div_by_zero_err = assert_func(error, all_nonzero, msg)
|
||||
return nan_error_check(lax.div_p, div_by_zero_err, x, y)
|
||||
error_checks[lax.div_p] = div_error_check
|
||||
|
||||
def cond_error_check(error, index, *ops, branches, linear):
|
||||
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
|
||||
new_linear = (False, False, *linear)
|
||||
@ -418,7 +426,6 @@ add_nan_check(lax.integer_pow_p)
|
||||
add_nan_check(lax.tanh_p)
|
||||
add_nan_check(lax.log_p)
|
||||
add_nan_check(lax.atan2_p)
|
||||
add_nan_check(lax.div_p)
|
||||
add_nan_check(lax.sin_p)
|
||||
add_nan_check(lax.cos_p)
|
||||
add_nan_check(lax.sinh_p)
|
||||
|
@ -66,6 +66,26 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_jit={}".format(jit), "jit": jit}
|
||||
for jit in [False, True]))
|
||||
def test_jit_div_errors(self, jit):
|
||||
def f(x, y):
|
||||
return x/y
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.ones((3,)))
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.array([1, 0, 1]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive div')
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_jit={}".format(jit), "jit": jit}
|
||||
for jit in [False, True]))
|
||||
@ -192,7 +212,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
|
||||
out_carry, outs = f(carry, xs)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertArraysEqual(ch_outs, outs)
|
||||
self.assertArraysEqual(ch_out_carry, out_carry)
|
||||
|
||||
@ -201,7 +221,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
|
||||
out_carry, outs = f(carry, xs)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertArraysEqual(ch_outs, outs)
|
||||
self.assertArraysEqual(ch_out_carry, out_carry)
|
||||
|
||||
@ -230,7 +250,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, ch_out = checkify.checkify(f)(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertArraysEqual(ch_out, out)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -256,7 +276,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, ch_out = checkify.checkify(f)(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
self.assertArraysEqual(ch_out, out)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -274,13 +294,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
init_val = 0.
|
||||
err, _ = checkify.checkify(f)(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
|
||||
# error on second cond
|
||||
init_val = 1.
|
||||
err, _ = checkify.checkify(f)(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_while_loop_body_and_cond_error(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user