skip some for_loop test cases on gpu due to flakey timeouts

PiperOrigin-RevId: 474168747
This commit is contained in:
Matthew Johnson 2022-09-13 17:51:16 -07:00 committed by jax authors
parent ef0256843f
commit 71b0968f70

View File

@ -221,6 +221,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
]))
@jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts?
def test_for_jvp(self, f, ref, body_shapes, n, for_impl):
for_ = for_impl
rng = self.rng()
@ -233,7 +234,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
expected = jax.jvp(ref, args, args)
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol, atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])
jtu.check_grads(partial(for_, n, f), (args,), order=2, modes=["fwd"])
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
@ -250,6 +251,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
]))
@jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts?
def test_for_linearize(self, f, ref, body_shapes, n, for_impl):
for_ = for_impl
rng = self.rng()
@ -354,6 +356,8 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
]))
@jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts?
@jtu.skip_on_flag("jax_skip_slow_tests", True)
def test_for_grad(self, f, ref, body_shapes, n, for_impl):
for_ = for_impl
rng = self.rng()
@ -368,9 +372,10 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
self.assertAllClose(ans, ans_discharged, check_dtypes=True, rtol=tol,
atol=tol)
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=3,
jtu.check_grads(lambda *args: for_(n, f, args)[1].sum(), args, order=2,
rtol=7e-3, atol=1e-2)
@jtu.skip_on_devices("gpu") # TODO(mattjj,sharadmv): timeouts?
def test_grad_of_triple_nested_for_loop(self):
func = lambda x: jnp.sin(x) + 1.