mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
logsumexp: use NumPy 2.0 convention for complex sign
This commit is contained in:
parent
08837a9919
commit
7d6a134f4e
@ -48,6 +48,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
|
||||
reshaped to the dimension of the input, following a similar change to
|
||||
{func}`numpy.unique` in NumPy 2.0.
|
||||
* {func}`jax.scipy.special.logsumexp` with `return_sign=True` now uses the NumPy 2.0
|
||||
convention for the complex sign, `x / abs(x)`. This is consistent with the behavior
|
||||
of the function in SciPy v1.13.
|
||||
|
||||
* Deprecations & Removals
|
||||
* A number of previously deprecated functions have been removed, following a
|
||||
|
@ -75,33 +75,22 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
a_arr, = promote_args_inexact("logsumexp", a)
|
||||
b_arr = a_arr # for type checking
|
||||
pos_dims, dims = _reduction_dims(a_arr, axis)
|
||||
amax = jnp.max(a_arr, axis=dims, keepdims=keepdims)
|
||||
amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims)
|
||||
amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
|
||||
amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
|
||||
# fast path if the result cannot be negative.
|
||||
if b is None and not np.issubdtype(a_arr.dtype, np.complexfloating):
|
||||
out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a_arr, amax_with_dims)),
|
||||
axis=dims, keepdims=keepdims)),
|
||||
amax)
|
||||
sign = jnp.where(jnp.isnan(out), out, 1.0)
|
||||
sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
|
||||
else:
|
||||
expsub = lax.exp(lax.sub(a_arr, amax_with_dims))
|
||||
if b is not None:
|
||||
expsub = lax.mul(expsub, b_arr)
|
||||
sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)
|
||||
|
||||
sign = lax.stop_gradient(jnp.sign(sumexp))
|
||||
if np.issubdtype(sumexp.dtype, np.complexfloating):
|
||||
if return_sign:
|
||||
sumexp = sign*sumexp
|
||||
out = lax.add(lax.log(sumexp), amax)
|
||||
else:
|
||||
out = lax.add(lax.log(lax.abs(sumexp)), amax)
|
||||
exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype)))
|
||||
if b is not None:
|
||||
exp_a = lax.mul(exp_a, b_arr)
|
||||
sumexp = exp_a.sum(axis=dims, keepdims=keepdims)
|
||||
sign = lax.sign(sumexp)
|
||||
if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating):
|
||||
sumexp = abs(sumexp)
|
||||
out = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype))
|
||||
|
||||
if return_sign:
|
||||
return (out, sign)
|
||||
if b is not None:
|
||||
if not np.issubdtype(out.dtype, np.complexfloating):
|
||||
with jax.debug_nans(False):
|
||||
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
|
||||
if b is not None and not np.issubdtype(out.dtype, np.complexfloating):
|
||||
with jax.debug_nans(False):
|
||||
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
|
||||
return out
|
||||
|
@ -38,6 +38,8 @@ from jax.scipy import cluster as lsp_cluster
|
||||
from jax import config
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
scipy_version = jtu.parse_version(scipy.version.version)
|
||||
|
||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||
compatible_shapes = [[(), ()],
|
||||
[(4,), (3, 4)],
|
||||
@ -111,6 +113,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||
def testLogSumExp(self, shapes, dtype, axis,
|
||||
keepdims, return_sign, use_b):
|
||||
if jnp.issubdtype(dtype, jnp.complexfloating) and scipy_version < (1, 13, 0):
|
||||
self.skipTest("logsumexp of complex input uses scipy 1.13.0 semantics.")
|
||||
if not jtu.test_device_matches(["cpu"]):
|
||||
rng = jtu.rand_some_inf_and_nan(self.rng())
|
||||
else:
|
||||
@ -151,6 +155,17 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
tol = {np.float32: 1E-6, np.float64: 1E-14}
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
def testLogSumExpComplexSign(self):
|
||||
# Tests behavior of complex sign, which changed in SciPy 1.13
|
||||
x = jnp.array([1 + 1j, 2 - 1j, -2 + 3j])
|
||||
logsumexp, sign = lsp_special.logsumexp(x, return_sign=True)
|
||||
expected_sumexp = jnp.exp(x).sum()
|
||||
expected_sign = expected_sumexp / abs(expected_sumexp).astype(x.dtype)
|
||||
self.assertEqual(logsumexp.dtype, sign.real.dtype)
|
||||
tol = 1E-4 if jtu.test_device_matches(['tpu']) else 1E-6
|
||||
self.assertAllClose(sign, expected_sign, rtol=tol)
|
||||
self.assertAllClose(sign * np.exp(logsumexp).astype(x.dtype), expected_sumexp, rtol=tol)
|
||||
|
||||
def testLogSumExpZeros(self):
|
||||
# Regression test for https://github.com/google/jax/issues/5370
|
||||
scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
|
||||
|
Loading…
x
Reference in New Issue
Block a user