From e13134327438c3107b2f837545cbbf13b22f15bf Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 2 Aug 2021 15:27:24 -0700 Subject: [PATCH] Fix issue with infinities in logsumexp --- jax/_src/scipy/special.py | 2 +- tests/lax_scipy_test.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 7c39db27f..546a2ae05 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -121,7 +121,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): axis=dims, keepdims=keepdims)), amax) sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype) - sign = jnp.where(out == -np.inf, 0.0, sign) + sign = jnp.where(jnp.isneginf(out), 0.0, sign) else: expsub = lax.exp(lax.sub(a, amax_with_dims)) if b is not None: diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 25c5215a3..67ef1b2db 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -25,6 +25,7 @@ from absl.testing import parameterized import numpy as np import scipy.special as osp_special +import jax from jax._src import api from jax import numpy as jnp from jax import lax @@ -205,6 +206,13 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) + def testLogSumExpOnes(self): + # Regression test for https://github.com/google/jax/issues/7390 + args_maker = lambda: [np.ones(4, dtype='float32')] + with jax.debug_infs(True): + self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker) + self._CompileAndCheck(lsp_special.logsumexp, args_maker) + @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix(