Remove the try/except for Shardy imports.

Shardy has been been included in JAX for a while now.

PiperOrigin-RevId: 742778405
This commit is contained in:
Vladimir Belitskiy 2025-04-01 11:33:00 -07:00 committed by jax authors
parent f4c727abb3
commit 5370ac2ec5
4 changed files with 3 additions and 20 deletions

View File

@ -616,9 +616,7 @@ def make_ir_context() -> ir.Context:
# we don't do any heavy computation on MLIR modules from Python anyway, so we
# just disable threading.
context.enable_multithreading(False)
# TODO(bartchr): Once JAX is released with SDY, remove the if.
if dialects.sdy:
dialects.sdy.register_dialect(context)
dialects.sdy.register_dialect(context)
dialects.mhlo.register_mhlo_dialect(context)
dialects.chlo.register_dialect(context)
dialects.hlo.register_dialect(context)

View File

@ -51,11 +51,7 @@ else:
])
del _lazy
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
try:
from jaxlib.mlir.dialects import sdy as sdy
except ImportError:
sdy: Any = None # type: ignore[no-redef]
from jaxlib.mlir.dialects import sdy
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
from jaxlib.mlir.dialects import stablehlo as hlo

View File

@ -14,8 +14,4 @@
# ruff: noqa: F403
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
try:
from jaxlib.mlir.dialects.sdy import *
except ImportError:
pass
from jaxlib.mlir.dialects.sdy import *

View File

@ -59,7 +59,6 @@ from jax._src.named_sharding import DuplicateSpecError
from jax._src import mesh as mesh_lib
from jax._src.mesh import AxisType
from jax._src.interpreters import pxla
from jax._src.lib.mlir import dialects
from jax._src import xla_bridge
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension
@ -8067,12 +8066,6 @@ class UtilTest(jtu.JaxTestCase):
@jtu.with_config(jax_use_shardy_partitioner=True)
class ShardyTest(jtu.JaxTestCase):
# TODO(bartchr): Once JAX is released with SDY, remove setUp.
def setUp(self):
if not dialects.sdy:
raise unittest.SkipTest('Shardy is not available.')
super().setUp()
def test_lowering_input_output_sharding(self):
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
np_inp = np.arange(16).reshape(8, 2)