diff --git a/tests/BUILD b/tests/BUILD index 4efd2c072..437b13853 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -914,9 +914,9 @@ jax_test( name = "for_loop_test", srcs = ["for_loop_test.py"], shard_count = { - "cpu": 20, - "gpu": 20, - "tpu": 20, + "cpu": 10, + "gpu": 10, + "tpu": 10, }, ) diff --git a/tests/for_loop_test.py b/tests/for_loop_test.py index 2014b8a8a..3e13d5326 100644 --- a/tests/for_loop_test.py +++ b/tests/for_loop_test.py @@ -196,7 +196,7 @@ for_reference = for_loop.discharged_for_loop class ForLoopTransformationTest(jtu.JaxTestCase): - @parameterized.named_parameters( + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_f={}_nsteps={}_impl={}".format( for_body_name, nsteps, impl_name), "f": for_body, "body_shapes": body_shapes, @@ -210,7 +210,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase): ("accum", for_body_accum, accum_ref, [(4,), (4,)], 3), ("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4), ("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4), - ]) + ])) def test_for_jvp(self, f, ref, body_shapes, n, for_impl): for_ = for_impl rng = self.rng() @@ -225,7 +225,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase): self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol) jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"]) - @parameterized.named_parameters( + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_f={}_nsteps={}_impl={}".format( for_body_name, nsteps, impl_name), "f": for_body, "body_shapes": body_shapes, @@ -239,7 +239,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase): ("accum", for_body_accum, accum_ref, [(4,), (4,)], 3), ("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4), ("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4), - ]) + ])) def test_for_linearize(self, f, ref, body_shapes, n, for_impl): for_ = for_impl rng = self.rng() @@ -328,7 +328,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase): np.testing.assert_allclose(actual_tangents[0], expected_tangents[0]) np.testing.assert_allclose(actual_tangents[1], expected_tangents[1]) - @parameterized.named_parameters( + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_f={}_nsteps={}_impl={}".format( for_body_name, nsteps, impl_name), "f": for_body, "body_shapes": body_shapes, @@ -343,7 +343,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase): ("accum", for_body_accum, accum_ref, [(4,), (4,)], 3), ("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4), ("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4), - ]) + ])) def test_for_grad(self, f, ref, body_shapes, n, for_impl): for_ = for_impl rng = self.rng()