mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[x64] make scipy_stats_test.py compatible with strict dtype promotion
This commit is contained in:
parent
cd565f8f41
commit
f00d706a6d
@ -32,13 +32,13 @@ def logpdf(x, mean, cov, allow_singular=None):
|
||||
x, mean, cov = _promote_dtypes_inexact(x, mean, cov)
|
||||
if not mean.shape:
|
||||
return (-1/2 * jnp.square(x - mean) / cov
|
||||
- 1/2 * (np.log(2*np.pi) + jnp.log(cov)))
|
||||
- 1/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
|
||||
else:
|
||||
n = mean.shape[-1]
|
||||
if not np.shape(cov):
|
||||
y = x - mean
|
||||
return (-1/2 * jnp.einsum('...i,...i->...', y, y) / cov
|
||||
- n/2 * (np.log(2*np.pi) + jnp.log(cov)))
|
||||
- n/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
|
||||
else:
|
||||
if cov.ndim < 2 or cov.shape[-2:] != (n, n):
|
||||
raise ValueError("multivariate_normal.logpdf got incompatible shapes")
|
||||
@ -47,7 +47,7 @@ def logpdf(x, mean, cov, allow_singular=None):
|
||||
partial(lax.linalg.triangular_solve, lower=True, transpose_a=True),
|
||||
signature="(n,n),(n)->(n)"
|
||||
)(L, x - mean)
|
||||
return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2*np.log(2*np.pi)
|
||||
return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi)
|
||||
- jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
|
||||
|
||||
@_wraps(osp_stats.multivariate_normal.pdf, update_doc=False)
|
||||
|
@ -34,6 +34,12 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
|
||||
scipy_version = tuple(map(int, osp.version.version.split('.')[:2]))
|
||||
|
||||
|
||||
def _strict_promotion_if_dtypes_match(dtypes):
|
||||
if all(dtype == dtypes[0] for dtype in dtypes):
|
||||
return jax.numpy_dtype_promotion('strict')
|
||||
return jax.numpy_dtype_promotion('standard')
|
||||
|
||||
|
||||
def genNamedParametersNArgs(n):
|
||||
return parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
@ -58,13 +64,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
k, mu, loc = map(rng, shapes, dtypes)
|
||||
k = np.floor(k)
|
||||
# clipping to ensure that rate parameter is strictly positive
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
|
||||
loc = np.floor(loc)
|
||||
return [k, mu, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testPoissonPmf(self, shapes, dtypes):
|
||||
@ -76,13 +83,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
k, mu, loc = map(rng, shapes, dtypes)
|
||||
k = np.floor(k)
|
||||
# clipping to ensure that rate parameter is strictly positive
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
|
||||
loc = np.floor(loc)
|
||||
return [k, mu, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testPoissonCdf(self, shapes, dtypes):
|
||||
@ -93,12 +101,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
k, mu, loc = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that rate parameter is strictly positive
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
|
||||
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
|
||||
return [k, mu, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
@ -114,9 +123,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
loc = np.floor(loc)
|
||||
return [x, p, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testGeomLogPmf(self, shapes, dtypes):
|
||||
@ -131,9 +141,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
loc = np.floor(loc)
|
||||
return [x, p, loc]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaLogPdf(self, shapes, dtypes):
|
||||
@ -145,10 +156,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, a, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, b, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float32: 2e-3, np.float64: 1e-4})
|
||||
|
||||
def testBetaLogPdfZero(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7645
|
||||
@ -166,12 +178,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
@ -208,9 +221,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
return x, alpha
|
||||
|
||||
tol = {np.float32: 1E-3, np.float64: 1e-5}
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=tol)
|
||||
self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testExponLogPdf(self, shapes, dtypes):
|
||||
@ -222,9 +237,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testGammaLogPdf(self, shapes, dtypes):
|
||||
@ -236,9 +252,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, a, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, a, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
def testGammaLogPdfZero(self):
|
||||
# Regression test for https://github.com/google/jax/issues/7256
|
||||
@ -255,9 +272,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, p = map(rng, shapes, dtypes)
|
||||
return [x, p]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4, rtol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4, rtol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(2)
|
||||
def testGenNormCdf(self, shapes, dtypes):
|
||||
@ -269,9 +287,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, p = map(rng, shapes, dtypes)
|
||||
return [x, p]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4, rtol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4, rtol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testNBinomLogPmf(self, shapes, dtypes):
|
||||
@ -288,9 +307,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
return [k, n, p, loc]
|
||||
|
||||
tol = {np.float32: 1e-6, np.float64: 1e-8}
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLaplaceLogPdf(self, shapes, dtypes):
|
||||
@ -301,12 +322,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None)
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
def testLaplaceCdf(self, shapes, dtypes):
|
||||
@ -317,12 +339,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure that scale is not too low
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None)
|
||||
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol={np.float32: 1e-5, np.float64: 1e-6})
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol={np.float32: 1e-5, np.float64: 1e-6})
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
def testLogisticCdf(self, shapes, dtypes):
|
||||
@ -333,9 +356,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
return list(map(rng, shapes, dtypes))
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(1)
|
||||
def testLogisticLogpdf(self, shapes, dtypes):
|
||||
@ -392,12 +416,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
@ -409,12 +434,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
@ -426,12 +452,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-6)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
@ -443,13 +470,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
q, loc, scale = map(rng, shapes, dtypes)
|
||||
# ensure probability is between 0 and 1:
|
||||
q = np.clip(np.abs(q / 3), a_min=None, a_max=1)
|
||||
q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [q, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
@ -462,9 +490,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, b, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, b, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
@ -476,13 +505,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
def args_maker():
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
# clipping to ensure that scale is not too low
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
|
||||
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float64: 1e-14}, atol={np.float64: 1e-14})
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-3)
|
||||
self._CompileAndCheck(lax_fun, args_maker,
|
||||
rtol={np.float64: 1e-14}, atol={np.float64: 1e-14})
|
||||
|
||||
|
||||
@genNamedParametersNArgs(3)
|
||||
@ -495,9 +525,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, loc, np.abs(scale)]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=1e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(4)
|
||||
def testChi2LogPdf(self, shapes, dtypes):
|
||||
@ -509,9 +540,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
x, df, loc, scale = map(rng, shapes, dtypes)
|
||||
return [x, df, loc, scale]
|
||||
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker)
|
||||
|
||||
@genNamedParametersNArgs(5)
|
||||
def testBetaBinomLogPmf(self, shapes, dtypes):
|
||||
@ -522,16 +554,17 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
k, n, a, b, loc = map(rng, shapes, dtypes)
|
||||
k = np.floor(k)
|
||||
n = np.ceil(n)
|
||||
a = np.clip(a, a_min = 0.1, a_max = None)
|
||||
b = np.clip(a, a_min = 0.1, a_max = None)
|
||||
a = np.clip(a, a_min = 0.1, a_max=None).astype(a.dtype)
|
||||
b = np.clip(a, a_min = 0.1, a_max=None).astype(b.dtype)
|
||||
loc = np.floor(loc)
|
||||
return [k, n, a, b, loc]
|
||||
|
||||
if scipy_version >= (1, 4):
|
||||
scipy_fun = osp_stats.betabinom.logpmf
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
|
||||
with _strict_promotion_if_dtypes_match(dtypes):
|
||||
if scipy_version >= (1, 4):
|
||||
scipy_fun = osp_stats.betabinom.logpmf
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
|
||||
tol=5e-4)
|
||||
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
|
||||
|
||||
def testIssue972(self):
|
||||
self.assertAllClose(
|
||||
@ -584,11 +617,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
|
||||
factor = rng(factor_shape, cov_dtype)
|
||||
args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
|
||||
return args
|
||||
return [a.astype(x_dtype) for a in args]
|
||||
|
||||
self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
|
||||
lsp_stats.multivariate_normal.logpdf,
|
||||
args_maker, tol=1e-3)
|
||||
args_maker, tol=1e-3, check_dtypes=False)
|
||||
self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
|
||||
rtol=1e-4, atol=1e-4)
|
||||
|
||||
@ -631,13 +664,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
|
||||
factor = rng(factor_shape, cov_dtype)
|
||||
args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
|
||||
return args
|
||||
return [a.astype(x_dtype) for a in args]
|
||||
|
||||
osp_fun = np.vectorize(osp_stats.multivariate_normal.logpdf,
|
||||
signature="(n),(n),(n,n)->()")
|
||||
|
||||
self._CheckAgainstNumpy(osp_fun, lsp_stats.multivariate_normal.logpdf,
|
||||
args_maker, tol=1e-3)
|
||||
args_maker, tol=1e-3, check_dtypes=False)
|
||||
self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
|
||||
rtol=1e-4, atol=1e-4)
|
||||
|
||||
@ -658,7 +691,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
|
||||
result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
|
||||
result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
|
||||
self.assertArraysEqual(result1, result2)
|
||||
self.assertArraysEqual(result1, result2, check_dtypes=False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user