Fix issue with infinities in logsumexp

This commit is contained in:
Jake VanderPlas 2021-08-02 15:27:24 -07:00
parent 6984f30d5e
commit e131343274
2 changed files with 9 additions and 1 deletions

View File

@ -121,7 +121,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
axis=dims, keepdims=keepdims)),
amax)
sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
sign = jnp.where(out == -np.inf, 0.0, sign)
sign = jnp.where(jnp.isneginf(out), 0.0, sign)
else:
expsub = lax.exp(lax.sub(a, amax_with_dims))
if b is not None:

View File

@ -25,6 +25,7 @@ from absl.testing import parameterized
import numpy as np
import scipy.special as osp_special
import jax
from jax._src import api
from jax import numpy as jnp
from jax import lax
@ -205,6 +206,13 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)
def testLogSumExpOnes(self):
# Regression test for https://github.com/google/jax/issues/7390
args_maker = lambda: [np.ones(4, dtype='float32')]
with jax.debug_infs(True):
self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker)
self._CompileAndCheck(lsp_special.logsumexp, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(