mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Reverts 0f103d33849ca017e6a199d0f79fa0d83b373995
PiperOrigin-RevId: 659670593
This commit is contained in:
parent
c2c04e054e
commit
e416c6675a
@ -14,7 +14,6 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import overload, Literal
|
||||
|
||||
import jax
|
||||
@ -41,7 +40,6 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]: ...
|
||||
|
||||
@functools.partial(jax.custom_jvp, nondiff_argnums=(1, 3, 4, 5))
|
||||
def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
keepdims: bool = False, return_sign: bool = False, where: ArrayLike | None = None) -> Array | tuple[Array, Array]:
|
||||
r"""Log-sum-exp reduction.
|
||||
@ -97,36 +95,3 @@ def logsumexp(a: ArrayLike, axis: Axis = None, b: ArrayLike | None = None,
|
||||
with jax.debug_nans(False):
|
||||
out = jnp.where(sign < 0, jnp.array(np.nan, dtype=out.dtype), out)
|
||||
return out
|
||||
|
||||
@logsumexp.defjvp
|
||||
def _logsumexp_jvp(axis, keepdims, return_sign, where, primals, tangents):
|
||||
a, b = primals
|
||||
a_dot, b_dot = tangents
|
||||
out = logsumexp(
|
||||
a,
|
||||
axis=axis,
|
||||
b=b,
|
||||
keepdims=keepdims,
|
||||
return_sign=return_sign,
|
||||
where=where,
|
||||
)
|
||||
|
||||
if return_sign:
|
||||
out, sign = out
|
||||
|
||||
if b is None:
|
||||
out_dot = jnp.sum(
|
||||
a_dot * jnp.exp(a - out), axis=axis, keepdims=keepdims, where=where,
|
||||
)
|
||||
else:
|
||||
out_dot = jnp.sum(
|
||||
(b * a_dot + b_dot) * jnp.exp(a - out),
|
||||
axis=axis,
|
||||
keepdims=keepdims,
|
||||
where=where,
|
||||
)
|
||||
if return_sign:
|
||||
sign_dot = jnp.zeros_like(sign)
|
||||
return (out, sign), (out_dot, sign_dot)
|
||||
else:
|
||||
return out, out_dot
|
||||
|
@ -18,7 +18,6 @@ import itertools
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import parameterized
|
||||
|
||||
import numpy as np
|
||||
import scipy.integrate
|
||||
@ -203,34 +202,6 @@ class LaxBackedScipyTests(jtu.JaxTestCase):
|
||||
y_actual = lsp_special.logsumexp(x, where=mask)
|
||||
self.assertAllClose(y_expected, y_actual, check_dtypes=False)
|
||||
|
||||
def testLogSumExpZerosJac(self):
|
||||
# Regression test for https://github.com/google/jax/issues/22398
|
||||
fun = lambda b: lsp_special.logsumexp(jnp.zeros(2), axis=0, b=b)
|
||||
np.testing.assert_array_equal(
|
||||
jax.jacfwd(fun)(jnp.array([1.0, 0.0])),
|
||||
jnp.ones(2),
|
||||
)
|
||||
|
||||
@parameterized.product(
|
||||
axis=[None, 0], with_b=[False, True], return_sign=[False, True]
|
||||
)
|
||||
def testLogSumExpJac(self, axis, with_b, return_sign):
|
||||
fun = partial(lsp_special.logsumexp, axis=axis, return_sign=return_sign)
|
||||
orig_fun = partial(
|
||||
lsp_special.logsumexp.fun, axis=axis, return_sign=return_sign
|
||||
)
|
||||
tol = 5e-5 if jtu.test_device_matches(["tpu"]) else 1e-06
|
||||
for i in range(100):
|
||||
a = jax.random.normal(jax.random.key(i), (2,)) * 4.2
|
||||
b = None
|
||||
if with_b:
|
||||
b = jax.random.uniform(jax.random.key(i), ())
|
||||
jax.tree.map(
|
||||
lambda x, y: np.testing.assert_allclose(x, y, atol=tol, rtol=tol),
|
||||
jax.jacfwd(fun)(a, b=b),
|
||||
jax.jacfwd(orig_fun)(a, b=b),
|
||||
)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=all_shapes,
|
||||
dtype=float_dtypes,
|
||||
|
Loading…
x
Reference in New Issue
Block a user