Fix test failures under SciPy 1.10.0.

This commit is contained in:
Peter Hawkins 2023-01-30 20:34:01 +00:00
parent 7eb7baa6f1
commit 27da460f25

View File

@ -885,7 +885,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
)
def testKde(self, inshape, dtype, outsize, weights, method, func):
if method == "callable":
method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))
method = lambda kde: kde.neff ** -1./(kde.d+4)
def scipy_fun(dataset, points, w):
w = np.abs(w) if weights else None
@ -915,11 +915,12 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)]
self._CheckAgainstNumpy(
scipy_fun, lax_fun, args_maker, tol={
np.float32: 1e-2 if jtu.device_under_test() == "tpu" else 1e-3,
np.float64: 1e-14
np.float32: 2e-2 if jtu.device_under_test() == "tpu" else 1e-3,
np.float64: 3e-14
})
self._CompileAndCheck(
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
lax_fun, args_maker, rtol={np.float32: 3e-5, np.float64: 3e-14},
atol={np.float32: 3e-4, np.float64: 3e-14})
@jtu.sample_product(
shape=[(15,), (3, 15), (1, 12)],