From dfd7d17c1d6f51f651839aff462ae94edefc7dc7 Mon Sep 17 00:00:00 2001 From: Zixuan Jiang Date: Fri, 31 May 2024 18:14:04 -0700 Subject: [PATCH] [JAX] Use iota_reshape_dims and iota_transpose_perm in pxla, which is more efficient than tile_assignment_devices. HloSharding V1 -> HloSharding V2. PiperOrigin-RevId: 639210975 --- jax/_src/interpreters/pxla.py | 4 ++-- tests/shard_map_test.py | 4 ++++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 75b0da7b5..0253fb697 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1545,11 +1545,11 @@ def manual_proto( tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes])) tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes])) - raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape) proto = xc.OpSharding() proto.type = xc.OpSharding.Type.OTHER proto.tile_assignment_dimensions = tad_shape - proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat) + proto.iota_reshape_dims = mesh_shape + proto.iota_transpose_perm = tad_perm proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL] return proto diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 710170ad2..a3ab64049 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1646,6 +1646,10 @@ class ShardMapTest(jtu.JaxTestCase): v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + self.assertIn( + 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}', + f.lower(v).as_text('hlo'), + ) self.assertAllClose(v*v, f(v), check_dtypes=False) def test_partial_auto_error_wsc_manual(self):