mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Skip some Shardy-enabled tests if XLA < 292.
PiperOrigin-RevId: 686133374
This commit is contained in:
parent
2c2c1eebc7
commit
2f2fd8a334
@ -5503,6 +5503,7 @@ class ShardyTest(jtu.JaxTestCase):
|
||||
|
||||
self.assertIn('sdy.sharding = #sdy.sharding', f.lower(arr).as_text())
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292")
|
||||
def test_lowering_with_sharding_constraint(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
arr = np.arange(16).reshape(4, 2, 2)
|
||||
@ -5528,6 +5529,7 @@ class ShardyTest(jtu.JaxTestCase):
|
||||
self.assertIn('<@mesh, [{"x"}, {?}, {"y"}]>', lowered_str)
|
||||
|
||||
# TODO(bartchr): run on CPU once Shardy is added to the XLA CPU pipeline.
|
||||
@unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292")
|
||||
@jtu.skip_on_devices('cpu')
|
||||
def test_compile_with_inferred_out_sharding(self):
|
||||
mesh = jtu.create_mesh((2, 2), ('x', 'y'))
|
||||
|
@ -37,6 +37,7 @@ from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import prng
|
||||
from jax._src import test_util as jtu
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib.mlir.dialects import sdy
|
||||
from jax._src.util import safe_zip, safe_map, partition_list, merge_lists
|
||||
from jax._src.ad_checkpoint import saved_residuals
|
||||
@ -2640,6 +2641,8 @@ class CustomPartitionerTest(jtu.JaxTestCase):
|
||||
# TODO(phawkins): enable this test unconditionally once shardy is the default.
|
||||
@unittest.skipIf(sdy is None, "shardy is not enabled")
|
||||
class SdyIntegrationTest(jtu.JaxTestCase):
|
||||
|
||||
@unittest.skipIf(xla_extension_version < 292, "Requires XLA version >=292")
|
||||
# Verify we can lower to a `ManualComputationOp`.
|
||||
def test_shardy_collective_permute(self):
|
||||
mesh = jtu.create_mesh((2,), ('x',))
|
||||
|
Loading…
x
Reference in New Issue
Block a user