mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add where argument to logsumexp.
This commit is contained in:
parent
29a2762b64
commit
e98612e2ab
@ -30,18 +30,18 @@ import numpy as np
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, return_sign: Literal[False] = False) -> Array: ...
|
||||
keepdims: bool = False, return_sign: Literal[False] = False, where: ArrayLike | None = None) -> Array: ...
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, *, return_sign: Literal[True]) -> tuple[Array, Array]: ...
|
||||
keepdims: bool = False, *, return_sign: Literal[True], where: ArrayLike | None = None) -> tuple[Array, Array]: ...
|
||||
|
||||
@overload
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, return_sign: bool = False) -> Array | tuple[Array, Array]: ...
|
||||
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: ...
|
||||
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, return_sign: bool = False) -> Array | tuple[Array, Array]:
|
||||
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]:
|
||||
r"""Log-sum-exp reduction.
|
||||
|
||||
Computes
|
||||
@ -63,6 +63,7 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
where ``sign`` is the sign of the sums and ``result`` contains the
|
||||
logarithms of their absolute values. If ``False`` only ``result`` is
|
||||
returned and it will contain NaN values if the sums are negative.
|
||||
where: Elements to include in the reduction.
|
||||
|
||||
Returns:
|
||||
Either an array ``result`` or a pair of arrays ``(result, sign)``, depending
|
||||
@ -75,14 +76,14 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
a_arr, = promote_args_inexact("logsumexp", a)
|
||||
b_arr = a_arr # for type checking
|
||||
pos_dims, dims = _reduction_dims(a_arr, axis)
|
||||
amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims)
|
||||
amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims, where=where, initial=-jnp.inf)
|
||||
amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
|
||||
amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
|
||||
|
||||
exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype)))
|
||||
if b is not None:
|
||||
exp_a = lax.mul(exp_a, b_arr)
|
||||
sumexp = exp_a.sum(axis=dims, keepdims=keepdims)
|
||||
sumexp = exp_a.sum(axis=dims, keepdims=keepdims, where=where)
|
||||
sign = lax.sign(sumexp)
|
||||
if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating):
|
||||
sumexp = abs(sumexp)
|
||||
|
@ -190,6 +190,19 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
result = lsp_special.logsumexp(1.0, b=1.0)
|
||||
self.assertEqual(result, 1.0)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(0,), (1,), (2,), (3,), (4,), (5,)],
|
||||
dtype=float_dtypes,
|
||||
)
|
||||
def testLogSumExpWhere(self, shape, dtype):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
x = rng(shape, dtype)
|
||||
rng = jtu.rand_bool(self.rng())
|
||||
mask = rng(shape, bool)
|
||||
y_expected = osp_special.logsumexp(x[mask]) if mask.any() else -jnp.inf
|
||||
y_actual = lsp_special.logsumexp(x, where=mask)
|
||||
self.assertAllClose(y_expected, y_actual, check_dtypes=False)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=float_dtypes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user