Add where argument to logsumexp.

This commit is contained in:
carlosgmartin 2024-04-08 12:57:06 -04:00
parent 29a2762b64
commit e98612e2ab
2 changed files with 20 additions and 6 deletions

View File

@ -30,18 +30,18 @@ import numpy as np
@overload
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
keepdims: bool = False, return_sign: Literal[False] = False) -> Array: ...
keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) -> Array: ...
@overload
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ...
keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) -> tuple[Array, Array]: ...
@overload
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
keepdims: bool = False, return_sign: bool = False) -> Array | tuple[Array, Array]: ...
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: ...
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
keepdims: bool = False, return_sign: bool = False) -> Array | tuple[Array, Array]:
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]:
r"""Log-sum-exp reduction.
Computes
@ -63,6 +63,7 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
where ``sign`` is the sign of the sums and ``result`` contains the
logarithms of their absolute values. If ``False`` only ``result`` is
returned and it will contain NaN values if the sums are negative.
where: Elements to include in the reduction.
Returns:
Either an array ``result`` or a pair of arrays ``(result, sign)``, depending
@ -75,14 +76,14 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
a_arr, = promote_args_inexact("logsumexp", a)
b_arr = a_arr # for type checking
pos_dims, dims = _reduction_dims(a_arr, axis)
amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims)
amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-jnp.inf)
amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype)))
if b is not None:
exp_a = lax.mul(exp_a, b_arr)
sumexp = exp_a.sum(axis=dims, keepdims=keepdims)
sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where)
sign = lax.sign(sumexp)
if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating):
sumexp = abs(sumexp)

View File

@ -190,6 +190,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
result = lsp_special.logsumexp(1.0, b=1.0)
self.assertEqual(result, 1.0)
@jtu.sample_product(
shape=[(0,), (1,), (2,), (3,), (4,), (5,)],
dtype=float_dtypes,
)
def testLogSumExpWhere(self, shape, dtype):
rng = jtu.rand_default(self.rng())
x = rng(shape, dtype)
rng = jtu.rand_bool(self.rng())
mask = rng(shape, bool)
y_expected = osp_special.logsumexp(x[mask]) if mask.any() else -jnp.inf
y_actual = lsp_special.logsumexp(x, where=mask)
self.assertAllClose(y_expected, y_actual, check_dtypes=False)
@jtu.sample_product(
shape=all_shapes,
dtype=float_dtypes,