Use cases_from_list to subsample enumerated cases in for_loop_test

PiperOrigin-RevId: 474093596
This commit is contained in:
Sharad Vikram 2022-09-13 12:33:40 -07:00 committed by jax authors
parent a2930e6a1e
commit ad326b99da
2 changed files with 9 additions and 9 deletions

View File

@ -914,9 +914,9 @@ jax_test(
name = "for_loop_test",
srcs = ["for_loop_test.py"],
shard_count = {
"cpu": 20,
"gpu": 20,
"tpu": 20,
"cpu": 10,
"gpu": 10,
"tpu": 10,
},
)

View File

@ -196,7 +196,7 @@ for_reference = for_loop.discharged_for_loop
class ForLoopTransformationTest(jtu.JaxTestCase):
@parameterized.named_parameters(
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
for_body_name, nsteps, impl_name),
"f": for_body, "body_shapes": body_shapes,
@ -210,7 +210,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
])
]))
def test_for_jvp(self, f, ref, body_shapes, n, for_impl):
for_ = for_impl
rng = self.rng()
@ -225,7 +225,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])
@parameterized.named_parameters(
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
for_body_name, nsteps, impl_name),
"f": for_body, "body_shapes": body_shapes,
@ -239,7 +239,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
])
]))
def test_for_linearize(self, f, ref, body_shapes, n, for_impl):
for_ = for_impl
rng = self.rng()
@ -328,7 +328,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0])
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1])
@parameterized.named_parameters(
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
for_body_name, nsteps, impl_name),
"f": for_body, "body_shapes": body_shapes,
@ -343,7 +343,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
])
]))
def test_for_grad(self, f, ref, body_shapes, n, for_impl):
for_ = for_impl
rng = self.rng()