mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Fix issue with infinities in logsumexp
This commit is contained in:
parent
6984f30d5e
commit
e131343274
@ -121,7 +121,7 @@ def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
axis=dims, keepdims=keepdims)),
|
||||
amax)
|
||||
sign = jnp.where(jnp.isnan(out), np.nan, 1.0).astype(out.dtype)
|
||||
sign = jnp.where(out == -np.inf, 0.0, sign)
|
||||
sign = jnp.where(jnp.isneginf(out), 0.0, sign)
|
||||
else:
|
||||
expsub = lax.exp(lax.sub(a, amax_with_dims))
|
||||
if b is not None:
|
||||
|
@ -25,6 +25,7 @@ from absl.testing import parameterized
|
||||
import numpy as np
|
||||
import scipy.special as osp_special
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax import numpy as jnp
|
||||
from jax import lax
|
||||
@ -205,6 +206,13 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testLogSumExpOnes(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7390
|
||||
args_maker = lambda: [np.ones(4, dtype='float32')]
|
||||
with jax.debug_infs(True):
|
||||
self._CheckAgainstNumpy(osp_special.logsumexp, lsp_special.logsumexp, args_maker)
|
||||
self._CompileAndCheck(lsp_special.logsumexp, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
|
Loading…
x
Reference in New Issue
Block a user