mirror of
https://github.com/ROCm/jax.git
synced 2025-04-18 21:06:06 +00:00
Add more pjit tests and make some tests go via actual computations rather than trivial computation.
PiperOrigin-RevId: 482919649
This commit is contained in:
parent
a4e366394b
commit
9956ad2f89
@ -797,6 +797,8 @@ class JaxTestCase(parameterized.TestCase):
|
||||
def tearDown(self):
|
||||
for key, value in self._original_config.items():
|
||||
config.update(key, value)
|
||||
# TODO(parkers): Remove this when a real fix for most_recent_entry lands.
|
||||
dispatch.xla_callable.most_recent_entry()
|
||||
super().tearDown()
|
||||
|
||||
def rng(self):
|
||||
|
@ -1818,6 +1818,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
@jax_array(True)
|
||||
def test_unspecified_out_axis_resources(self):
|
||||
# TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend.
|
||||
if (xla_bridge.get_backend().runtime_type == 'stream_executor' and
|
||||
jtu.device_under_test() == 'tpu'):
|
||||
self.skipTest('Does not work with the cloud TPU SE runtime.')
|
||||
|
||||
def _checks(out, input_data):
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
@ -1835,26 +1839,27 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
|
||||
|
||||
f = pjit(lambda x: x)
|
||||
f = pjit(lambda x: x * 2)
|
||||
|
||||
out = f(input_array)
|
||||
_checks(out, input_data)
|
||||
_checks(out, input_data * 2)
|
||||
|
||||
out2 = f(out)
|
||||
_checks(out2, input_data)
|
||||
_checks(out2, input_data * 4)
|
||||
|
||||
@parameterized.named_parameters(
|
||||
('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)),
|
||||
('mesh2', (2, 2), (4, 1), (4, 2), (2, 2), (8, 2)),
|
||||
('mesh3', (2, 1), (4, 2), (4, 2), (4, 2), (8, 2)),
|
||||
('mesh1', (4, 2), (2, 8), (2, 2), (1, 2), (8, 2)),
|
||||
('mesh2', (2, 2), (4, 8), (4, 2), (2, 2), (8, 2)),
|
||||
('mesh3', (2, 1), (4, 8), (4, 2), (4, 2), (8, 2)),
|
||||
)
|
||||
@jax_array(True)
|
||||
def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
|
||||
s2_shape, s3_shape, s4_shape):
|
||||
# Disable on SE runtime type because XLA sharding propagation is not
|
||||
# supported.
|
||||
if xla_bridge.get_backend().runtime_type == 'se':
|
||||
raise unittest.SkipTest('Needs TFRT runtime.')
|
||||
# TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend.
|
||||
if (xla_bridge.get_backend().runtime_type == 'stream_executor' and
|
||||
jtu.device_under_test() == 'tpu'):
|
||||
self.skipTest('Does not work with the cloud TPU SE runtime.')
|
||||
|
||||
global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
|
||||
global_input_shape = (8, 2)
|
||||
|
||||
@ -1870,14 +1875,15 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
@pjit
|
||||
def f(tree):
|
||||
return tree
|
||||
out_tree = f((a1, (a2, (a3, a4))))
|
||||
out_tree = f((a1 @ a1.T, (a2, (a3 * 2, a4))))
|
||||
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
|
||||
|
||||
self.assertIsInstance(out1, array.ArrayImpl)
|
||||
self.assertEqual(out1.shape, (8, 2))
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
|
||||
for s in out1.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertArraysEqual(
|
||||
s.data._arrays[0], (input_data @ input_data.T)[s.index])
|
||||
|
||||
self.assertIsInstance(out2, array.ArrayImpl)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
@ -1889,7 +1895,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
self.assertEqual(out3.shape, (8, 2))
|
||||
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
|
||||
for s in out3.addressable_shards:
|
||||
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
|
||||
self.assertArraysEqual(s.data._arrays[0], (input_data * 2)[s.index])
|
||||
|
||||
self.assertIsInstance(out4, array.ArrayImpl)
|
||||
self.assertEqual(out4.shape, (8, 2))
|
||||
@ -2176,7 +2182,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
@jax_array(True)
|
||||
def test_grad_of_pjit_single_device_sharding(self):
|
||||
a = jnp.array(16, dtype=jnp.float32)
|
||||
f = lambda x: x
|
||||
f = lambda x: x * 3
|
||||
out = jax.grad(pjit(f))(a)
|
||||
self.assertIsInstance(out, array.ArrayImpl)
|
||||
self.assertArraysEqual(out, jax.grad(f)(a))
|
||||
@ -2488,6 +2494,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
for i, x, y in zip(range(n), xs, ys):
|
||||
self.assertAllClose(x + i, y)
|
||||
|
||||
@jax_array(True)
|
||||
def test_trivial_computation(self):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
|
||||
s = MeshPspecSharding(mesh, P('x', 'y'))
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
arr = jax.device_put(inp_data, s)
|
||||
out = pjit(lambda x: x)(arr)
|
||||
self.assertArraysEqual(out, inp_data)
|
||||
|
||||
@jax_array(True)
|
||||
def test_multi_device_pjit_mul(self):
|
||||
shape = (8, 2)
|
||||
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
|
||||
inp_data = np.arange(prod(shape)).reshape(shape)
|
||||
arr1 = jax.device_put(inp_data, MeshPspecSharding(mesh, P('x', 'y')))
|
||||
arr2 = jax.device_put(inp_data, MeshPspecSharding(mesh, P(None, 'y')))
|
||||
|
||||
out1, out2 = pjit(lambda x, y: (x @ x.T, y * 2))(arr1, arr2)
|
||||
|
||||
self.assertArraysEqual(out1, inp_data @ inp_data.T)
|
||||
self.assertEqual(out1.shape, (8, 8))
|
||||
self.assertArraysEqual(out2, inp_data * 2)
|
||||
self.assertEqual(out2.shape, (8, 2))
|
||||
|
||||
|
||||
class TempSharding(Sharding):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user