mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 04:46:06 +00:00
Relax some test tolerances and disable some tests on GPU.
This commit is contained in:
parent
bdcf17cdd6
commit
c5b260c7f8
@ -128,7 +128,7 @@ class GmapTest(jtu.JaxTestCase):
|
||||
|
||||
x = jnp.arange(800).reshape((8, 10, 10))
|
||||
|
||||
self.assertAllClose(gmap(f, schedule)(x), vmap(f)(x))
|
||||
self.assertAllClose(gmap(f, schedule)(x), vmap(f)(x), rtol=3e-5)
|
||||
|
||||
@check_default_schedules(lambda s: not any(c[0] == 'sequential' for c in s))
|
||||
@skip_insufficient_devices(8)
|
||||
|
@ -1633,7 +1633,8 @@ class LaxControlFlowTest(jtu.JaxTestCase):
|
||||
|
||||
ans = api.vmap(lambda c, as_: scan(f, c, as_), in_axes)(c, as_)
|
||||
expected = api.vmap(lambda c, as_: scan_reference(f, c, as_), in_axes)(c, as_)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False)
|
||||
self.assertAllClose(ans, expected, check_dtypes=False,
|
||||
rtol=1e-5, atol=1e-5)
|
||||
|
||||
def testScanVmapTuples(self):
|
||||
def f(c, a):
|
||||
|
@ -45,6 +45,8 @@ class TestPolynomial(jtu.JaxTestCase):
|
||||
for length in [0, 3, 9, 10, 17]
|
||||
for leading in [0, 1, 2, 3, 5, 7, 10]
|
||||
for trailing in [0, 1, 2, 3, 5, 7, 10]))
|
||||
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def testRoots(self, dtype, rng_factory, length, leading, trailing):
|
||||
rng = rng_factory(np.random.RandomState(0))
|
||||
|
||||
@ -67,6 +69,8 @@ class TestPolynomial(jtu.JaxTestCase):
|
||||
for rng_factory in [jtu.rand_default]
|
||||
for length in [0, 1, 3, 10]
|
||||
for trailing in [0, 1, 3, 7]))
|
||||
# TODO(phawkins): no nonsymmetric eigendecomposition implementation on GPU.
|
||||
@jtu.skip_on_devices("gpu")
|
||||
def testRootsNostrip(self, length, dtype, rng_factory, trailing):
|
||||
rng = rng_factory(np.random.RandomState(0))
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user