mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Run the stream annotation tests on 2 devices so that it can be tested in TAP
PiperOrigin-RevId: 738113725
This commit is contained in:
parent
942ff38e36
commit
76d9890bb7
@ -1664,31 +1664,30 @@ class StreamAnnotationTest(jtu.JaxTestCase):
|
|||||||
def test_stream_annotation_inside_shmap(self):
|
def test_stream_annotation_inside_shmap(self):
|
||||||
if not jtu.test_device_matches(["gpu"]):
|
if not jtu.test_device_matches(["gpu"]):
|
||||||
self.skipTest("Stream annotation is only supported on GPU.")
|
self.skipTest("Stream annotation is only supported on GPU.")
|
||||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
mesh = jtu.create_mesh((2,), ('x',))
|
||||||
s = NamedSharding(mesh, P('x', 'y'))
|
s = NamedSharding(mesh, P('x'))
|
||||||
np_inp = np.ones((8, 8))
|
np_inp = np.ones((8,))
|
||||||
arr1 = jax.device_put(np_inp, s)
|
arr1 = jax.device_put(np_inp, s)
|
||||||
arr2 = jax.device_put(np_inp, s)
|
arr2 = jax.device_put(np_inp, s)
|
||||||
|
|
||||||
@compute_on('gpu_stream:1')
|
@compute_on('gpu_stream:1')
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def g(x, y):
|
def g(x, y):
|
||||||
return x @ y
|
return x * y
|
||||||
|
|
||||||
@compute_on('gpu_stream:2')
|
@compute_on('gpu_stream:2')
|
||||||
@jax.jit
|
@jax.jit
|
||||||
def h(x, y):
|
def h(x, y):
|
||||||
return x @ y
|
return x * y
|
||||||
|
|
||||||
def f(x, y):
|
def f(x, y):
|
||||||
z = g(x, y)
|
z = g(x, y)
|
||||||
w = h(3 * x, 2 * y)
|
w = h(3 * x, 2 * y)
|
||||||
return z + w
|
return z + w
|
||||||
|
|
||||||
out = jax.jit(shard_map(f, mesh=mesh,
|
out = jax.jit(shard_map(f, mesh=mesh, in_specs=(P('x'), P('x')),
|
||||||
in_specs=(P('x', 'y'), P('x', 'y')),
|
out_specs=P('x')))(arr1, arr2)
|
||||||
out_specs=P('x', 'y')))(arr1, arr2)
|
self.assertArraysEqual(out, arr1 * 7)
|
||||||
self.assertArraysEqual(out, arr1 * 28)
|
|
||||||
|
|
||||||
|
|
||||||
class ActivationOffloadingTest(jtu.JaxTestCase):
|
class ActivationOffloadingTest(jtu.JaxTestCase):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user