Improve behavior of a number of math functions for extreme inputs.

Call XLA's sqrt instead of defining sqrt to be x**0.5. The two have different behaviors for infinite inputs.

Incorporate improvements to acos, sinh, cosh, asinh, and acosh that have previously been made to the versions in the XLA C++ client libraries.
This commit is contained in:
Peter Hawkins 2019-05-29 12:51:24 -04:00
parent d852830639
commit 6e1ec38a14
3 changed files with 62 additions and 15 deletions

View File

@ -1182,7 +1182,7 @@ def batch_matmul(lhs, rhs):
def sqrt(x):
r"""Elementwise square root: :math:`\sqrt{x}`."""
return pow(x, _const(x, 0.5))
return sqrt_p.bind(x)
def rsqrt(x):
r"""Elementwise reciprocal square root: :math:`1 \over \sqrt{x}`."""
@ -1207,8 +1207,11 @@ def asin(x):
def acos(x):
r"""Elementwise arc cosine: :math:`\mathrm{acos}(x)`."""
return mul(_const(x, 2),
atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x)))
return select(
ne(x, _const(x, -1.0)),
mul(_const(x, 2),
atan2(sqrt(sub(_const(x, 1), square(x))), add(_const(x, 1), x))),
full_like(x, onp.pi))
def atan(x):
r"""Elementwise arc tangent: :math:`\mathrm{atan}(x)`."""
@ -1216,22 +1219,42 @@ def atan(x):
def sinh(x):
r"""Elementwise hyperbolic sine: :math:`\mathrm{sinh}(x)`."""
return mul(_const(x, 0.5), sub(exp(x), exp(neg(x))))
log_half = _const(x, onp.log(0.5))
# This formulation avoids overflow when e^x is inf but e^x/2 is not inf.
return sub(exp(add(log_half, x)), exp(sub(log_half, x)))
def cosh(x):
r"""Elementwise hyperbolic cosine: :math:`\mathrm{cosh}(x)`."""
return mul(_const(x, 0.5), add(exp(x), exp(neg(x))))
log_half = _const(x, onp.log(0.5))
# This formulation avoids overflow when e^x is inf but e^x/2 is not inf.
return add(exp(add(log_half, x)), exp(sub(log_half, x)))
def asinh(x):
r"""Elementwise arc hyperbolic sine: :math:`\mathrm{asinh}(x)`."""
# asinh(x) = log(x + sqrt(x**2 + 1))
return log(add(x, sqrt(add(mul(x, x), _const(x, 1)))))
result = log(add(x, sqrt(add(mul(x, x), _const(x, 1)))))
if onp.issubdtype(_dtype(result), onp.complexfloating):
return result
a = abs(x)
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
return select(lt(a, _const(a, sqrt_max_value)),
result,
mul(sign(x), add(log(a), _const(a, onp.log(2.)))))
def acosh(x):
r"""Elementwise arc hyperbolic cosine: :math:`\mathrm{acosh}(x)`."""
# acosh(x) = log(x + sqrt((x + 1) * (x - 1)))
return log(add(x, mul(sqrt(add(x, _const(x, 1))),
sqrt(sub(x, _const(x, 1))))))
# acosh(x) = log(x + sqrt((x + 1) * (x - 1))) if x < sqrt_max_value
# log(x) + log(2) otherwise
sqrt_max_value = onp.sqrt(onp.finfo(_dtype(x)).max)
result = log(add(x, mul(sqrt(add(x, _const(x, 1))),
sqrt(sub(x, _const(x, 1))))))
if onp.issubdtype(_dtype(result), onp.complexfloating):
return result
return select(
lt(x, _const(x, sqrt_max_value)),
result,
add(log(x), _const(x, onp.log(2.))))
def atanh(x):
r"""Elementwise arc hyperbolic tangent: :math:`\mathrm{atanh}(x)`."""
@ -1488,6 +1511,9 @@ ad.defjvp2(abs_p,
_maybe_conj = lambda x: conj(x) if _iscomplex(x) else x
_maybe_real = lambda x: real(x) if _iscomplex(x) else x
sqrt_p = standard_unop(_float | _complex, 'sqrt')
ad.defjvp2(sqrt_p, lambda g, ans, x: _safe_mul(g, div(_const(x, 0.5), ans)))
# TODO handle broadcasting
pow_p = standard_binop([_float | _complex, _float | _complex], 'pow')

View File

@ -274,6 +274,7 @@ tanh = _one_to_one_unop(onp.tanh, lax.tanh, True)
arcsinh = _one_to_one_unop(onp.arcsinh, lax.asinh, True)
arccosh = _one_to_one_unop(onp.arccosh, lax.acosh, True)
arctanh = _one_to_one_unop(onp.arctanh, lax.atanh, True)
sqrt = _one_to_one_unop(onp.sqrt, lax.sqrt, True)
add = _one_to_one_binop(onp.add, lax.add)
@ -460,12 +461,6 @@ def cbrt(x):
return lax.sign(x) * power(lax.abs(x), _constant_like(x, 1. / 3.))
@_wraps(onp.sqrt)
def sqrt(x):
x, = _promote_to_result_dtype(onp.sqrt, x)
return power(x, _constant_like(x, 0.5))
@_wraps(onp.square)
def square(x):
x, = _promote_to_result_dtype(onp.square, x)

View File

@ -1564,6 +1564,32 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self.assertAllClose(onp.zeros(3,), api.grad(f)(onp.ones(3,)),
check_dtypes=True)
@parameterized.named_parameters(
jtu.cases_from_list(
{"testcase_name": jtu.format_test_name_suffix(op, [()], [dtype]),
"dtype": dtype, "op": op}
for dtype in float_dtypes
for op in ("sqrt", "arccos", "arcsin", "arctan", "sin", "cos", "tan",
"sinh", "cosh", "tanh", "arccosh", "arcsinh", "arctanh", "exp",
"log", "expm1", "log1p")))
def testMathSpecialFloatValues(self, op, dtype):
onp_op = getattr(onp, op)
lnp_op = getattr(lnp, op)
dtype = onp.dtype(xla_bridge.canonicalize_dtype(dtype)).type
for x in (onp.nan, -onp.inf, -100., -2. -1., 0., 1., 2., 100., onp.inf,
onp.finfo(dtype).max, onp.sqrt(onp.finfo(dtype).max),
onp.sqrt(onp.finfo(dtype).max) * 2.):
if onp.isnan(x) and op in ("cosh", "expm1", "exp"):
# TODO(b/133842876, b/)133842870: these return wrong outputs on CPU for
# NaN inputs.
continue
x = dtype(x)
expected = onp_op(x)
actual = lnp_op(x)
if expected != actual:
print(x, expected, actual)
self.assertAllClose(expected, actual, check_dtypes=True)
if __name__ == "__main__":
absltest.main()