Checkify: fix nan check when primitive has multiple results.

PiperOrigin-RevId: 488653856
This commit is contained in:
Lena Martens 2022-11-15 07:35:15 -08:00 committed by jax authors
parent 108bc83520
commit 3116ed52a9
2 changed files with 14 additions and 3 deletions

View File

@ -568,10 +568,10 @@ def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
def isnan(x):
if isinstance(x, prng.PRNGKeyArray):
return False
return jnp.isnan(x)
return jnp.any(jnp.isnan(x))
any_nans = (jnp.any(isnan(x) for x in out)
if prim.multiple_results else jnp.any(isnan(out)))
any_nans = (jnp.any(jnp.array([isnan(x) for x in out]))
if prim.multiple_results else isnan(out))
msg = f'nan generated by primitive {prim.name} at {summary()}'
return out, assert_func(error, any_nans, msg, None)

View File

@ -973,6 +973,17 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "hi!")
def test_psum_nan_check(self):
@partial(jax.vmap, axis_name="i")
def f(x, y):
return lax.psum((x, y), axis_name="i")
cf = checkify.checkify(f, errors=checkify.nan_checks)
err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2)))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive psum")
class LowerableChecksTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()