mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
[SDY] add JAX lowering to Shardy ShardingGroupOp
for shard_alike.
PiperOrigin-RevId: 694567084
This commit is contained in:
parent
4d1a1264f0
commit
afd8239ea4
@ -15,6 +15,7 @@
|
||||
from functools import partial
|
||||
import itertools
|
||||
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import mlir
|
||||
@ -24,7 +25,7 @@ from jax._src.interpreters import batching
|
||||
from jax._src.util import safe_zip
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.api_util import shaped_abstractify
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir import dialects, ir
|
||||
|
||||
_next_shard_group_id = itertools.count()
|
||||
|
||||
@ -91,6 +92,11 @@ def _group_shard(
|
||||
) -> tuple[ir.Value, ir.Value]:
|
||||
shard_group_id = next(_next_shard_group_id)
|
||||
|
||||
if config.use_shardy_partitioner.value:
|
||||
dialects.sdy.ShardingGroupOp(x, shard_group_id)
|
||||
dialects.sdy.ShardingGroupOp(y, shard_group_id)
|
||||
return x, y
|
||||
|
||||
unknown_op_sharding = xc.OpSharding()
|
||||
unknown_op_sharding.type = xc.OpSharding.Type.UNKNOWN
|
||||
unknown_op_sharding.is_shard_group = True
|
||||
|
@ -283,6 +283,7 @@ jax_multiplatform_test(
|
||||
"tpu_v3_2x2",
|
||||
"tpu_v5e_4x2",
|
||||
"tpu_v4_2x2",
|
||||
"tpu_v3_2x2_shardy",
|
||||
],
|
||||
deps = [
|
||||
"//jax:experimental",
|
||||
|
@ -18,6 +18,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
import numpy as np
|
||||
from absl.testing import absltest
|
||||
from jax._src import config
|
||||
from jax._src import test_util as jtu
|
||||
from jax.sharding import NamedSharding, PartitionSpec as P
|
||||
from jax.experimental.shard_alike import shard_alike
|
||||
@ -221,18 +222,16 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
np_inp = np.arange(8.)
|
||||
s = NamedSharding(mesh, P('x'))
|
||||
rep_s = NamedSharding(mesh, P())
|
||||
arr = jax.device_put(np_inp, s)
|
||||
arr2 = jax.device_put(np_inp, rep_s)
|
||||
|
||||
def f(x, y):
|
||||
return shard_alike(x, y)
|
||||
|
||||
eager_out1, eager_out2 = f(arr, arr2)
|
||||
eager_out1, eager_out2 = f(arr, np_inp)
|
||||
self.assertEqual(eager_out1.sharding, s)
|
||||
self.assertEqual(eager_out2.sharding, s)
|
||||
|
||||
out1, out2 = jax.jit(f)(arr, arr2)
|
||||
out1, out2 = jax.jit(f)(arr, np_inp)
|
||||
self.assertEqual(out1.sharding, s)
|
||||
self.assertEqual(out2.sharding, s)
|
||||
|
||||
@ -282,6 +281,5 @@ class ShardAlikeTest(jtu.JaxTestCase):
|
||||
_, y = shard_alike(x, jnp.arange(8))
|
||||
self.assertEqual(y.sharding, s)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user