logsumexp: use NumPy 2.0 convention for complex sign

This commit is contained in:
Jake VanderPlas 2024-01-16 16:15:06 -08:00
parent 08837a9919
commit 7d6a134f4e
3 changed files with 31 additions and 24 deletions

View File

@ -48,6 +48,9 @@ Remember to align the itemized text with the first line of an item within a list
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
reshaped to the dimension of the input, following a similar change to
{func}`numpy.unique` in NumPy 2.0.
* {func}`jax.scipy.special.logsumexp` with `return_sign=True` now uses the NumPy 2.0
convention for the complex sign, `x / abs(x)`. This is consistent with the behavior
of the function in SciPy v1.13.
* Deprecations & Removals
* A number of previously deprecated functions have been removed, following a

View File

@ -75,33 +75,22 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
a_arr, = promote_args_inexact("logsumexp", a)
b_arr = a_arr # for type checking
pos_dims, dims = _reduction_dims(a_arr, axis)
amax = jnp.max(a_arr, axis=dims, keepdims=keepdims)
amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims)
amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
# fast path if the result cannot be negative.
if b is None and not np.issubdtype(a_arr.dtype, np.complexfloating):
out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a_arr, amax_with_dims)),
axis=dims, keepdims=keepdims)),
amax)
sign = jnp.where(jnp.isnan(out), out, 1.0)
sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
else:
expsub = lax.exp(lax.sub(a_arr, amax_with_dims))
if b is not None:
expsub = lax.mul(expsub, b_arr)
sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)
sign = lax.stop_gradient(jnp.sign(sumexp))
if np.issubdtype(sumexp.dtype, np.complexfloating):
if return_sign:
sumexp = sign*sumexp
out = lax.add(lax.log(sumexp), amax)
else:
out = lax.add(lax.log(lax.abs(sumexp)), amax)
exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype)))
if b is not None:
exp_a = lax.mul(exp_a, b_arr)
sumexp = exp_a.sum(axis=dims, keepdims=keepdims)
sign = lax.sign(sumexp)
if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating):
sumexp = abs(sumexp)
out = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype))
if return_sign:
return (out, sign)
if b is not None:
if not np.issubdtype(out.dtype, np.complexfloating):
with jax.debug_nans(False):
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
if b is not None and not np.issubdtype(out.dtype, np.complexfloating):
with jax.debug_nans(False):
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
return out

View File

@ -38,6 +38,8 @@ from jax.scipy import cluster as lsp_cluster
from jax import config
config.parse_flags_with_absl()
scipy_version = jtu.parse_version(scipy.version.version)
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
compatible_shapes = [[(), ()],
[(4,), (3, 4)],
@ -111,6 +113,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
def testLogSumExp(self, shapes, dtype, axis,
keepdims, return_sign, use_b):
if jnp.issubdtype(dtype, jnp.complexfloating) and scipy_version < (1, 13, 0):
self.skipTest("logsumexp of complex input uses scipy 1.13.0 semantics.")
if not jtu.test_device_matches(["cpu"]):
rng = jtu.rand_some_inf_and_nan(self.rng())
else:
@ -151,6 +155,17 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
tol = {np.float32: 1E-6, np.float64: 1E-14}
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
def testLogSumExpComplexSign(self):
# Tests behavior of complex sign, which changed in SciPy 1.13
x = jnp.array([1 + 1j, 2 - 1j, -2 + 3j])
logsumexp, sign = lsp_special.logsumexp(x, return_sign=True)
expected_sumexp = jnp.exp(x).sum()
expected_sign = expected_sumexp / abs(expected_sumexp).astype(x.dtype)
self.assertEqual(logsumexp.dtype, sign.real.dtype)
tol = 1E-4 if jtu.test_device_matches(['tpu']) else 1E-6
self.assertAllClose(sign, expected_sign, rtol=tol)
self.assertAllClose(sign * np.exp(logsumexp).astype(x.dtype), expected_sumexp, rtol=tol)
def testLogSumExpZeros(self):
# Regression test for https://github.com/google/jax/issues/5370
scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)