[x64] make scipy_stats_test.py compatible with strict dtype promotion

This commit is contained in:
Jake VanderPlas 2022-06-14 14:47:58 -07:00
parent cd565f8f41
commit f00d706a6d
2 changed files with 133 additions and 100 deletions

View File

@ -32,13 +32,13 @@ def logpdf(x, mean, cov, allow_singular=None):
x, mean, cov = _promote_dtypes_inexact(x, mean, cov)
if not mean.shape:
return (-1/2 * jnp.square(x - mean) / cov
- 1/2 * (np.log(2*np.pi) + jnp.log(cov)))
- 1/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
else:
n = mean.shape[-1]
if not np.shape(cov):
y = x - mean
return (-1/2 * jnp.einsum('...i,...i->...', y, y) / cov
- n/2 * (np.log(2*np.pi) + jnp.log(cov)))
- n/2 * (jnp.log(2*np.pi) + jnp.log(cov)))
else:
if cov.ndim < 2 or cov.shape[-2:] != (n, n):
raise ValueError("multivariate_normal.logpdf got incompatible shapes")
@ -47,7 +47,7 @@ def logpdf(x, mean, cov, allow_singular=None):
partial(lax.linalg.triangular_solve, lower=True, transpose_a=True),
signature="(n,n),(n)->(n)"
)(L, x - mean)
return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2*np.log(2*np.pi)
return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi)
- jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1))
@_wraps(osp_stats.multivariate_normal.pdf, update_doc=False)

View File

@ -34,6 +34,12 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
scipy_version = tuple(map(int, osp.version.version.split('.')[:2]))
def _strict_promotion_if_dtypes_match(dtypes):
if all(dtype == dtypes[0] for dtype in dtypes):
return jax.numpy_dtype_promotion('strict')
return jax.numpy_dtype_promotion('standard')
def genNamedParametersNArgs(n):
return parameterized.named_parameters(
jtu.cases_from_list(
@ -58,13 +64,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
k, mu, loc = map(rng, shapes, dtypes)
k = np.floor(k)
# clipping to ensure that rate parameter is strictly positive
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
loc = np.floor(loc)
return [k, mu, loc]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
def testPoissonPmf(self, shapes, dtypes):
@ -76,13 +83,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
k, mu, loc = map(rng, shapes, dtypes)
k = np.floor(k)
# clipping to ensure that rate parameter is strictly positive
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
loc = np.floor(loc)
return [k, mu, loc]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testPoissonCdf(self, shapes, dtypes):
@ -93,12 +101,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
k, mu, loc = map(rng, shapes, dtypes)
# clipping to ensure that rate parameter is strictly positive
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None)
mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype)
return [k, mu, loc]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
@ -114,9 +123,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
loc = np.floor(loc)
return [x, p, loc]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testGeomLogPmf(self, shapes, dtypes):
@ -131,9 +141,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
loc = np.floor(loc)
return [x, p, loc]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testBetaLogPdf(self, shapes, dtypes):
@ -145,10 +156,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
x, a, b, loc, scale = map(rng, shapes, dtypes)
return [x, a, b, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float32: 2e-3, np.float64: 1e-4})
def testBetaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7645
@ -166,12 +178,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@parameterized.named_parameters(
jtu.cases_from_list(
@ -208,9 +221,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
return x, alpha
tol = {np.float32: 1E-3, np.float64: 1e-5}
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=tol)
self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol)
@genNamedParametersNArgs(3)
def testExponLogPdf(self, shapes, dtypes):
@ -222,9 +237,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testGammaLogPdf(self, shapes, dtypes):
@ -236,9 +252,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
x, a, loc, scale = map(rng, shapes, dtypes)
return [x, a, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
def testGammaLogPdfZero(self):
# Regression test for https://github.com/google/jax/issues/7256
@ -255,9 +272,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
x, p = map(rng, shapes, dtypes)
return [x, p]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(2)
def testGenNormCdf(self, shapes, dtypes):
@ -269,9 +287,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
x, p = map(rng, shapes, dtypes)
return [x, p]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4, rtol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testNBinomLogPmf(self, shapes, dtypes):
@ -288,9 +307,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
return [k, n, p, loc]
tol = {np.float32: 1e-6, np.float64: 1e-8}
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol)
@genNamedParametersNArgs(3)
def testLaplaceLogPdf(self, shapes, dtypes):
@ -301,12 +322,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None)
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
def testLaplaceCdf(self, shapes, dtypes):
@ -317,12 +339,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# ensure that scale is not too low
scale = np.clip(scale, a_min=0.1, a_max=None)
scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol={np.float32: 1e-5, np.float64: 1e-6})
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol={np.float32: 1e-5, np.float64: 1e-6})
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1)
def testLogisticCdf(self, shapes, dtypes):
@ -333,9 +356,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
return list(map(rng, shapes, dtypes))
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(1)
def testLogisticLogpdf(self, shapes, dtypes):
@ -392,12 +416,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
@ -409,12 +434,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
@ -426,12 +452,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
x, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-6)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(3)
@ -443,13 +470,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
q, loc, scale = map(rng, shapes, dtypes)
# ensure probability is between 0 and 1:
q = np.clip(np.abs(q / 3), a_min=None, a_max=1)
q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [q, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4)
@genNamedParametersNArgs(4)
@ -462,9 +490,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
x, b, loc, scale = map(rng, shapes, dtypes)
return [x, b, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
@ -476,13 +505,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
def args_maker():
x, df, loc, scale = map(rng, shapes, dtypes)
# clipping to ensure that scale is not too low
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None)
scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype)
return [x, df, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float64: 1e-14}, atol={np.float64: 1e-14})
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-3)
self._CompileAndCheck(lax_fun, args_maker,
rtol={np.float64: 1e-14}, atol={np.float64: 1e-14})
@genNamedParametersNArgs(3)
@ -495,9 +525,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
x, loc, scale = map(rng, shapes, dtypes)
return [x, loc, np.abs(scale)]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=1e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(4)
def testChi2LogPdf(self, shapes, dtypes):
@ -509,9 +540,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
x, df, loc, scale = map(rng, shapes, dtypes)
return [x, df, loc, scale]
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
with _strict_promotion_if_dtypes_match(dtypes):
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker)
@genNamedParametersNArgs(5)
def testBetaBinomLogPmf(self, shapes, dtypes):
@ -522,16 +554,17 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
k, n, a, b, loc = map(rng, shapes, dtypes)
k = np.floor(k)
n = np.ceil(n)
a = np.clip(a, a_min = 0.1, a_max = None)
b = np.clip(a, a_min = 0.1, a_max = None)
a = np.clip(a, a_min = 0.1, a_max=None).astype(a.dtype)
b = np.clip(a, a_min = 0.1, a_max=None).astype(b.dtype)
loc = np.floor(loc)
return [k, n, a, b, loc]
if scipy_version >= (1, 4):
scipy_fun = osp_stats.betabinom.logpmf
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
with _strict_promotion_if_dtypes_match(dtypes):
if scipy_version >= (1, 4):
scipy_fun = osp_stats.betabinom.logpmf
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False,
tol=5e-4)
self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5)
def testIssue972(self):
self.assertAllClose(
@ -584,11 +617,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
factor = rng(factor_shape, cov_dtype)
args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
return args
return [a.astype(x_dtype) for a in args]
self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf,
lsp_stats.multivariate_normal.logpdf,
args_maker, tol=1e-3)
args_maker, tol=1e-3, check_dtypes=False)
self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
rtol=1e-4, atol=1e-4)
@ -631,13 +664,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1])
factor = rng(factor_shape, cov_dtype)
args.append(np.matmul(factor, np.swapaxes(factor, -1, -2)))
return args
return [a.astype(x_dtype) for a in args]
osp_fun = np.vectorize(osp_stats.multivariate_normal.logpdf,
signature="(n),(n),(n,n)->()")
self._CheckAgainstNumpy(osp_fun, lsp_stats.multivariate_normal.logpdf,
args_maker, tol=1e-3)
args_maker, tol=1e-3, check_dtypes=False)
self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker,
rtol=1e-4, atol=1e-4)
@ -658,7 +691,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov)
result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov)
self.assertArraysEqual(result1, result2)
self.assertArraysEqual(result1, result2, check_dtypes=False)
if __name__ == "__main__":