Add tests for 0d fully replicated scalar input to pjit.

PiperOrigin-RevId: 420884601
This commit is contained in:
Yash Katariya 2022-01-10 16:23:59 -08:00 committed by jax authors
parent 67723da38b
commit 7bc51879d4

View File

@ -1181,14 +1181,15 @@ class UtilTest(jtu.JaxTestCase):
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
global_in_aval1 = jax.core.ShapedArray((4, 4), jnp.int32)
global_in_aval2 = jax.core.ShapedArray((4, 4, 4), jnp.int32)
in_avals = [global_in_aval1, global_in_aval2]
global_in_aval3 = jax.core.ShapedArray((), jnp.int32)
in_avals = [global_in_aval1, global_in_aval2, global_in_aval3]
_, out_indices, _ = pxla._get_input_metadata(
in_avals, global_mesh, [{}, {}], [False, False])
in_avals, global_mesh, [{}, {}, {}], [False, False, False])
self.assertLen(out_indices, len(in_avals))
self.assertLen(out_indices[0], len(global_mesh.local_devices))
self.assertLen(out_indices[1], len(global_mesh.local_devices))
self.assertTrue(all(len(out) == len(global_mesh.local_devices)
for out in out_indices))
self.assertTrue(all(len(i) == aval.ndim
for out, aval in safe_zip(out_indices, in_avals) for i in out))
self.assertTrue(all(i == (slice(None),) * aval.ndim