diff --git a/jax/_src/scipy/stats/multivariate_normal.py b/jax/_src/scipy/stats/multivariate_normal.py index 7e89d5903..ffcaa00ad 100644 --- a/jax/_src/scipy/stats/multivariate_normal.py +++ b/jax/_src/scipy/stats/multivariate_normal.py @@ -32,13 +32,13 @@ def logpdf(x, mean, cov, allow_singular=None): x, mean, cov = _promote_dtypes_inexact(x, mean, cov) if not mean.shape: return (-1/2 * jnp.square(x - mean) / cov - - 1/2 * (np.log(2*np.pi) + jnp.log(cov))) + - 1/2 * (jnp.log(2*np.pi) + jnp.log(cov))) else: n = mean.shape[-1] if not np.shape(cov): y = x - mean return (-1/2 * jnp.einsum('...i,...i->...', y, y) / cov - - n/2 * (np.log(2*np.pi) + jnp.log(cov))) + - n/2 * (jnp.log(2*np.pi) + jnp.log(cov))) else: if cov.ndim < 2 or cov.shape[-2:] != (n, n): raise ValueError("multivariate_normal.logpdf got incompatible shapes") @@ -47,7 +47,7 @@ def logpdf(x, mean, cov, allow_singular=None): partial(lax.linalg.triangular_solve, lower=True, transpose_a=True), signature="(n,n),(n)->(n)" )(L, x - mean) - return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2*np.log(2*np.pi) + return (-1/2 * jnp.einsum('...i,...i->...', y, y) - n/2 * jnp.log(2*np.pi) - jnp.log(L.diagonal(axis1=-1, axis2=-2)).sum(-1)) @_wraps(osp_stats.multivariate_normal.pdf, update_doc=False) diff --git a/tests/scipy_stats_test.py b/tests/scipy_stats_test.py index fb609ce13..dcad58481 100644 --- a/tests/scipy_stats_test.py +++ b/tests/scipy_stats_test.py @@ -34,6 +34,12 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)] scipy_version = tuple(map(int, osp.version.version.split('.')[:2])) +def _strict_promotion_if_dtypes_match(dtypes): + if all(dtype == dtypes[0] for dtype in dtypes): + return jax.numpy_dtype_promotion('strict') + return jax.numpy_dtype_promotion('standard') + + def genNamedParametersNArgs(n): return parameterized.named_parameters( jtu.cases_from_list( @@ -58,13 +64,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive - mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) + mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) loc = np.floor(loc) return [k, mu, loc] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, rtol={np.float64: 1e-14}) @genNamedParametersNArgs(3) def testPoissonPmf(self, shapes, dtypes): @@ -76,13 +83,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): k, mu, loc = map(rng, shapes, dtypes) k = np.floor(k) # clipping to ensure that rate parameter is strictly positive - mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) + mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) loc = np.floor(loc) return [k, mu, loc] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testPoissonCdf(self, shapes, dtypes): @@ -93,12 +101,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): k, mu, loc = map(rng, shapes, dtypes) # clipping to ensure that rate parameter is strictly positive - mu = np.clip(np.abs(mu), a_min=0.1, a_max=None) + mu = np.clip(np.abs(mu), a_min=0.1, a_max=None).astype(mu.dtype) return [k, mu, loc] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) @@ -114,9 +123,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): loc = np.floor(loc) return [x, p, loc] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testGeomLogPmf(self, shapes, dtypes): @@ -131,9 +141,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): loc = np.floor(loc) return [x, p, loc] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaLogPdf(self, shapes, dtypes): @@ -145,10 +156,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, a, b, loc, scale = map(rng, shapes, dtypes) return [x, a, b, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, - rtol={np.float32: 2e-3, np.float64: 1e-4}) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, + rtol={np.float32: 2e-3, np.float64: 1e-4}) def testBetaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7645 @@ -166,12 +178,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) @parameterized.named_parameters( jtu.cases_from_list( @@ -208,9 +221,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): return x, alpha tol = {np.float32: 1E-3, np.float64: 1e-5} - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=tol) - self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol) + + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=tol) + self._CompileAndCheck(lax_fun, args_maker, atol=tol, rtol=tol) @genNamedParametersNArgs(3) def testExponLogPdf(self, shapes, dtypes): @@ -222,9 +237,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testGammaLogPdf(self, shapes, dtypes): @@ -236,9 +252,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, a, loc, scale = map(rng, shapes, dtypes) return [x, a, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) def testGammaLogPdfZero(self): # Regression test for https://github.com/google/jax/issues/7256 @@ -255,9 +272,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, p = map(rng, shapes, dtypes) return [x, p] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4, rtol=1e-3) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4, rtol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(2) def testGenNormCdf(self, shapes, dtypes): @@ -269,9 +287,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, p = map(rng, shapes, dtypes) return [x, p] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4, rtol=1e-3) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4, rtol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testNBinomLogPmf(self, shapes, dtypes): @@ -288,9 +307,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): return [k, n, p, loc] tol = {np.float32: 1e-6, np.float64: 1e-8} - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) + + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=tol, atol=tol) @genNamedParametersNArgs(3) def testLaplaceLogPdf(self, shapes, dtypes): @@ -301,12 +322,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low - scale = np.clip(scale, a_min=0.1, a_max=None) + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) def testLaplaceCdf(self, shapes, dtypes): @@ -317,12 +339,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # ensure that scale is not too low - scale = np.clip(scale, a_min=0.1, a_max=None) + scale = np.clip(scale, a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol={np.float32: 1e-5, np.float64: 1e-6}) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol={np.float32: 1e-5, np.float64: 1e-6}) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticCdf(self, shapes, dtypes): @@ -333,9 +356,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): return list(map(rng, shapes, dtypes)) - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-6) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-6) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(1) def testLogisticLogpdf(self, shapes, dtypes): @@ -392,12 +416,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) @@ -409,12 +434,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) @@ -426,12 +452,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): x, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-6) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-6) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(3) @@ -443,13 +470,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): q, loc, scale = map(rng, shapes, dtypes) # ensure probability is between 0 and 1: - q = np.clip(np.abs(q / 3), a_min=None, a_max=1) + q = np.clip(np.abs(q / 3), a_min=None, a_max=1).astype(q.dtype) # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [q, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=3e-4) @genNamedParametersNArgs(4) @@ -462,9 +490,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, b, loc, scale = map(rng, shapes, dtypes) return [x, b, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) @@ -476,13 +505,14 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): def args_maker(): x, df, loc, scale = map(rng, shapes, dtypes) # clipping to ensure that scale is not too low - scale = np.clip(np.abs(scale), a_min=0.1, a_max=None) + scale = np.clip(np.abs(scale), a_min=0.1, a_max=None).astype(scale.dtype) return [x, df, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-3) - self._CompileAndCheck(lax_fun, args_maker, - rtol={np.float64: 1e-14}, atol={np.float64: 1e-14}) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-3) + self._CompileAndCheck(lax_fun, args_maker, + rtol={np.float64: 1e-14}, atol={np.float64: 1e-14}) @genNamedParametersNArgs(3) @@ -495,9 +525,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, loc, scale = map(rng, shapes, dtypes) return [x, loc, np.abs(scale)] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=1e-4) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=1e-4) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(4) def testChi2LogPdf(self, shapes, dtypes): @@ -509,9 +540,10 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): x, df, loc, scale = map(rng, shapes, dtypes) return [x, df, loc, scale] - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker) + with _strict_promotion_if_dtypes_match(dtypes): + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker) @genNamedParametersNArgs(5) def testBetaBinomLogPmf(self, shapes, dtypes): @@ -522,16 +554,17 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): k, n, a, b, loc = map(rng, shapes, dtypes) k = np.floor(k) n = np.ceil(n) - a = np.clip(a, a_min = 0.1, a_max = None) - b = np.clip(a, a_min = 0.1, a_max = None) + a = np.clip(a, a_min = 0.1, a_max=None).astype(a.dtype) + b = np.clip(a, a_min = 0.1, a_max=None).astype(b.dtype) loc = np.floor(loc) return [k, n, a, b, loc] - if scipy_version >= (1, 4): - scipy_fun = osp_stats.betabinom.logpmf - self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, - tol=5e-4) - self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) + with _strict_promotion_if_dtypes_match(dtypes): + if scipy_version >= (1, 4): + scipy_fun = osp_stats.betabinom.logpmf + self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker, check_dtypes=False, + tol=5e-4) + self._CompileAndCheck(lax_fun, args_maker, rtol=1e-5, atol=1e-5) def testIssue972(self): self.assertAllClose( @@ -584,11 +617,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) - return args + return [a.astype(x_dtype) for a in args] self._CheckAgainstNumpy(osp_stats.multivariate_normal.logpdf, lsp_stats.multivariate_normal.logpdf, - args_maker, tol=1e-3) + args_maker, tol=1e-3, check_dtypes=False) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @@ -631,13 +664,13 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): factor_shape = (*cov_shape[:-1], 2 * cov_shape[-1]) factor = rng(factor_shape, cov_dtype) args.append(np.matmul(factor, np.swapaxes(factor, -1, -2))) - return args + return [a.astype(x_dtype) for a in args] osp_fun = np.vectorize(osp_stats.multivariate_normal.logpdf, signature="(n),(n),(n,n)->()") self._CheckAgainstNumpy(osp_fun, lsp_stats.multivariate_normal.logpdf, - args_maker, tol=1e-3) + args_maker, tol=1e-3, check_dtypes=False) self._CompileAndCheck(lsp_stats.multivariate_normal.logpdf, args_maker, rtol=1e-4, atol=1e-4) @@ -658,7 +691,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase): result1 = lsp_stats.multivariate_normal.logpdf(x, mean, cov) result2 = jax.vmap(lsp_stats.multivariate_normal.logpdf)(x, mean, cov) - self.assertArraysEqual(result1, result2) + self.assertArraysEqual(result1, result2, check_dtypes=False) if __name__ == "__main__":