[JAX] Use iota_reshape_dims and iota_transpose_perm in pxla, which is more efficient than tile_assignment_devices.

HloSharding V1 -> HloSharding V2.

PiperOrigin-RevId: 639210975
This commit is contained in:
Zixuan Jiang 2024-05-31 18:14:04 -07:00 committed by jax authors
parent 1edb94ec46
commit dfd7d17c1d
2 changed files with 6 additions and 2 deletions

View File

@ -1545,11 +1545,11 @@ def manual_proto(
tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes]))
tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes]))
raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
proto = xc.OpSharding()
proto.type = xc.OpSharding.Type.OTHER
proto.tile_assignment_dimensions = tad_shape
proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat)
proto.iota_reshape_dims = mesh_shape
proto.iota_transpose_perm = tad_perm
proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL]
return proto

View File

@ -1646,6 +1646,10 @@ class ShardMapTest(jtu.JaxTestCase):
v = jnp.arange(32.).reshape(4, 8)
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
self.assertIn(
'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}',
f.lower(v).as_text('hlo'),
)
self.assertAllClose(v*v, f(v), check_dtypes=False)
def test_partial_auto_error_wsc_manual(self):