mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 12:56:07 +00:00
Fix test failures under SciPy 1.10.0.
This commit is contained in:
parent
7eb7baa6f1
commit
27da460f25
@ -885,7 +885,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
)
|
||||
def testKde(self, inshape, dtype, outsize, weights, method, func):
|
||||
if method == "callable":
|
||||
method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))
|
||||
method = lambda kde: kde.neff ** -1./(kde.d+4)
|
||||
|
||||
def scipy_fun(dataset, points, w):
|
||||
w = np.abs(w) if weights else None
|
||||
@ -915,11 +915,12 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)]
|
||||
self._CheckAgainstNumpy(
|
||||
scipy_fun, lax_fun, args_maker, tol={
|
||||
np.float32: 1e-2 if jtu.device_under_test() == "tpu" else 1e-3,
|
||||
np.float64: 1e-14
|
||||
np.float32: 2e-2 if jtu.device_under_test() == "tpu" else 1e-3,
|
||||
np.float64: 3e-14
|
||||
})
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-5, np.float64: 3e-14},
|
||||
atol={np.float32: 3e-4, np.float64: 3e-14})
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(15,), (3, 15), (1, 12)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user