Fix logsumexp issue with debug_nans and disable_jit

This commit is contained in:
Jake VanderPlas 2021-08-17 13:47:06 -07:00
parent 476642578b
commit 0e256ddeb7
2 changed files with 9 additions and 2 deletions

View File

@ -120,8 +120,8 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
axis=dims, keepdims=keepdims)),
amax)
sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
sign = jnp.where(jnp.isneginf(out), 0.0, sign)
sign = jnp.where(jnp.isnan(out), out, 1.0)
sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
else:
expsub = lax.exp(lax.sub(a, amax_with_dims))
if b is not None:

View File

@ -214,6 +214,13 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker)
self._CompileAndCheck(lsp_special.logsumexp, args_maker)
def testLogSumExpNans(self):
# Regression test for https://github.com/google/jax/issues/7634
with jax.debug_nans(True):
with jax.disable_jit():
result = lsp_special.logsumexp(1.0)
self.assertEqual(result, 1.0)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(