From 76d9890bb7b2fca21a1061af08a915b9d1b275ef Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Tue, 18 Mar 2025 13:00:57 -0700 Subject: [PATCH] Run the stream annotation tests on 2 devices so that it can be tested in TAP PiperOrigin-RevId: 738113725 --- tests/memories_test.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/tests/memories_test.py b/tests/memories_test.py index a08c5f36c..0ca973c4d 100644 --- a/tests/memories_test.py +++ b/tests/memories_test.py @@ -1664,31 +1664,30 @@ class StreamAnnotationTest(jtu.JaxTestCase): def test_stream_annotation_inside_shmap(self): if not jtu.test_device_matches(["gpu"]): self.skipTest("Stream annotation is only supported on GPU.") - mesh = jtu.create_mesh((2, 2), ('x', 'y')) - s = NamedSharding(mesh, P('x', 'y')) - np_inp = np.ones((8, 8)) + mesh = jtu.create_mesh((2,), ('x',)) + s = NamedSharding(mesh, P('x')) + np_inp = np.ones((8,)) arr1 = jax.device_put(np_inp, s) arr2 = jax.device_put(np_inp, s) @compute_on('gpu_stream:1') @jax.jit def g(x, y): - return x @ y + return x * y @compute_on('gpu_stream:2') @jax.jit def h(x, y): - return x @ y + return x * y def f(x, y): z = g(x, y) w = h(3 * x, 2 * y) return z + w - out = jax.jit(shard_map(f, mesh=mesh, - in_specs=(P('x', 'y'), P('x', 'y')), - out_specs=P('x', 'y')))(arr1, arr2) - self.assertArraysEqual(out, arr1 * 28) + out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')), + out_specs=P('x')))(arr1, arr2) + self.assertArraysEqual(out, arr1 * 7) class ActivationOffloadingTest(jtu.JaxTestCase):