mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Remove the try/except for Shardy imports.
Shardy has been been included in JAX for a while now. PiperOrigin-RevId: 742778405
This commit is contained in:
parent
f4c727abb3
commit
5370ac2ec5
@ -616,9 +616,7 @@ def make_ir_context() -> ir.Context:
|
||||
# we don't do any heavy computation on MLIR modules from Python anyway, so we
|
||||
# just disable threading.
|
||||
context.enable_multithreading(False)
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove the if.
|
||||
if dialects.sdy:
|
||||
dialects.sdy.register_dialect(context)
|
||||
dialects.sdy.register_dialect(context)
|
||||
dialects.mhlo.register_mhlo_dialect(context)
|
||||
dialects.chlo.register_dialect(context)
|
||||
dialects.hlo.register_dialect(context)
|
||||
|
@ -51,11 +51,7 @@ else:
|
||||
])
|
||||
del _lazy
|
||||
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
|
||||
try:
|
||||
from jaxlib.mlir.dialects import sdy as sdy
|
||||
except ImportError:
|
||||
sdy: Any = None # type: ignore[no-redef]
|
||||
from jaxlib.mlir.dialects import sdy
|
||||
|
||||
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
|
||||
from jaxlib.mlir.dialects import stablehlo as hlo
|
||||
|
@ -14,8 +14,4 @@
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove the try/except.
|
||||
try:
|
||||
from jaxlib.mlir.dialects.sdy import *
|
||||
except ImportError:
|
||||
pass
|
||||
from jaxlib.mlir.dialects.sdy import *
|
||||
|
@ -59,7 +59,6 @@ from jax._src.named_sharding import DuplicateSpecError
|
||||
from jax._src import mesh as mesh_lib
|
||||
from jax._src.mesh import AxisType
|
||||
from jax._src.interpreters import pxla
|
||||
from jax._src.lib.mlir import dialects
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension
|
||||
@ -8067,12 +8066,6 @@ class UtilTest(jtu.JaxTestCase):
|
||||
@jtu.with_config(jax_use_shardy_partitioner=True)
|
||||
class ShardyTest(jtu.JaxTestCase):
|
||||
|
||||
# TODO(bartchr): Once JAX is released with SDY, remove setUp.
|
||||
def setUp(self):
|
||||
if not dialects.sdy:
|
||||
raise unittest.SkipTest('Shardy is not available.')
|
||||
super().setUp()
|
||||
|
||||
def test_lowering_input_output_sharding(self):
|
||||
mesh = jtu.create_mesh((4, 2), ('x', 'y'))
|
||||
np_inp = np.arange(16).reshape(8, 2)
|
||||
|
Loading…
x
Reference in New Issue
Block a user