mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jax.scipy.special.logsumexp: fix b=0 corner case
This commit is contained in:
parent
a0c5a80971
commit
2c623d5837
@ -101,7 +101,8 @@ expit.defjvps(lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans))
|
||||
@_wraps(osp_special.logsumexp)
|
||||
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
|
||||
if b is not None:
|
||||
a, b = jnp.broadcast_arrays(a, b)
|
||||
a, b = _promote_args_inexact("logsumexp", a, b)
|
||||
a = jnp.where(b != 0, a, -jnp.inf)
|
||||
pos_dims, dims = _reduction_dims(a, axis)
|
||||
amax = jnp.max(a, axis=dims, keepdims=keepdims)
|
||||
amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))
|
||||
|
@ -146,6 +146,14 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testLogSumExpZeros(self):
|
||||
# Regression test for https://github.com/google/jax/issues/5370
|
||||
scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
|
||||
lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
|
||||
args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(itertools.chain.from_iterable(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(
|
||||
|
Loading…
x
Reference in New Issue
Block a user