mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix logsumexp issue with debug_nans and disable_jit
This commit is contained in:
parent
476642578b
commit
0e256ddeb7
@ -120,8 +120,8 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a, amax_with_dims)),
|
||||
axis=dims, keepdims=keepdims)),
|
||||
amax)
|
||||
sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
|
||||
sign = jnp.where(jnp.isneginf(out), 0.0, sign)
|
||||
sign = jnp.where(jnp.isnan(out), out, 1.0)
|
||||
sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
|
||||
else:
|
||||
expsub = lax.exp(lax.sub(a, amax_with_dims))
|
||||
if b is not None:
|
||||
|
@ -214,6 +214,13 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker)
|
||||
self._CompileAndCheck(lsp_special.logsumexp, args_maker)
|
||||
|
||||
def testLogSumExpNans(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7634
|
||||
with jax.debug_nans(True):
|
||||
with jax.disable_jit():
|
||||
result = lsp_special.logsumexp(1.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
|
Loading…
x
Reference in New Issue
Block a user