logsumexp: fix issue with debug_nans

This commit is contained in:
Jake VanderPlas 2021-08-18 11:54:05 -07:00
parent 52476d64fa
commit 730ae33e03
2 changed files with 7 additions and 2 deletions

View File

@ -139,7 +139,9 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
return (out, sign)
if b is not None:
if not np.issubdtype(out.dtype, np.complexfloating):
out = jnp.where(sign < 0, np.nan, out)
# Use jnp.array(nan) to avoid false positives in debug_nans
# (see https://github.com/google/jax/issues/7634)
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
return out

View File

@ -219,7 +219,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
with jax.debug_nans(True):
with jax.disable_jit():
result = lsp_special.logsumexp(1.0)
self.assertEqual(result, 1.0)
self.assertEqual(result, 1.0)
result = lsp_special.logsumexp(1.0, b=1.0)
self.assertEqual(result, 1.0)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(