mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
This reverts commit bbcbe23c1ee52cf76542f3a60f8344832a0dd05f. This change appears to cause test failures in TF probability's JAX backend.
This commit is contained in:
parent
664a4e123d
commit
0103929930
@ -13,8 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import math
|
||||
|
||||
import numpy as onp
|
||||
import scipy.special as osp_special
|
||||
|
||||
@ -147,7 +145,7 @@ def multigammaln(a, d):
|
||||
|
||||
# Normal distributions
|
||||
|
||||
# Functions "ndtr" [... is] derived from calculations made in:
|
||||
# Functions "ndtr" and "ndtri" are derived from calculations made in:
|
||||
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
|
||||
# In the following email exchange, the author gives his consent to redistribute
|
||||
# derived works under an Apache 2.0 license.
|
||||
@ -275,7 +273,7 @@ def ndtri(p):
|
||||
This is a based on the implementation in netlib.
|
||||
|
||||
Args:
|
||||
p: an array of floating-point type.
|
||||
p: an array of type `float32`, `float64`.
|
||||
|
||||
Returns:
|
||||
an array with `dtype=p.dtype`.
|
||||
@ -283,8 +281,122 @@ def ndtri(p):
|
||||
Raises:
|
||||
TypeError: if `p` is not floating-type.
|
||||
"""
|
||||
p, = _promote_args_inexact("ndtri", p)
|
||||
return lax.erf_inv(np.asarray(p) * 2 - 1) * math.sqrt(2)
|
||||
x = np.asarray(p)
|
||||
dtype = lax.dtype(p)
|
||||
if dtype not in (np.float32, np.float64):
|
||||
raise TypeError(
|
||||
"x.dtype={} is not supported, see docstring for supported types."
|
||||
.format(dtype))
|
||||
return _ndtri(p)
|
||||
|
||||
|
||||
def _ndtri(p):
|
||||
"""Implements ndtri core logic."""
|
||||
|
||||
# Constants used in piece-wise rational approximations. Taken from the cephes
|
||||
# library:
|
||||
# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
|
||||
p0 = list(reversed([-5.99633501014107895267E1,
|
||||
9.80010754185999661536E1,
|
||||
-5.66762857469070293439E1,
|
||||
1.39312609387279679503E1,
|
||||
-1.23916583867381258016E0]))
|
||||
q0 = list(reversed([1.0,
|
||||
1.95448858338141759834E0,
|
||||
4.67627912898881538453E0,
|
||||
8.63602421390890590575E1,
|
||||
-2.25462687854119370527E2,
|
||||
2.00260212380060660359E2,
|
||||
-8.20372256168333339912E1,
|
||||
1.59056225126211695515E1,
|
||||
-1.18331621121330003142E0]))
|
||||
p1 = list(reversed([4.05544892305962419923E0,
|
||||
3.15251094599893866154E1,
|
||||
5.71628192246421288162E1,
|
||||
4.40805073893200834700E1,
|
||||
1.46849561928858024014E1,
|
||||
2.18663306850790267539E0,
|
||||
-1.40256079171354495875E-1,
|
||||
-3.50424626827848203418E-2,
|
||||
-8.57456785154685413611E-4]))
|
||||
q1 = list(reversed([1.0,
|
||||
1.57799883256466749731E1,
|
||||
4.53907635128879210584E1,
|
||||
4.13172038254672030440E1,
|
||||
1.50425385692907503408E1,
|
||||
2.50464946208309415979E0,
|
||||
-1.42182922854787788574E-1,
|
||||
-3.80806407691578277194E-2,
|
||||
-9.33259480895457427372E-4]))
|
||||
p2 = list(reversed([3.23774891776946035970E0,
|
||||
6.91522889068984211695E0,
|
||||
3.93881025292474443415E0,
|
||||
1.33303460815807542389E0,
|
||||
2.01485389549179081538E-1,
|
||||
1.23716634817820021358E-2,
|
||||
3.01581553508235416007E-4,
|
||||
2.65806974686737550832E-6,
|
||||
6.23974539184983293730E-9]))
|
||||
q2 = list(reversed([1.0,
|
||||
6.02427039364742014255E0,
|
||||
3.67983563856160859403E0,
|
||||
1.37702099489081330271E0,
|
||||
2.16236993594496635890E-1,
|
||||
1.34204006088543189037E-2,
|
||||
3.28014464682127739104E-4,
|
||||
2.89247864745380683936E-6,
|
||||
6.79019408009981274425E-9]))
|
||||
|
||||
dtype = lax.dtype(p).type
|
||||
shape = np.shape(p)
|
||||
|
||||
def _create_polynomial(var, coeffs):
|
||||
"""Compute n_th order polynomial via Horner's method."""
|
||||
coeffs = onp.array(coeffs, dtype)
|
||||
if not coeffs.size:
|
||||
return np.zeros_like(var)
|
||||
return coeffs[0] + _create_polynomial(var, coeffs[1:]) * var
|
||||
|
||||
|
||||
maybe_complement_p = np.where(p > dtype(-onp.expm1(-2.)), dtype(1.) - p, p)
|
||||
# Write in an arbitrary value in place of 0 for p since 0 will cause NaNs
|
||||
# later on. The result from the computation when p == 0 is not used so any
|
||||
# number that doesn't result in NaNs is fine.
|
||||
sanitized_mcp = np.where(
|
||||
maybe_complement_p <= dtype(0.),
|
||||
np.full(shape, dtype(0.5)),
|
||||
maybe_complement_p)
|
||||
|
||||
# Compute x for p > exp(-2): x/sqrt(2pi) = w + w**3 P0(w**2)/Q0(w**2).
|
||||
w = sanitized_mcp - dtype(0.5)
|
||||
ww = lax.square(w)
|
||||
x_for_big_p = w + w * ww * (_create_polynomial(ww, p0)
|
||||
/ _create_polynomial(ww, q0))
|
||||
x_for_big_p *= -dtype(onp.sqrt(2. * onp.pi))
|
||||
|
||||
# Compute x for p <= exp(-2): x = z - log(z)/z - (1/z) P(1/z) / Q(1/z),
|
||||
# where z = sqrt(-2. * log(p)), and P/Q are chosen between two different
|
||||
# arrays based on whether p < exp(-32).
|
||||
z = lax.sqrt(dtype(-2.) * lax.log(sanitized_mcp))
|
||||
first_term = z - lax.log(z) / z
|
||||
second_term_small_p = (
|
||||
_create_polynomial(dtype(1.) / z, p2) /
|
||||
_create_polynomial(dtype(1.) / z, q2) / z)
|
||||
second_term_otherwise = (
|
||||
_create_polynomial(dtype(1.) / z, p1) /
|
||||
_create_polynomial(dtype(1.) / z, q1) / z)
|
||||
x_for_small_p = first_term - second_term_small_p
|
||||
x_otherwise = first_term - second_term_otherwise
|
||||
|
||||
x = np.where(sanitized_mcp > dtype(onp.exp(-2.)),
|
||||
x_for_big_p,
|
||||
np.where(z >= dtype(8.0), x_for_small_p, x_otherwise))
|
||||
|
||||
x = np.where(p > dtype(1. - onp.exp(-2.)), x, -x)
|
||||
infinity = np.full(shape, dtype(onp.inf))
|
||||
x_nan_replaced = np.where(
|
||||
p <= dtype(0.0), -infinity, np.where(p >= dtype(1.0), infinity, x))
|
||||
return x_nan_replaced
|
||||
|
||||
|
||||
@custom_transforms
|
||||
|
@ -128,8 +128,10 @@ LAX_OPS = [
|
||||
jtu.rand_positive, {onp.float64: 1e-14}),
|
||||
op_record("erf", 1, float_dtypes, jtu.rand_small),
|
||||
op_record("erfc", 1, float_dtypes, jtu.rand_small),
|
||||
# TODO(b/142976030): the approximation of erfinf used by XLA is only
|
||||
# accurate to float32 precision.
|
||||
op_record("erf_inv", 1, float_dtypes, jtu.rand_small,
|
||||
{onp.float64: 1e-14}),
|
||||
{onp.float64: 1e-9}),
|
||||
op_record("bessel_i0e", 1, float_dtypes, jtu.rand_default),
|
||||
op_record("bessel_i1e", 1, float_dtypes, jtu.rand_default),
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user