Revert "Use lax.erf_inv to implement ndtri. (#2122)" (#2128)

This reverts commit bbcbe23c1ee52cf76542f3a60f8344832a0dd05f.

This change appears to cause test failures in TF probability's JAX backend.
This commit is contained in:
Peter Hawkins 2020-01-30 19:19:41 -05:00 committed by GitHub
parent 664a4e123d
commit 0103929930
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 121 additions and 7 deletions

View File

@ -13,8 +13,6 @@
# limitations under the License.
import math
import numpy as onp
import scipy.special as osp_special
@ -147,7 +145,7 @@ def multigammaln(a, d):
# Normal distributions
# Functions "ndtr" [... is] derived from calculations made in:
# Functions "ndtr" and "ndtri" are derived from calculations made in:
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
# In the following email exchange, the author gives his consent to redistribute
# derived works under an Apache 2.0 license.
@ -275,7 +273,7 @@ def ndtri(p):
This is a based on the implementation in netlib.
Args:
p: an array of floating-point type.
p: an array of type `float32`, `float64`.
Returns:
an array with `dtype=p.dtype`.
@ -283,8 +281,122 @@ def ndtri(p):
Raises:
TypeError: if `p` is not floating-type.
"""
p, = _promote_args_inexact("ndtri", p)
return lax.erf_inv(np.asarray(p) * 2 - 1) * math.sqrt(2)
x = np.asarray(p)
dtype = lax.dtype(p)
if dtype not in (np.float32, np.float64):
raise TypeError(
"x.dtype={} is not supported, see docstring for supported types."
.format(dtype))
return _ndtri(p)
def _ndtri(p):
"""Implements ndtri core logic."""
# Constants used in piece-wise rational approximations. Taken from the cephes
# library:
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
p0 = list(reversed([-5.99633501014107895267E1,
9.80010754185999661536E1,
-5.66762857469070293439E1,
1.39312609387279679503E1,
-1.23916583867381258016E0]))
q0 = list(reversed([1.0,
1.95448858338141759834E0,
4.67627912898881538453E0,
8.63602421390890590575E1,
-2.25462687854119370527E2,
2.00260212380060660359E2,
-8.20372256168333339912E1,
1.59056225126211695515E1,
-1.18331621121330003142E0]))
p1 = list(reversed([4.05544892305962419923E0,
3.15251094599893866154E1,
5.71628192246421288162E1,
4.40805073893200834700E1,
1.46849561928858024014E1,
2.18663306850790267539E0,
-1.40256079171354495875E-1,
-3.50424626827848203418E-2,
-8.57456785154685413611E-4]))
q1 = list(reversed([1.0,
1.57799883256466749731E1,
4.53907635128879210584E1,
4.13172038254672030440E1,
1.50425385692907503408E1,
2.50464946208309415979E0,
-1.42182922854787788574E-1,
-3.80806407691578277194E-2,
-9.33259480895457427372E-4]))
p2 = list(reversed([3.23774891776946035970E0,
6.91522889068984211695E0,
3.93881025292474443415E0,
1.33303460815807542389E0,
2.01485389549179081538E-1,
1.23716634817820021358E-2,
3.01581553508235416007E-4,
2.65806974686737550832E-6,
6.23974539184983293730E-9]))
q2 = list(reversed([1.0,
6.02427039364742014255E0,
3.67983563856160859403E0,
1.37702099489081330271E0,
2.16236993594496635890E-1,
1.34204006088543189037E-2,
3.28014464682127739104E-4,
2.89247864745380683936E-6,
6.79019408009981274425E-9]))
dtype = lax.dtype(p).type
shape = np.shape(p)
def _create_polynomial(var, coeffs):
"""Compute n_th order polynomial via Horner's method."""
coeffs = onp.array(coeffs, dtype)
if not coeffs.size:
return np.zeros_like(var)
return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
maybe_complement_p = np.where(p > dtype(-onp.expm1(-2.)), dtype(1.) - p, p)
# Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
# later on. The result from the computation when p == 0 is not used so any
# number that doesn't result in NaNs is fine.
sanitized_mcp = np.where(
maybe_complement_p <= dtype(0.),
np.full(shape, dtype(0.5)),
maybe_complement_p)
# Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
w = sanitized_mcp - dtype(0.5)
ww = lax.square(w)
x_for_big_p = w + w * ww * (_create_polynomial(ww, p0)
/ _create_polynomial(ww, q0))
x_for_big_p *= -dtype(onp.sqrt(2. * onp.pi))
# Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
# where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
# arrays based on whether p < exp(-32).
z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp))
first_term = z - lax.log(z) / z
second_term_small_p = (
_create_polynomial(dtype(1.) / z, p2) /
_create_polynomial(dtype(1.) / z, q2) / z)
second_term_otherwise = (
_create_polynomial(dtype(1.) / z, p1) /
_create_polynomial(dtype(1.) / z, q1) / z)
x_for_small_p = first_term - second_term_small_p
x_otherwise = first_term - second_term_otherwise
x = np.where(sanitized_mcp > dtype(onp.exp(-2.)),
x_for_big_p,
np.where(z >= dtype(8.0), x_for_small_p, x_otherwise))
x = np.where(p > dtype(1. - onp.exp(-2.)), x, -x)
infinity = np.full(shape, dtype(onp.inf))
x_nan_replaced = np.where(
p <= dtype(0.0), -infinity, np.where(p >= dtype(1.0), infinity, x))
return x_nan_replaced
@custom_transforms

View File

@ -128,8 +128,10 @@ LAX_OPS = [
jtu.rand_positive, {onp.float64: 1e-14}),
op_record("erf", 1, float_dtypes, jtu.rand_small),
op_record("erfc", 1, float_dtypes, jtu.rand_small),
# TODO(b/142976030): the approximation of erfinf used by XLA is only
# accurate to float32 precision.
op_record("erf_inv", 1, float_dtypes, jtu.rand_small,
{onp.float64: 1e-14}),
{onp.float64: 1e-9}),
op_record("bessel_i0e", 1, float_dtypes, jtu.rand_default),
op_record("bessel_i1e", 1, float_dtypes, jtu.rand_default),