[SDY] add JAX lowering to Shardy ShardingGroupOp for shard_alike.

PiperOrigin-RevId: 694567084
This commit is contained in:
Bill Varcho 2024-11-08 11:02:09 -08:00 committed by jax authors
parent 4d1a1264f0
commit afd8239ea4
3 changed files with 11 additions and 6 deletions

View File

@ -15,6 +15,7 @@
from functools import partial
import itertools
from jax._src import config
from jax._src import core
from jax._src.interpreters import ad
from jax._src.interpreters import mlir
@ -24,7 +25,7 @@ from jax._src.interpreters import batching
from jax._src.util import safe_zip
from jax._src.lib import xla_client as xc
from jax._src.api_util import shaped_abstractify
from jax._src.lib.mlir import ir
from jax._src.lib.mlir import dialects, ir
_next_shard_group_id = itertools.count()
@ -91,6 +92,11 @@ def _group_shard(
) -> tuple[ir.Value, ir.Value]:
shard_group_id = next(_next_shard_group_id)
if config.use_shardy_partitioner.value:
dialects.sdy.ShardingGroupOp(x, shard_group_id)
dialects.sdy.ShardingGroupOp(y, shard_group_id)
return x, y
unknown_op_sharding = xc.OpSharding()
unknown_op_sharding.type = xc.OpSharding.Type.UNKNOWN
unknown_op_sharding.is_shard_group = True

View File

@ -283,6 +283,7 @@ jax_multiplatform_test(
"tpu_v3_2x2",
"tpu_v5e_4x2",
"tpu_v4_2x2",
"tpu_v3_2x2_shardy",
],
deps = [
"//jax:experimental",

View File

@ -18,6 +18,7 @@ import jax
import jax.numpy as jnp
import numpy as np
from absl.testing import absltest
from jax._src import config
from jax._src import test_util as jtu
from jax.sharding import NamedSharding, PartitionSpec as P
from jax.experimental.shard_alike import shard_alike
@ -221,18 +222,16 @@ class ShardAlikeTest(jtu.JaxTestCase):
mesh = jtu.create_mesh((2,), ('x',))
np_inp = np.arange(8.)
s = NamedSharding(mesh, P('x'))
rep_s = NamedSharding(mesh, P())
arr = jax.device_put(np_inp, s)
arr2 = jax.device_put(np_inp, rep_s)
def f(x, y):
return shard_alike(x, y)
eager_out1, eager_out2 = f(arr, arr2)
eager_out1, eager_out2 = f(arr, np_inp)
self.assertEqual(eager_out1.sharding, s)
self.assertEqual(eager_out2.sharding, s)
out1, out2 = jax.jit(f)(arr, arr2)
out1, out2 = jax.jit(f)(arr, np_inp)
self.assertEqual(out1.sharding, s)
self.assertEqual(out2.sharding, s)
@ -282,6 +281,5 @@ class ShardAlikeTest(jtu.JaxTestCase):
_, y = shard_alike(x, jnp.arange(8))
self.assertEqual(y.sharding, s)
if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())