Reverts 0f103d33849ca017e6a199d0f79fa0d83b373995

PiperOrigin-RevId: 659670593
This commit is contained in:
Sergei Lebedev 2024-08-05 13:51:14 -07:00 committed by jax authors
parent c2c04e054e
commit e416c6675a
2 changed files with 0 additions and 64 deletions

View File

@ -14,7 +14,6 @@
from __future__ import annotations
import functools
from typing import overload, Literal
import jax
@ -41,7 +40,6 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: ...
@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 3, 4, 5))
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]:
r"""Log-sum-exp reduction.
@ -97,36 +95,3 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
with jax.debug_nans(False):
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
return out
@logsumexp.defjvp
def _logsumexp_jvp(axis, keepdims, return_sign, where, primals, tangents):
a, b = primals
a_dot, b_dot = tangents
out = logsumexp(
a,
axis=axis,
b=b,
keepdims=keepdims,
return_sign=return_sign,
where=where,
)
if return_sign:
out, sign = out
if b is None:
out_dot = jnp.sum(
a_dot * jnp.exp(a - out), axis=axis, keepdims=keepdims, where=where,
)
else:
out_dot = jnp.sum(
(b * a_dot + b_dot) * jnp.exp(a - out),
axis=axis,
keepdims=keepdims,
where=where,
)
if return_sign:
sign_dot = jnp.zeros_like(sign)
return (out, sign), (out_dot, sign_dot)
else:
return out, out_dot

View File

@ -18,7 +18,6 @@ import itertools
import unittest
from absl.testing import absltest
from absl.testing import parameterized
import numpy as np
import scipy.integrate
@ -203,34 +202,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
y_actual = lsp_special.logsumexp(x, where=mask)
self.assertAllClose(y_expected, y_actual, check_dtypes=False)
def testLogSumExpZerosJac(self):
# Regression test for https://github.com/google/jax/issues/22398
fun = lambda b: lsp_special.logsumexp(jnp.zeros(2), axis=0, b=b)
np.testing.assert_array_equal(
jax.jacfwd(fun)(jnp.array([1.0, 0.0])),
jnp.ones(2),
)
@parameterized.product(
axis=[None, 0], with_b=[False, True], return_sign=[False, True]
)
def testLogSumExpJac(self, axis, with_b, return_sign):
fun = partial(lsp_special.logsumexp, axis=axis, return_sign=return_sign)
orig_fun = partial(
lsp_special.logsumexp.fun, axis=axis, return_sign=return_sign
)
tol = 5e-5 if jtu.test_device_matches(["tpu"]) else 1e-06
for i in range(100):
a = jax.random.normal(jax.random.key(i), (2,)) * 4.2
b = None
if with_b:
b = jax.random.uniform(jax.random.key(i), ())
jax.tree.map(
lambda x, y: np.testing.assert_allclose(x, y, atol=tol, rtol=tol),
jax.jacfwd(fun)(a, b=b),
jax.jacfwd(orig_fun)(a, b=b),
)
@jtu.sample_product(
shape=all_shapes,
dtype=float_dtypes,