From 3116ed52a94eb996db63a67d16109c03b5d94845 Mon Sep 17 00:00:00 2001 From: Lena Martens Date: Tue, 15 Nov 2022 07:35:15 -0800 Subject: [PATCH] Checkify: fix nan check when primitive has multiple results. PiperOrigin-RevId: 488653856 --- jax/_src/checkify.py | 6 +++--- tests/checkify_test.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 6926b8700..47e05f6df 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -568,10 +568,10 @@ def nan_error_check(prim, error, enabled_errors, *in_vals, **params): def isnan(x): if isinstance(x, prng.PRNGKeyArray): return False - return jnp.isnan(x) + return jnp.any(jnp.isnan(x)) - any_nans = (jnp.any(isnan(x) for x in out) - if prim.multiple_results else jnp.any(isnan(out))) + any_nans = (jnp.any(jnp.array([isnan(x) for x in out])) + if prim.multiple_results else isnan(out)) msg = f'nan generated by primitive {prim.name} at {summary()}' return out, assert_func(error, any_nans, msg, None) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index a13192a44..733a63661 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -973,6 +973,17 @@ class AssertPrimitiveTests(jtu.JaxTestCase): self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), "hi!") + def test_psum_nan_check(self): + @partial(jax.vmap, axis_name="i") + def f(x, y): + return lax.psum((x, y), axis_name="i") + + cf = checkify.checkify(f, errors=checkify.nan_checks) + err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2))) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), "nan generated by primitive psum") + + class LowerableChecksTest(jtu.JaxTestCase): def setUp(self): super().setUp()