mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
looser TPU precision
This commit is contained in:
parent
0788d5708a
commit
dc2a50ff21
@ -34,12 +34,12 @@ one_and_two_dim_shapes = [(4,), (3, 4), (3, 1), (1, 4)]
|
||||
|
||||
|
||||
def genNamedParametersNArgs(n):
|
||||
return parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
|
||||
"shapes": shapes, "dtypes": dtypes}
|
||||
for shapes in itertools.combinations_with_replacement(all_shapes, n)
|
||||
for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
|
||||
return parameterized.named_parameters(
|
||||
jtu.cases_from_list(
|
||||
{"testcase_name": jtu.format_test_name_suffix("", shapes, dtypes),
|
||||
"shapes": shapes, "dtypes": dtypes}
|
||||
for shapes in itertools.combinations_with_replacement(all_shapes, n)
|
||||
for dtypes in itertools.combinations_with_replacement(jtu.dtypes.floating, n)))
|
||||
|
||||
|
||||
# Allow implicit rank promotion in these tests, as virtually every test exercises it.
|
||||
@ -704,7 +704,7 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
for func in [None, "evaluate", "logpdf", "pdf"]))
|
||||
def testKde(self, inshape, dtype, outsize, weights, method, func):
|
||||
if method == "callable":
|
||||
method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))
|
||||
method = lambda kde: jax.numpy.power(kde.neff, -1./(kde.d+4))
|
||||
|
||||
def scipy_fun(dataset, points, w):
|
||||
w = np.abs(w) if weights else None
|
||||
@ -732,8 +732,11 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
args_maker = lambda: [
|
||||
rng(inshape, dtype), rng(outshape, dtype), rng(inshape[-1:], dtype)]
|
||||
self._CheckAgainstNumpy(scipy_fun, lax_fun, args_maker,
|
||||
tol={np.float32: 1e-3, np.float64: 1e-14})
|
||||
self._CheckAgainstNumpy(
|
||||
scipy_fun, lax_fun, args_maker, tol={
|
||||
np.float32: 1e-2 if jtu.device_under_test() == "tpu" else 1e-3,
|
||||
np.float64: 1e-14
|
||||
})
|
||||
self._CompileAndCheck(
|
||||
lax_fun, args_maker, rtol={np.float32: 3e-07, np.float64: 4e-15})
|
||||
|
||||
@ -887,4 +890,4 @@ class LaxBackedScipyStatsTests(jtu.JaxTestCase):
|
||||
self.assertAllClose(evaluate_kde(kde, x), kde.evaluate(x))
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user