From 5370ac2ec59c1acb347eb68771beec2487c8de64 Mon Sep 17 00:00:00 2001 From: Vladimir Belitskiy Date: Tue, 1 Apr 2025 11:33:00 -0700 Subject: [PATCH] Remove the try/except for Shardy imports. Shardy has been been included in JAX for a while now. PiperOrigin-RevId: 742778405 --- jax/_src/interpreters/mlir.py | 4 +--- jax/_src/lib/mlir/dialects/__init__.py | 6 +----- jax/extend/mlir/dialects/sdy.py | 6 +----- tests/pjit_test.py | 7 ------- 4 files changed, 3 insertions(+), 20 deletions(-) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 23d1b5dd9..a1b37876f 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -616,9 +616,7 @@ def make_ir_context() -> ir.Context: # we don't do any heavy computation on MLIR modules from Python anyway, so we # just disable threading. context.enable_multithreading(False) - # TODO(bartchr): Once JAX is released with SDY, remove the if. - if dialects.sdy: - dialects.sdy.register_dialect(context) + dialects.sdy.register_dialect(context) dialects.mhlo.register_mhlo_dialect(context) dialects.chlo.register_dialect(context) dialects.hlo.register_dialect(context) diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index a9bae8821..be5317824 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -51,11 +51,7 @@ else: ]) del _lazy -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects import sdy as sdy -except ImportError: - sdy: Any = None # type: ignore[no-redef] +from jaxlib.mlir.dialects import sdy # Alias that is set up to abstract away the transition from MHLO to StableHLO. from jaxlib.mlir.dialects import stablehlo as hlo diff --git a/jax/extend/mlir/dialects/sdy.py b/jax/extend/mlir/dialects/sdy.py index 48586cc26..d83fd90ec 100644 --- a/jax/extend/mlir/dialects/sdy.py +++ b/jax/extend/mlir/dialects/sdy.py @@ -14,8 +14,4 @@ # ruff: noqa: F403 -# TODO(bartchr): Once JAX is released with SDY, remove the try/except. -try: - from jaxlib.mlir.dialects.sdy import * -except ImportError: - pass +from jaxlib.mlir.dialects.sdy import * diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 0b2daee8c..ee4a8cd3e 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -59,7 +59,6 @@ from jax._src.named_sharding import DuplicateSpecError from jax._src import mesh as mesh_lib from jax._src.mesh import AxisType from jax._src.interpreters import pxla -from jax._src.lib.mlir import dialects from jax._src import xla_bridge from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension @@ -8067,12 +8066,6 @@ class UtilTest(jtu.JaxTestCase): @jtu.with_config(jax_use_shardy_partitioner=True) class ShardyTest(jtu.JaxTestCase): - # TODO(bartchr): Once JAX is released with SDY, remove setUp. - def setUp(self): - if not dialects.sdy: - raise unittest.SkipTest('Shardy is not available.') - super().setUp() - def test_lowering_input_output_sharding(self): mesh = jtu.create_mesh((4, 2), ('x', 'y')) np_inp = np.arange(16).reshape(8, 2)