jax.scipy.special.logsumexp: fix b=0 corner case

This commit is contained in:
Jake VanderPlas 2021-02-26 17:05:32 -08:00
parent a0c5a80971
commit 2c623d5837
2 changed files with 10 additions and 1 deletions

View File

@ -101,7 +101,8 @@ expit.defjvps(lambda g, ans, x: g * ans * (lax._const(ans, 1) - ans))
@_wraps(osp_special.logsumexp)
def logsumexp(a, axis=None, b=None, keepdims=False, return_sign=False):
if b is not None:
a, b = jnp.broadcast_arrays(a, b)
a, b = _promote_args_inexact("logsumexp", a, b)
a = jnp.where(b != 0, a, -jnp.inf)
pos_dims, dims = _reduction_dims(a, axis)
amax = jnp.max(a, axis=dims, keepdims=keepdims)
amax = lax.stop_gradient(lax.select(lax.is_finite(amax), amax, lax.full_like(amax, 0)))

View File

@ -146,6 +146,14 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)
def testLogSumExpZeros(self):
# Regression test for https://github.com/google/jax/issues/5370
scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
lax_fun = lambda a, b: lsp_special.logsumexp(a, b=b)
args_maker = lambda: [np.array([-1000, -2]), np.array([1, 0])]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker)
self._CompileAndCheck(lax_fun, args_maker)
@parameterized.named_parameters(itertools.chain.from_iterable(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(