diff --git a/tests/pjit_test.py b/tests/pjit_test.py index c7aaa3397..f87e5bcc5 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1181,14 +1181,15 @@ class UtilTest(jtu.JaxTestCase): global_mesh = create_global_mesh((2, 2), ('x', 'y')) global_in_aval1 = jax.core.ShapedArray((4, 4), jnp.int32) global_in_aval2 = jax.core.ShapedArray((4, 4, 4), jnp.int32) - in_avals = [global_in_aval1, global_in_aval2] + global_in_aval3 = jax.core.ShapedArray((), jnp.int32) + in_avals = [global_in_aval1, global_in_aval2, global_in_aval3] _, out_indices, _ = pxla._get_input_metadata( - in_avals, global_mesh, [{}, {}], [False, False]) + in_avals, global_mesh, [{}, {}, {}], [False, False, False]) self.assertLen(out_indices, len(in_avals)) - self.assertLen(out_indices[0], len(global_mesh.local_devices)) - self.assertLen(out_indices[1], len(global_mesh.local_devices)) + self.assertTrue(all(len(out) == len(global_mesh.local_devices) + for out in out_indices)) self.assertTrue(all(len(i) == aval.ndim for out, aval in safe_zip(out_indices, in_avals) for i in out)) self.assertTrue(all(i == (slice(None),) * aval.ndim