mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add tests for 0d fully replicated scalar input to pjit.
PiperOrigin-RevId: 420884601
This commit is contained in:
parent
67723da38b
commit
7bc51879d4
@ -1181,14 +1181,15 @@ class UtilTest(jtu.JaxTestCase):
|
||||
global_mesh = create_global_mesh((2, 2), ('x', 'y'))
|
||||
global_in_aval1 = jax.core.ShapedArray((4, 4), jnp.int32)
|
||||
global_in_aval2 = jax.core.ShapedArray((4, 4, 4), jnp.int32)
|
||||
in_avals = [global_in_aval1, global_in_aval2]
|
||||
global_in_aval3 = jax.core.ShapedArray((), jnp.int32)
|
||||
in_avals = [global_in_aval1, global_in_aval2, global_in_aval3]
|
||||
|
||||
_, out_indices, _ = pxla._get_input_metadata(
|
||||
in_avals, global_mesh, [{}, {}], [False, False])
|
||||
in_avals, global_mesh, [{}, {}, {}], [False, False, False])
|
||||
|
||||
self.assertLen(out_indices, len(in_avals))
|
||||
self.assertLen(out_indices[0], len(global_mesh.local_devices))
|
||||
self.assertLen(out_indices[1], len(global_mesh.local_devices))
|
||||
self.assertTrue(all(len(out) == len(global_mesh.local_devices)
|
||||
for out in out_indices))
|
||||
self.assertTrue(all(len(i) == aval.ndim
|
||||
for out, aval in safe_zip(out_indices, in_avals) for i in out))
|
||||
self.assertTrue(all(i == (slice(None),) * aval.ndim
|
||||
|
Loading…
x
Reference in New Issue
Block a user