diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 75b0da7b5..0253fb697 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -1545,11 +1545,11 @@ def manual_proto( tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes])) tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes])) - raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape) proto = xc.OpSharding() proto.type = xc.OpSharding.Type.OTHER proto.tile_assignment_dimensions = tad_shape - proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat) + proto.iota_reshape_dims = mesh_shape + proto.iota_transpose_perm = tad_perm proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL] return proto diff --git a/tests/shard_map_test.py b/tests/shard_map_test.py index 710170ad2..a3ab64049 100644 --- a/tests/shard_map_test.py +++ b/tests/shard_map_test.py @@ -1646,6 +1646,10 @@ class ShardMapTest(jtu.JaxTestCase): v = jnp.arange(32.).reshape(4, 8) v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j'))) + self.assertIn( + 'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}', + f.lower(v).as_text('hlo'), + ) self.assertAllClose(v*v, f(v), check_dtypes=False) def test_partial_auto_error_wsc_manual(self):