Merge pull request #24864 from jakevdp:logaddexp2

PiperOrigin-RevId: 696602883
This commit is contained in:
jax authors 2024-11-14 11:56:42 -08:00
commit 303b792ac2

View File

@ -2630,16 +2630,6 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
return lax_other.logaddexp(x1, x2)
def _wrap_between(x, _a):
"""Wraps `x` between `[-a, a]`."""
a = _constant_like(x, _a)
two_a = _constant_like(x, 2 * _a)
zero = _constant_like(x, 0)
rem = lax.rem(lax.add(x, a), two_a)
rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
return lax.sub(rem, a)
@jit
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
"""Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.
@ -2668,33 +2658,8 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
Array(True, dtype=bool)
"""
x1, x2 = promote_args_inexact("logaddexp2", x1, x2)
return _logaddexp2(x1, x2)
@custom_jvp
def _logaddexp2(x1, x2):
amax = lax.max(x1, x2)
if dtypes.issubdtype(x1.dtype, np.floating):
delta = lax.sub(x1, x2)
return lax.select(lax._isnan(delta),
lax.add(x1, x2), # NaNs or infinities of the same sign.
lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
_constant_like(x1, np.log(2)))))
else:
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2))))
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
@_logaddexp2.defjvp
def _logaddexp2_jvp(primals, tangents):
x1, x2 = primals
t1, t2 = tangents
x1, x2, t1, t2 = promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
primal_out = logaddexp2(x1, x2)
tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
return primal_out, tangent_out
ln2 = float(np.log(2))
return logaddexp(x1 * ln2, x2 * ln2) / ln2
@partial(jit, inline=True)