mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Improve behavior of a number of math functions for extreme inputs.
Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs. Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.
This commit is contained in:
parent
d852830639
commit
6e1ec38a14
@ -1182,7 +1182,7 @@ def batch_matmul(lhs, rhs):
|
||||
|
||||
def sqrt(x):
|
||||
r"""Elementwise square root: :math:`\sqrt{x}`."""
|
||||
return pow(x, _const(x, 0.5))
|
||||
return sqrt_p.bind(x)
|
||||
|
||||
def rsqrt(x):
|
||||
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`."""
|
||||
@ -1207,8 +1207,11 @@ def asin(x):
|
||||
|
||||
def acos(x):
|
||||
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
|
||||
return mul(_const(x, 2),
|
||||
atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x)))
|
||||
return select(
|
||||
ne(x, _const(x, -1.0)),
|
||||
mul(_const(x, 2),
|
||||
atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))),
|
||||
full_like(x, onp.pi))
|
||||
|
||||
def atan(x):
|
||||
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
|
||||
@ -1216,22 +1219,42 @@ def atan(x):
|
||||
|
||||
def sinh(x):
|
||||
r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`."""
|
||||
return mul(_const(x, 0.5), sub(exp(x), exp(neg(x))))
|
||||
log_half = _const(x, onp.log(0.5))
|
||||
# This formulation avoids overflow when e^x is inf but e^x/2 is not inf.
|
||||
return sub(exp(add(log_half, x)), exp(sub(log_half, x)))
|
||||
|
||||
def cosh(x):
|
||||
r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`."""
|
||||
return mul(_const(x, 0.5), add(exp(x), exp(neg(x))))
|
||||
log_half = _const(x, onp.log(0.5))
|
||||
# This formulation avoids overflow when e^x is inf but e^x/2 is not inf.
|
||||
return add(exp(add(log_half, x)), exp(sub(log_half, x)))
|
||||
|
||||
def asinh(x):
|
||||
r"""Elementwise arc hyperbolic sine: :math:`\mathrm{asinh}(x)`."""
|
||||
# asinh(x) = log(x + sqrt(x**2 + 1))
|
||||
return log(add(x, sqrt(add(mul(x, x), _const(x, 1)))))
|
||||
result = log(add(x, sqrt(add(mul(x, x), _const(x, 1)))))
|
||||
if onp.issubdtype(_dtype(result), onp.complexfloating):
|
||||
return result
|
||||
a = abs(x)
|
||||
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
|
||||
return select(lt(a, _const(a, sqrt_max_value)),
|
||||
result,
|
||||
mul(sign(x), add(log(a), _const(a, onp.log(2.)))))
|
||||
|
||||
def acosh(x):
|
||||
r"""Elementwise arc hyperbolic cosine: :math:`\mathrm{acosh}(x)`."""
|
||||
# acosh(x) = log(x + sqrt((x + 1) * (x - 1)))
|
||||
return log(add(x, mul(sqrt(add(x, _const(x, 1))),
|
||||
sqrt(sub(x, _const(x, 1))))))
|
||||
# acosh(x) = log(x + sqrt((x + 1) * (x - 1))) if x < sqrt_max_value
|
||||
# log(x) + log(2) otherwise
|
||||
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
|
||||
result = log(add(x, mul(sqrt(add(x, _const(x, 1))),
|
||||
sqrt(sub(x, _const(x, 1))))))
|
||||
if onp.issubdtype(_dtype(result), onp.complexfloating):
|
||||
return result
|
||||
return select(
|
||||
lt(x, _const(x, sqrt_max_value)),
|
||||
result,
|
||||
add(log(x), _const(x, onp.log(2.))))
|
||||
|
||||
|
||||
def atanh(x):
|
||||
r"""Elementwise arc hyperbolic tangent: :math:`\mathrm{atanh}(x)`."""
|
||||
@ -1488,6 +1511,9 @@ ad.defjvp2(abs_p,
|
||||
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
|
||||
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
|
||||
|
||||
sqrt_p = standard_unop(_float | _complex, 'sqrt')
|
||||
ad.defjvp2(sqrt_p, lambda g, ans, x: _safe_mul(g, div(_const(x, 0.5), ans)))
|
||||
|
||||
# TODO handle broadcasting
|
||||
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')
|
||||
|
||||
|
@ -274,6 +274,7 @@ tanh = _one_to_one_unop(onp.tanh, lax.tanh, True)
|
||||
arcsinh = _one_to_one_unop(onp.arcsinh, lax.asinh, True)
|
||||
arccosh = _one_to_one_unop(onp.arccosh, lax.acosh, True)
|
||||
arctanh = _one_to_one_unop(onp.arctanh, lax.atanh, True)
|
||||
sqrt = _one_to_one_unop(onp.sqrt, lax.sqrt, True)
|
||||
|
||||
|
||||
add = _one_to_one_binop(onp.add, lax.add)
|
||||
@ -460,12 +461,6 @@ def cbrt(x):
|
||||
return lax.sign(x) * power(lax.abs(x), _constant_like(x, 1. / 3.))
|
||||
|
||||
|
||||
@_wraps(onp.sqrt)
|
||||
def sqrt(x):
|
||||
x, = _promote_to_result_dtype(onp.sqrt, x)
|
||||
return power(x, _constant_like(x, 0.5))
|
||||
|
||||
|
||||
@_wraps(onp.square)
|
||||
def square(x):
|
||||
x, = _promote_to_result_dtype(onp.square, x)
|
||||
|
@ -1564,6 +1564,32 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)),
|
||||
check_dtypes=True)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]),
|
||||
"dtype": dtype, "op": op}
|
||||
for dtype in float_dtypes
|
||||
for op in ("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan",
|
||||
"sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp",
|
||||
"log", "expm1", "log1p")))
|
||||
def testMathSpecialFloatValues(self, op, dtype):
|
||||
onp_op = getattr(onp, op)
|
||||
lnp_op = getattr(lnp, op)
|
||||
dtype = onp.dtype(xla_bridge.canonicalize_dtype(dtype)).type
|
||||
for x in (onp.nan, -onp.inf, -100., -2. -1., 0., 1., 2., 100., onp.inf,
|
||||
onp.finfo(dtype).max, onp.sqrt(onp.finfo(dtype).max),
|
||||
onp.sqrt(onp.finfo(dtype).max) * 2.):
|
||||
if onp.isnan(x) and op in ("cosh", "expm1", "exp"):
|
||||
# TODO(b/133842876, b/)133842870: these return wrong outputs on CPU for
|
||||
# NaN inputs.
|
||||
continue
|
||||
x = dtype(x)
|
||||
expected = onp_op(x)
|
||||
actual = lnp_op(x)
|
||||
if expected != actual:
|
||||
print(x, expected, actual)
|
||||
self.assertAllClose(expected, actual, check_dtypes=True)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main()
|
||||
|
Loading…
x
Reference in New Issue
Block a user