Add more pjit tests and make some tests go via actual computations rather than trivial computation.

PiperOrigin-RevId: 482919649
This commit is contained in:
Yash Katariya 2022-10-21 16:53:14 -07:00 committed by jax authors
parent a4e366394b
commit 9956ad2f89
2 changed files with 48 additions and 15 deletions

View File

@ -797,6 +797,8 @@ class JaxTestCase(parameterized.TestCase):
def tearDown(self):
for key, value in self._original_config.items():
config.update(key, value)
# TODO(parkers): Remove this when a real fix for most_recent_entry lands.
dispatch.xla_callable.most_recent_entry()
super().tearDown()
def rng(self):

View File

@ -1818,6 +1818,10 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax_array(True)
def test_unspecified_out_axis_resources(self):
# TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend.
if (xla_bridge.get_backend().runtime_type == 'stream_executor' and
jtu.device_under_test() == 'tpu'):
self.skipTest('Does not work with the cloud TPU SE runtime.')
def _checks(out, input_data):
self.assertIsInstance(out, array.ArrayImpl)
@ -1835,26 +1839,27 @@ class ArrayPjitTest(jtu.JaxTestCase):
input_array, input_data = create_array(global_input_shape, global_mesh, mesh_axes)
f = pjit(lambda x: x)
f = pjit(lambda x: x * 2)
out = f(input_array)
_checks(out, input_data)
_checks(out, input_data * 2)
out2 = f(out)
_checks(out2, input_data)
_checks(out2, input_data * 4)
@parameterized.named_parameters(
('mesh1', (4, 2), (2, 1), (2, 2), (1, 2), (8, 2)),
('mesh2', (2, 2), (4, 1), (4, 2), (2, 2), (8, 2)),
('mesh3', (2, 1), (4, 2), (4, 2), (4, 2), (8, 2)),
('mesh1', (4, 2), (2, 8), (2, 2), (1, 2), (8, 2)),
('mesh2', (2, 2), (4, 8), (4, 2), (2, 2), (8, 2)),
('mesh3', (2, 1), (4, 8), (4, 2), (4, 2), (8, 2)),
)
@jax_array(True)
def test_pjit_array_multi_input_multi_output(self, mesh_shape, s1_shape,
s2_shape, s3_shape, s4_shape):
# Disable on SE runtime type because XLA sharding propagation is not
# supported.
if xla_bridge.get_backend().runtime_type == 'se':
raise unittest.SkipTest('Needs TFRT runtime.')
# TODO(https://github.com/google/jax/issues/12927): Fix cloud TPU SE backend.
if (xla_bridge.get_backend().runtime_type == 'stream_executor' and
jtu.device_under_test() == 'tpu'):
self.skipTest('Does not work with the cloud TPU SE runtime.')
global_mesh = jtu.create_global_mesh(mesh_shape, ('x', 'y'))
global_input_shape = (8, 2)
@ -1870,14 +1875,15 @@ class ArrayPjitTest(jtu.JaxTestCase):
@pjit
def f(tree):
return tree
out_tree = f((a1, (a2, (a3, a4))))
out_tree = f((a1 @ a1.T, (a2, (a3 * 2, a4))))
(out1, out2, out3, out4), _ = jax.tree_util.tree_flatten(out_tree)
self.assertIsInstance(out1, array.ArrayImpl)
self.assertEqual(out1.shape, (8, 2))
self.assertEqual(out1.shape, (8, 8))
self.assertEqual(out1.addressable_shards[0].data.shape, s1_shape)
for s in out1.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(
s.data._arrays[0], (input_data @ input_data.T)[s.index])
self.assertIsInstance(out2, array.ArrayImpl)
self.assertEqual(out2.shape, (8, 2))
@ -1889,7 +1895,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
self.assertEqual(out3.shape, (8, 2))
self.assertEqual(out3.addressable_shards[0].data.shape, s3_shape)
for s in out3.addressable_shards:
self.assertArraysEqual(s.data._arrays[0], input_data[s.index])
self.assertArraysEqual(s.data._arrays[0], (input_data * 2)[s.index])
self.assertIsInstance(out4, array.ArrayImpl)
self.assertEqual(out4.shape, (8, 2))
@ -2176,7 +2182,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
@jax_array(True)
def test_grad_of_pjit_single_device_sharding(self):
a = jnp.array(16, dtype=jnp.float32)
f = lambda x: x
f = lambda x: x * 3
out = jax.grad(pjit(f))(a)
self.assertIsInstance(out, array.ArrayImpl)
self.assertArraysEqual(out, jax.grad(f)(a))
@ -2488,6 +2494,31 @@ class ArrayPjitTest(jtu.JaxTestCase):
for i, x, y in zip(range(n), xs, ys):
self.assertAllClose(x + i, y)
@jax_array(True)
def test_trivial_computation(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
s = MeshPspecSharding(mesh, P('x', 'y'))
inp_data = np.arange(prod(shape)).reshape(shape)
arr = jax.device_put(inp_data, s)
out = pjit(lambda x: x)(arr)
self.assertArraysEqual(out, inp_data)
@jax_array(True)
def test_multi_device_pjit_mul(self):
shape = (8, 2)
mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp_data = np.arange(prod(shape)).reshape(shape)
arr1 = jax.device_put(inp_data, MeshPspecSharding(mesh, P('x', 'y')))
arr2 = jax.device_put(inp_data, MeshPspecSharding(mesh, P(None, 'y')))
out1, out2 = pjit(lambda x, y: (x @ x.T, y * 2))(arr1, arr2)
self.assertArraysEqual(out1, inp_data @ inp_data.T)
self.assertEqual(out1.shape, (8, 8))
self.assertArraysEqual(out2, inp_data * 2)
self.assertEqual(out2.shape, (8, 2))
class TempSharding(Sharding):