mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
jnp.logaddexp2: simplify implementation
This commit is contained in:
parent
19a51de2ab
commit
d823f1720d
@ -2630,16 +2630,6 @@ def logaddexp(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
return lax_other.logaddexp(x1, x2)
|
||||
|
||||
|
||||
def _wrap_between(x, _a):
|
||||
"""Wraps `x` between `[-a, a]`."""
|
||||
a = _constant_like(x, _a)
|
||||
two_a = _constant_like(x, 2 * _a)
|
||||
zero = _constant_like(x, 0)
|
||||
rem = lax.rem(lax.add(x, a), two_a)
|
||||
rem = lax.select(lax.lt(rem, zero), lax.add(rem, two_a), rem)
|
||||
return lax.sub(rem, a)
|
||||
|
||||
|
||||
@jit
|
||||
def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
"""Logarithm of the sum of exponentials of inputs in base-2 avoiding overflow.
|
||||
@ -2668,33 +2658,8 @@ def logaddexp2(x1: ArrayLike, x2: ArrayLike, /) -> Array:
|
||||
Array(True, dtype=bool)
|
||||
"""
|
||||
x1, x2 = promote_args_inexact("logaddexp2", x1, x2)
|
||||
return _logaddexp2(x1, x2)
|
||||
|
||||
|
||||
@custom_jvp
|
||||
def _logaddexp2(x1, x2):
|
||||
amax = lax.max(x1, x2)
|
||||
if dtypes.issubdtype(x1.dtype, np.floating):
|
||||
delta = lax.sub(x1, x2)
|
||||
return lax.select(lax._isnan(delta),
|
||||
lax.add(x1, x2), # NaNs or infinities of the same sign.
|
||||
lax.add(amax, lax.div(lax.log1p(exp2(lax.neg(lax.abs(delta)))),
|
||||
_constant_like(x1, np.log(2)))))
|
||||
else:
|
||||
delta = lax.sub(lax.add(x1, x2), lax.mul(amax, _constant_like(amax, 2)))
|
||||
out = lax.add(amax, lax.div(lax.log1p(exp2(delta)), _constant_like(x1, np.log(2))))
|
||||
return lax.complex(lax.real(out), _wrap_between(lax.imag(out), np.pi / np.log(2)))
|
||||
|
||||
|
||||
@_logaddexp2.defjvp
|
||||
def _logaddexp2_jvp(primals, tangents):
|
||||
x1, x2 = primals
|
||||
t1, t2 = tangents
|
||||
x1, x2, t1, t2 = promote_args_inexact("logaddexp2_jvp", x1, x2, t1, t2)
|
||||
primal_out = logaddexp2(x1, x2)
|
||||
tangent_out = lax.add(lax.mul(t1, exp2(lax.sub(_replace_inf(x1), _replace_inf(primal_out)))),
|
||||
lax.mul(t2, exp2(lax.sub(_replace_inf(x2), _replace_inf(primal_out)))))
|
||||
return primal_out, tangent_out
|
||||
ln2 = float(np.log(2))
|
||||
return logaddexp(x1 * ln2, x2 * ln2) / ln2
|
||||
|
||||
|
||||
@partial(jit, inline=True)
|
||||
|
Loading…
x
Reference in New Issue
Block a user