Run the stream annotation tests on 2 devices so that it can be tested in TAP

PiperOrigin-RevId: 738113725
This commit is contained in:
Yash Katariya 2025-03-18 13:00:57 -07:00 committed by jax authors
parent 942ff38e36
commit 76d9890bb7

View File

@ -1664,31 +1664,30 @@ class StreamAnnotationTest(jtu.JaxTestCase):
def test_stream_annotation_inside_shmap(self):
if not jtu.test_device_matches(["gpu"]):
self.skipTest("Stream annotation is only supported on GPU.")
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
s = NamedSharding(mesh, P('x', 'y'))
np_inp = np.ones((8, 8))
mesh = jtu.create_mesh((2,), ('x',))
s = NamedSharding(mesh, P('x'))
np_inp = np.ones((8,))
arr1 = jax.device_put(np_inp, s)
arr2 = jax.device_put(np_inp, s)
@compute_on('gpu_stream:1')
@jax.jit
def g(x, y):
return x @ y
return x * y
@compute_on('gpu_stream:2')
@jax.jit
def h(x, y):
return x @ y
return x * y
def f(x, y):
z = g(x, y)
w = h(3 * x, 2 * y)
return z + w
out = jax.jit(shard_map(f, mesh=mesh,
in_specs=(P('x', 'y'), P('x', 'y')),
out_specs=P('x', 'y')))(arr1, arr2)
self.assertArraysEqual(out, arr1 * 28)
out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')),
out_specs=P('x')))(arr1, arr2)
self.assertArraysEqual(out, arr1 * 7)
class ActivationOffloadingTest(jtu.JaxTestCase):