mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Run the stream annotation tests on 2 devices so that it can be tested in TAP
PiperOrigin-RevId: 738113725
This commit is contained in:
parent
942ff38e36
commit
76d9890bb7
@ -1664,31 +1664,30 @@ class StreamAnnotationTest(jtu.JaxTestCase):
|
||||
def test_stream_annotation_inside_shmap(self):
|
||||
if not jtu.test_device_matches(["gpu"]):
|
||||
self.skipTest("Stream annotation is only supported on GPU.")
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
s = NamedSharding(mesh, P('x', 'y'))
|
||||
np_inp = np.ones((8, 8))
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
np_inp = np.ones((8,))
|
||||
arr1 = jax.device_put(np_inp, s)
|
||||
arr2 = jax.device_put(np_inp, s)
|
||||
|
||||
@compute_on('gpu_stream:1')
|
||||
@jax.jit
|
||||
def g(x, y):
|
||||
return x @ y
|
||||
return x * y
|
||||
|
||||
@compute_on('gpu_stream:2')
|
||||
@jax.jit
|
||||
def h(x, y):
|
||||
return x @ y
|
||||
return x * y
|
||||
|
||||
def f(x, y):
|
||||
z = g(x, y)
|
||||
w = h(3 * x, 2 * y)
|
||||
return z + w
|
||||
|
||||
out = jax.jit(shard_map(f, mesh=mesh,
|
||||
in_specs=(P('x', 'y'), P('x', 'y')),
|
||||
out_specs=P('x', 'y')))(arr1, arr2)
|
||||
self.assertArraysEqual(out, arr1 * 28)
|
||||
out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')),
|
||||
out_specs=P('x')))(arr1, arr2)
|
||||
self.assertArraysEqual(out, arr1 * 7)
|
||||
|
||||
|
||||
class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user