Relax some test tolerances and disable some tests on GPU.

This commit is contained in:
Peter Hawkins 2020-12-01 16:02:06 -05:00
parent bdcf17cdd6
commit c5b260c7f8
3 changed files with 7 additions and 2 deletions

View File

@ -128,7 +128,7 @@ class GmapTest(jtu.JaxTestCase):
x = jnp.arange(800).reshape((8, 10, 10))
self.assertAllClose(gmap(f, schedule)(x), vmap(f)(x))
self.assertAllClose(gmap(f, schedule)(x), vmap(f)(x), rtol=3e-5)
@check_default_schedules(lambda s: not any(c[0] == 'sequential' for c in s))
@skip_insufficient_devices(8)

View File

@ -1633,7 +1633,8 @@ class LaxControlFlowTest(jtu.JaxTestCase):
ans = api.vmap(lambda c, as_: scan(f, c, as_), in_axes)(c, as_)
expected = api.vmap(lambda c, as_: scan_reference(f, c, as_), in_axes)(c, as_)
self.assertAllClose(ans, expected, check_dtypes=False)
self.assertAllClose(ans, expected, check_dtypes=False,
rtol=1e-5, atol=1e-5)
def testScanVmapTuples(self):
def f(c, a):

View File

@ -45,6 +45,8 @@ class TestPolynomial(jtu.JaxTestCase):
for length in [0, 3, 9, 10, 17]
for leading in [0, 1, 2, 3, 5, 7, 10]
for trailing in [0, 1, 2, 3, 5, 7, 10]))
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
@jtu.skip_on_devices("gpu")
def testRoots(self, dtype, rng_factory, length, leading, trailing):
rng = rng_factory(np.random.RandomState(0))
@ -67,6 +69,8 @@ class TestPolynomial(jtu.JaxTestCase):
for rng_factory in [jtu.rand_default]
for length in [0, 1, 3, 10]
for trailing in [0, 1, 3, 7]))
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
@jtu.skip_on_devices("gpu")
def testRootsNostrip(self, length, dtype, rng_factory, trailing):
rng = rng_factory(np.random.RandomState(0))