diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ec2c140c..12d33b922 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -48,6 +48,9 @@ Remember to align the itemized text with the first line of an item within a list * {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices reshaped to the dimension of the input, following a similar change to {func}`numpy.unique` in NumPy 2.0. + * {func}`jax.scipy.special.logsumexp` with `return_sign=True` now uses the NumPy 2.0 + convention for the complex sign, `x / abs(x)`. This is consistent with the behavior + of the function in SciPy v1.13. * Deprecations & Removals * A number of previously deprecated functions have been removed, following a diff --git a/jax/_src/ops/special.py b/jax/_src/ops/special.py index 525e492ba..72826621a 100644 --- a/jax/_src/ops/special.py +++ b/jax/_src/ops/special.py @@ -75,33 +75,22 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None, a_arr, = promote_args_inexact("logsumexp", a) b_arr = a_arr # for type checking pos_dims, dims = _reduction_dims(a_arr, axis) - amax = jnp.max(a_arr, axis=dims, keepdims=keepdims) + amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims) amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0))) amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims) - # fast path if the result cannot be negative. - if b is None and not np.issubdtype(a_arr.dtype, np.complexfloating): - out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a_arr, amax_with_dims)), - axis=dims, keepdims=keepdims)), - amax) - sign = jnp.where(jnp.isnan(out), out, 1.0) - sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype) - else: - expsub = lax.exp(lax.sub(a_arr, amax_with_dims)) - if b is not None: - expsub = lax.mul(expsub, b_arr) - sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims) - sign = lax.stop_gradient(jnp.sign(sumexp)) - if np.issubdtype(sumexp.dtype, np.complexfloating): - if return_sign: - sumexp = sign*sumexp - out = lax.add(lax.log(sumexp), amax) - else: - out = lax.add(lax.log(lax.abs(sumexp)), amax) + exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype))) + if b is not None: + exp_a = lax.mul(exp_a, b_arr) + sumexp = exp_a.sum(axis=dims, keepdims=keepdims) + sign = lax.sign(sumexp) + if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating): + sumexp = abs(sumexp) + out = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype)) + if return_sign: return (out, sign) - if b is not None: - if not np.issubdtype(out.dtype, np.complexfloating): - with jax.debug_nans(False): - out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) + if b is not None and not np.issubdtype(out.dtype, np.complexfloating): + with jax.debug_nans(False): + out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out) return out diff --git a/tests/lax_scipy_test.py b/tests/lax_scipy_test.py index e4d0e9015..71bcbe4cd 100644 --- a/tests/lax_scipy_test.py +++ b/tests/lax_scipy_test.py @@ -38,6 +38,8 @@ from jax.scipy import cluster as lsp_cluster from jax import config config.parse_flags_with_absl() +scipy_version = jtu.parse_version(scipy.version.version) + all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)] compatible_shapes = [[(), ()], [(4,), (3, 4)], @@ -111,6 +113,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase): @jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion. def testLogSumExp(self, shapes, dtype, axis, keepdims, return_sign, use_b): + if jnp.issubdtype(dtype, jnp.complexfloating) and scipy_version < (1, 13, 0): + self.skipTest("logsumexp of complex input uses scipy 1.13.0 semantics.") if not jtu.test_device_matches(["cpu"]): rng = jtu.rand_some_inf_and_nan(self.rng()) else: @@ -151,6 +155,17 @@ class LaxBackedScipyTests(jtu.JaxTestCase): tol = {np.float32: 1E-6, np.float64: 1E-14} self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) + def testLogSumExpComplexSign(self): + # Tests behavior of complex sign, which changed in SciPy 1.13 + x = jnp.array([1 + 1j, 2 - 1j, -2 + 3j]) + logsumexp, sign = lsp_special.logsumexp(x, return_sign=True) + expected_sumexp = jnp.exp(x).sum() + expected_sign = expected_sumexp / abs(expected_sumexp).astype(x.dtype) + self.assertEqual(logsumexp.dtype, sign.real.dtype) + tol = 1E-4 if jtu.test_device_matches(['tpu']) else 1E-6 + self.assertAllClose(sign, expected_sign, rtol=tol) + self.assertAllClose(sign * np.exp(logsumexp).astype(x.dtype), expected_sumexp, rtol=tol) + def testLogSumExpZeros(self): # Regression test for https://github.com/google/jax/issues/5370 scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)