From 0e256ddeb7845434bc6ac031962c355a33ee1adf Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 17 Aug 2021 13:47:06 -0700 Subject: [PATCH] Fix logsumexp issue with debug_nans and disable_jit --- jax/_src/scipy/special.py | 4 ++-- tests/lax_scipy_test.py | 7 +++++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index cd264ba11..084b23baa 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -120,8 +120,8 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a, amax_with_dims)), axis=dims, keepdims=keepdims)), amax) - sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) - sign = jnp.where(jnp.isneginf(out), 0.0, sign) + sign = jnp.where(jnp.isnan(out), out, 1.0) + sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index d7a22e536..f16ea1ff0 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -214,6 +214,13 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker) self._CompileAndCheck(lsp_special.logsumexp, args_maker) + def testLogSumExpNans(self): + # Regression test for https://github.com/google/jax/issues/7634 + with jax.debug_nans(True): + with jax.disable_jit(): + result = lsp_special.logsumexp(1.0) + self.assertEqual(result, 1.0) + @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix(