From 2c623d583719245a521e758402ff6fe1d8f72d12 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 26 Feb 2021 17:05:32 -0800 Subject: [PATCH] jax.scipy.special.logsumexp: fix b=0 corner case --- jax/_src/scipy/special.py | 3 ++- tests/lax_scipy_test.py | 8 ++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index e7088ff56..14ace4e87 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -101,7 +101,8 @@ expit.defjvps(lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans)) @_wraps(osp_special.logsumexp) def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False): if b is not None: - a, b = jnp.broadcast_arrays(a, b) + a, b = _promote_args_inexact("logsumexp", a, b) + a = jnp.where(b != 0, a, -jnp.inf) pos_dims, dims = _reduction_dims(a, axis) amax = jnp.max(a, axis=dims, keepdims=keepdims) amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0))) diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index 09ff6c0f8..a93861f95 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -146,6 +146,14 @@ class LaxBackedScipyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) self._CompileAndCheck(lax_fun, args_maker) + def testLogSumExpZeros(self): + # Regression test for https://github.com/google/jax/issues/5370 + scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b) + lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b) + args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])] + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker) + self._CompileAndCheck(lax_fun, args_maker) + @parameterized.named_parameters(itertools.chain.from_iterable( jtu.cases_from_list( {"testcase_name": jtu.format_test_name_suffix(