1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 05:16:06 +00:00

Skip some Shardy-enabled tests if XLA < 292.

PiperOrigin-RevId: 686133374
This commit is contained in:
Vladimir Belitskiy 2024-10-15 09:28:26 -07:00 committed by jax authors
parent 2c2c1eebc7
commit 2f2fd8a334
2 changed files with 5 additions and 0 deletions

@ -5503,6 +5503,7 @@ class ShardyTest(jtu.JaxTestCase):
self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text())
@unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292")
def test_lowering_with_sharding_constraint(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
arr = np.arange(16).reshape(4, 2, 2)
@ -5528,6 +5529,7 @@ class ShardyTest(jtu.JaxTestCase):
self.assertIn('<@mesh, [{"x"}, {?}, {"y"}]>', lowered_str)
# TODO(bartchr): run on CPU once Shardy is added to the XLA CPU pipeline.
@unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292")
@jtu.skip_on_devices('cpu')
def test_compile_with_inferred_out_sharding(self):
mesh = jtu.create_mesh((2, 2), ('x', 'y'))

@ -37,6 +37,7 @@ from jax._src import config
from jax._src import core
from jax._src import prng
from jax._src import test_util as jtu
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
from jax._src.ad_checkpoint import saved_residuals
@ -2640,6 +2641,8 @@ class CustomPartitionerTest(jtu.JaxTestCase):
# TODO(phawkins): enable this test unconditionally once shardy is the default.
@unittest.skipIf(sdy is None, "shardy is not enabled")
class SdyIntegrationTest(jtu.JaxTestCase):
@unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292")
# Verify we can lower to a `ManualComputationOp`.
def test_shardy_collective_permute(self):
mesh = jtu.create_mesh((2,), ('x',))