mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
logsumexp: use NumPy 2.0 convention for complex sign
This commit is contained in:
parent
08837a9919
commit
7d6a134f4e
@ -48,6 +48,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
|
* {func}`jax.numpy.unique` with `return_inverse = True` returns inverse indices
|
||||||
reshaped to the dimension of the input, following a similar change to
|
reshaped to the dimension of the input, following a similar change to
|
||||||
{func}`numpy.unique` in NumPy 2.0.
|
{func}`numpy.unique` in NumPy 2.0.
|
||||||
|
* {func}`jax.scipy.special.logsumexp` with `return_sign=True` now uses the NumPy 2.0
|
||||||
|
convention for the complex sign, `x / abs(x)`. This is consistent with the behavior
|
||||||
|
of the function in SciPy v1.13.
|
||||||
|
|
||||||
* Deprecations & Removals
|
* Deprecations & Removals
|
||||||
* A number of previously deprecated functions have been removed, following a
|
* A number of previously deprecated functions have been removed, following a
|
||||||
|
@ -75,33 +75,22 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
|||||||
a_arr, = promote_args_inexact("logsumexp", a)
|
a_arr, = promote_args_inexact("logsumexp", a)
|
||||||
b_arr = a_arr # for type checking
|
b_arr = a_arr # for type checking
|
||||||
pos_dims, dims = _reduction_dims(a_arr, axis)
|
pos_dims, dims = _reduction_dims(a_arr, axis)
|
||||||
amax = jnp.max(a_arr, axis=dims, keepdims=keepdims)
|
amax = jnp.max(a_arr.real, axis=dims, keepdims=keepdims)
|
||||||
amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
|
amax = lax.stop_gradient(lax.select(jnp.isfinite(amax), amax, lax.full_like(amax, 0)))
|
||||||
amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
|
amax_with_dims = amax if keepdims else lax.expand_dims(amax, pos_dims)
|
||||||
# fast path if the result cannot be negative.
|
|
||||||
if b is None and not np.issubdtype(a_arr.dtype, np.complexfloating):
|
|
||||||
out = lax.add(lax.log(jnp.sum(lax.exp(lax.sub(a_arr, amax_with_dims)),
|
|
||||||
axis=dims, keepdims=keepdims)),
|
|
||||||
amax)
|
|
||||||
sign = jnp.where(jnp.isnan(out), out, 1.0)
|
|
||||||
sign = jnp.where(jnp.isneginf(out), 0.0, sign).astype(out.dtype)
|
|
||||||
else:
|
|
||||||
expsub = lax.exp(lax.sub(a_arr, amax_with_dims))
|
|
||||||
if b is not None:
|
|
||||||
expsub = lax.mul(expsub, b_arr)
|
|
||||||
sumexp = jnp.sum(expsub, axis=dims, keepdims=keepdims)
|
|
||||||
|
|
||||||
sign = lax.stop_gradient(jnp.sign(sumexp))
|
exp_a = lax.exp(lax.sub(a_arr, amax_with_dims.astype(a_arr.dtype)))
|
||||||
if np.issubdtype(sumexp.dtype, np.complexfloating):
|
if b is not None:
|
||||||
if return_sign:
|
exp_a = lax.mul(exp_a, b_arr)
|
||||||
sumexp = sign*sumexp
|
sumexp = exp_a.sum(axis=dims, keepdims=keepdims)
|
||||||
out = lax.add(lax.log(sumexp), amax)
|
sign = lax.sign(sumexp)
|
||||||
else:
|
if return_sign or not np.issubdtype(a_arr.dtype, np.complexfloating):
|
||||||
out = lax.add(lax.log(lax.abs(sumexp)), amax)
|
sumexp = abs(sumexp)
|
||||||
|
out = lax.add(lax.log(sumexp), amax.astype(sumexp.dtype))
|
||||||
|
|
||||||
if return_sign:
|
if return_sign:
|
||||||
return (out, sign)
|
return (out, sign)
|
||||||
if b is not None:
|
if b is not None and not np.issubdtype(out.dtype, np.complexfloating):
|
||||||
if not np.issubdtype(out.dtype, np.complexfloating):
|
with jax.debug_nans(False):
|
||||||
with jax.debug_nans(False):
|
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
|
||||||
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
|
|
||||||
return out
|
return out
|
||||||
|
@ -38,6 +38,8 @@ from jax.scipy import cluster as lsp_cluster
|
|||||||
from jax import config
|
from jax import config
|
||||||
config.parse_flags_with_absl()
|
config.parse_flags_with_absl()
|
||||||
|
|
||||||
|
scipy_version = jtu.parse_version(scipy.version.version)
|
||||||
|
|
||||||
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
all_shapes = [(), (4,), (3, 4), (3, 1), (1, 4), (2, 1, 4)]
|
||||||
compatible_shapes = [[(), ()],
|
compatible_shapes = [[(), ()],
|
||||||
[(4,), (3, 4)],
|
[(4,), (3, 4)],
|
||||||
@ -111,6 +113,8 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
|||||||
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
@jax.numpy_rank_promotion('allow') # This test explicitly exercises implicit rank promotion.
|
||||||
def testLogSumExp(self, shapes, dtype, axis,
|
def testLogSumExp(self, shapes, dtype, axis,
|
||||||
keepdims, return_sign, use_b):
|
keepdims, return_sign, use_b):
|
||||||
|
if jnp.issubdtype(dtype, jnp.complexfloating) and scipy_version < (1, 13, 0):
|
||||||
|
self.skipTest("logsumexp of complex input uses scipy 1.13.0 semantics.")
|
||||||
if not jtu.test_device_matches(["cpu"]):
|
if not jtu.test_device_matches(["cpu"]):
|
||||||
rng = jtu.rand_some_inf_and_nan(self.rng())
|
rng = jtu.rand_some_inf_and_nan(self.rng())
|
||||||
else:
|
else:
|
||||||
@ -151,6 +155,17 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
|||||||
tol = {np.float32: 1E-6, np.float64: 1E-14}
|
tol = {np.float32: 1E-6, np.float64: 1E-14}
|
||||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
|
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
|
||||||
|
|
||||||
|
def testLogSumExpComplexSign(self):
|
||||||
|
# Tests behavior of complex sign, which changed in SciPy 1.13
|
||||||
|
x = jnp.array([1 + 1j, 2 - 1j, -2 + 3j])
|
||||||
|
logsumexp, sign = lsp_special.logsumexp(x, return_sign=True)
|
||||||
|
expected_sumexp = jnp.exp(x).sum()
|
||||||
|
expected_sign = expected_sumexp / abs(expected_sumexp).astype(x.dtype)
|
||||||
|
self.assertEqual(logsumexp.dtype, sign.real.dtype)
|
||||||
|
tol = 1E-4 if jtu.test_device_matches(['tpu']) else 1E-6
|
||||||
|
self.assertAllClose(sign, expected_sign, rtol=tol)
|
||||||
|
self.assertAllClose(sign * np.exp(logsumexp).astype(x.dtype), expected_sumexp, rtol=tol)
|
||||||
|
|
||||||
def testLogSumExpZeros(self):
|
def testLogSumExpZeros(self):
|
||||||
# Regression test for https://github.com/google/jax/issues/5370
|
# Regression test for https://github.com/google/jax/issues/5370
|
||||||
scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
|
scipy_fun = lambda a, b: osp_special.logsumexp(a, b=b)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user