mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Use cases_from_list to subsample enumerated cases in for_loop_test
PiperOrigin-RevId: 474093596
This commit is contained in:
parent
a2930e6a1e
commit
ad326b99da
@ -914,9 +914,9 @@ jax_test(
|
||||
name = "for_loop_test",
|
||||
srcs = ["for_loop_test.py"],
|
||||
shard_count = {
|
||||
"cpu": 20,
|
||||
"gpu": 20,
|
||||
"tpu": 20,
|
||||
"cpu": 10,
|
||||
"gpu": 10,
|
||||
"tpu": 10,
|
||||
},
|
||||
)
|
||||
|
||||
|
@ -196,7 +196,7 @@ for_reference = for_loop.discharged_for_loop
|
||||
|
||||
class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
|
||||
for_body_name, nsteps, impl_name),
|
||||
"f": for_body, "body_shapes": body_shapes,
|
||||
@ -210,7 +210,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
|
||||
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
|
||||
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
|
||||
])
|
||||
]))
|
||||
def test_for_jvp(self, f, ref, body_shapes, n, for_impl):
|
||||
for_ = for_impl
|
||||
rng = self.rng()
|
||||
@ -225,7 +225,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
self.assertAllClose(ans, expected, check_dtypes=True, rtol=tol, atol=tol)
|
||||
jtu.check_grads(partial(for_, n, f), (args,), order=3, modes=["fwd"])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
|
||||
for_body_name, nsteps, impl_name),
|
||||
"f": for_body, "body_shapes": body_shapes,
|
||||
@ -239,7 +239,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
|
||||
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
|
||||
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
|
||||
])
|
||||
]))
|
||||
def test_for_linearize(self, f, ref, body_shapes, n, for_impl):
|
||||
for_ = for_impl
|
||||
rng = self.rng()
|
||||
@ -328,7 +328,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
np.testing.assert_allclose(actual_tangents[0], expected_tangents[0])
|
||||
np.testing.assert_allclose(actual_tangents[1], expected_tangents[1])
|
||||
|
||||
@parameterized.named_parameters(
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_f={}_nsteps={}_impl={}".format(
|
||||
for_body_name, nsteps, impl_name),
|
||||
"f": for_body, "body_shapes": body_shapes,
|
||||
@ -343,7 +343,7 @@ class ForLoopTransformationTest(jtu.JaxTestCase):
|
||||
("accum", for_body_accum, accum_ref, [(4,), (4,)], 3),
|
||||
("sin_sq", for_body_sin_sq, sin_sq_ref, [(4,), (4,)], 4),
|
||||
("reverse", for_body_reverse, reverse_ref, [(4,), (4,)], 4),
|
||||
])
|
||||
]))
|
||||
def test_for_grad(self, f, ref, body_shapes, n, for_impl):
|
||||
for_ = for_impl
|
||||
rng = self.rng()
|
||||
|
Loading…
x
Reference in New Issue
Block a user