mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Use iota_reshape_dims and iota_transpose_perm in pxla, which is more efficient than tile_assignment_devices.
HloSharding V1 -> HloSharding V2. PiperOrigin-RevId: 639210975
This commit is contained in:
parent
1edb94ec46
commit
dfd7d17c1d
@ -1545,11 +1545,11 @@ def manual_proto(
|
||||
tad_shape.append(math.prod([named_mesh_shape[a] for a in replicated_axes]))
|
||||
tad_shape.append(math.prod([named_mesh_shape[a] for a in manual_axes]))
|
||||
|
||||
raw_mesh = np.arange(math.prod(mesh_shape)).reshape(mesh_shape)
|
||||
proto = xc.OpSharding()
|
||||
proto.type = xc.OpSharding.Type.OTHER
|
||||
proto.tile_assignment_dimensions = tad_shape
|
||||
proto.tile_assignment_devices = list(raw_mesh.transpose(tad_perm).reshape(tad_shape).flat)
|
||||
proto.iota_reshape_dims = mesh_shape
|
||||
proto.iota_transpose_perm = tad_perm
|
||||
proto.last_tile_dims = [xc.OpSharding.Type.REPLICATED, xc.OpSharding.Type.MANUAL]
|
||||
return proto
|
||||
|
||||
|
@ -1646,6 +1646,10 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
|
||||
v = jnp.arange(32.).reshape(4, 8)
|
||||
v = jax.device_put(v, jax.sharding.NamedSharding(mesh, P('i', 'j')))
|
||||
self.assertIn(
|
||||
'sharding={devices=[1,1,2,2]<=[4] last_tile_dims={manual, replicated}}',
|
||||
f.lower(v).as_text('hlo'),
|
||||
)
|
||||
self.assertAllClose(v*v, f(v), check_dtypes=False)
|
||||
|
||||
def test_partial_auto_error_wsc_manual(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user