mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
logsumexp: fix issue with debug_nans
This commit is contained in:
parent
52476d64fa
commit
730ae33e03
@ -139,7 +139,9 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
return (out, sign)
|
||||
if b is not None:
|
||||
if not np.issubdtype(out.dtype, np.complexfloating):
|
||||
out = jnp.where(sign < 0, np.nan, out)
|
||||
# Use jnp.array(nan) to avoid false positives in debug_nans
|
||||
# (see https://github.com/google/jax/issues/7634)
|
||||
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
|
||||
return out
|
||||
|
||||
|
||||
|
@ -219,7 +219,10 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
with jax.debug_nans(True):
|
||||
with jax.disable_jit():
|
||||
result = lsp_special.logsumexp(1.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
|
||||
result = lsp_special.logsumexp(1.0, b=1.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
|
Loading…
x
Reference in New Issue
Block a user