diff --git a/docs/aot.md b/docs/aot.md index 766f4b583..efcfe3ea6 100644 --- a/docs/aot.md +++ b/docs/aot.md @@ -20,7 +20,7 @@ are arrays, JAX does the following in order: their shape and element type). 2. **Lower** this specialized, staged-out computation to the XLA compiler's - input language, MHLO. + input language, StableHLO. 3. **Compile** the lowered HLO program to produce an optimized executable for the target device (CPU, GPU, or TPU). @@ -45,9 +45,9 @@ way. An example: >>> print(lowered.as_text()) module @jit_f.0 { func.func public @main(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = mhlo.constant dense<2> : tensor - %1 = mhlo.multiply %0, %arg0 : tensor - %2 = mhlo.add %1, %arg1 : tensor + %0 = stablehlo.constant dense<2> : tensor + %1 = stablehlo.multiply %0, %arg0 : tensor + %2 = stablehlo.add %1, %arg1 : tensor return %2 : tensor } } @@ -129,8 +129,8 @@ to invoke the resulting compiled function. Continuing with our example above: >>> print(lowered_with_x.as_text()) module @jit_f.1 { func.func public @main(%arg0: tensor) -> tensor { - %0 = mhlo.constant dense<14> : tensor - %1 = mhlo.add %0, %arg0 : tensor + %0 = stablehlo.constant dense<14> : tensor + %1 = stablehlo.add %0, %arg0 : tensor return %1 : tensor } } diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 6a4a351ec..776d7d833 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -53,6 +53,7 @@ from jax._src.sharding import (PmapSharding, SingleDeviceSharding, from jax._src.abstract_arrays import array_types from jax._src.config import config, flags from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import use_stablehlo from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client as xc import jax._src.util as util @@ -982,9 +983,20 @@ class XlaComputation(stages.XlaLowering): use_tuple_args=self.compile_args["tuple_args"]) def mhlo(self) -> ir.Module: - if self.is_trivial(): - raise ValueError("A trivial computation has no MHLO") - return self._hlo + if use_stablehlo: + return super().mhlo() + else: + if self.is_trivial(): + raise ValueError("A trivial computation has no MHLO") + return self._hlo + + def stablehlo(self) -> ir.Module: + if use_stablehlo: + if self.is_trivial(): + raise ValueError("A trivial computation has no StableHLO") + return self._hlo + else: + return super().stablehlo() def compile(self) -> XlaCompiledComputation: if self._executable is None: diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index d17116f65..cb1019b2b 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -38,7 +38,7 @@ from jax._src.numpy import lax_numpy import jax._src.util as util from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import hlo +from jax._src.lib.mlir.dialects import hlo, use_stablehlo unsafe_map, map = map, safe_map # type: ignore @@ -984,7 +984,7 @@ def _all_to_all_lowering(ctx, x, *, else: other_args = {} return hlo.AllToAllOp( - [x], + x if use_stablehlo else [x], split_dimension=mlir.i64_attr(split_axis), concat_dimension=mlir.i64_attr(concat_axis), split_count=mlir.i64_attr(split_count), diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index 564cd0bd8..9cf5f79c9 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -24,6 +24,8 @@ from jax.lib import xla_client import jaxlib.mlir.dialects.stablehlo as stablehlo # Alias that is set up to abstract away the transition from MHLO to StableHLO. -# At the moment, it points to MHLO, but in the future it will start to -# conditionally and then unconditionally point to StableHLO. -import jaxlib.mlir.dialects.mhlo as hlo +use_stablehlo = xla_client.mlir_api_version >= 42 +if use_stablehlo: + import jaxlib.mlir.dialects.stablehlo as hlo +else: + import jaxlib.mlir.dialects.mhlo as hlo diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 6aa8468c7..3c6b543b7 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -44,6 +44,7 @@ from jax._src import source_info_util from jax._src import traceback_util from jax._src import util from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import use_stablehlo from jax.interpreters import mlir from jax.interpreters import xla @@ -169,7 +170,8 @@ class Lowering(Protocol): compiler. Args: - dialect: Optional string specifying a representation dialect (e.g. "mhlo") + dialect: Optional string specifying a representation dialect + (e.g. "stablehlo") """ raise NotImplementedError @@ -264,20 +266,31 @@ class XlaLowering(Lowering): def mhlo(self) -> ir.Module: """Return an MHLO representation of this computation.""" - raise NotImplementedError("must override") + if use_stablehlo: + module_str = xla_extension.mlir.stablehlo_to_mhlo( + mlir.module_to_bytecode(self.stablehlo())) + with self.stablehlo().context: + return ir.Module.parse(module_str) + else: + raise NotImplementedError("must override") def stablehlo(self) -> ir.Module: """Return a StableHLO representation of this computation.""" - module_str = xla_extension.mlir.mhlo_to_stablehlo( - mlir.module_to_string(self.mhlo())) - with mlir.make_ir_context(): - return ir.Module.parse(module_str) + if use_stablehlo: + raise NotImplementedError("must override") + else: + module_str = xla_extension.mlir.mhlo_to_stablehlo( + mlir.module_to_bytecode(self.mhlo())) + with self.mhlo().context: + return ir.Module.parse(module_str) def compile(self) -> Executable: raise NotImplementedError("must override") def as_text(self, dialect: Optional[str] = None) -> str: - if dialect is None or dialect == "mhlo": + if dialect is None: + dialect = "stablehlo" if use_stablehlo else "mhlo" + if dialect == "mhlo": return str(self.mhlo()) elif dialect == "stablehlo": return str(self.stablehlo()) @@ -287,7 +300,9 @@ class XlaLowering(Lowering): raise ValueError(f"unknown dialect: {dialect}") def compiler_ir(self, dialect: Optional[str] = None) -> Any: - if dialect is None or dialect == "mhlo": + if dialect is None: + dialect = "stablehlo" if use_stablehlo else "mhlo" + if dialect == "mhlo": return self.mhlo() elif dialect == "stablehlo": return self.stablehlo() @@ -579,7 +594,7 @@ class Lowered(Stage): nor reliable serialization. It is relayed directly to external callers. Args: - dialect: Optional string specifying a lowering dialect (e.g. "mhlo") + dialect: Optional string specifying a lowering dialect (e.g. "stablehlo") """ return self._lowering.as_text(dialect) @@ -594,7 +609,7 @@ class Lowered(Stage): runtime. Args: - dialect: Optional string specifying a lowering dialect (e.g. "mhlo") + dialect: Optional string specifying a lowering dialect (e.g. "stablehlo") """ try: return self._lowering.compiler_ir(dialect) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 0abcf9853..de4d843cf 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -77,7 +77,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib import xla_extension_version from jax._src.lib import pmap_lib from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import hlo +from jax._src.lib.mlir.dialects import hlo, use_stablehlo from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list, new_name_stack, wrap_name, assert_unreachable, tuple_insert, tuple_delete, distributed_debug_log, @@ -1520,7 +1520,16 @@ class PmapComputation(stages.XlaLowering): use_tuple_args=self.compile_args["tuple_args"]) def mhlo(self) -> ir.Module: - return self._hlo + if use_stablehlo: + return super().mhlo() + else: + return self._hlo + + def stablehlo(self) -> ir.Module: + if use_stablehlo: + return self._hlo + else: + return super().stablehlo() @profiler.annotate_function def compile(self) -> PmapExecutable: @@ -3187,9 +3196,20 @@ class MeshComputation(stages.XlaLowering): use_tuple_args=self.compile_args["tuple_args"]) def mhlo(self) -> ir.Module: - if self.is_trivial: - raise ValueError("A trivial computation has no MHLO") - return self._hlo + if use_stablehlo: + return super().mhlo() + else: + if self.is_trivial: + raise ValueError("A trivial computation has no MHLO") + return self._hlo + + def stablehlo(self) -> ir.Module: + if use_stablehlo: + if self.is_trivial: + raise ValueError("A trivial computation has no StableHLO") + return self._hlo + else: + return super().stablehlo() def compile(self, _allow_propagation_to_outputs : bool = False, diff --git a/jaxlib/ducc_fft.py b/jaxlib/ducc_fft.py index fd1c63d98..e96a4596f 100644 --- a/jaxlib/ducc_fft.py +++ b/jaxlib/ducc_fft.py @@ -15,7 +15,7 @@ from typing import List import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as hlo +import jaxlib.mlir.dialects.stablehlo as hlo from .hlo_helpers import custom_call diff --git a/jaxlib/gpu_rnn.py b/jaxlib/gpu_rnn.py index 40654dd1e..f647d5ba5 100644 --- a/jaxlib/gpu_rnn.py +++ b/jaxlib/gpu_rnn.py @@ -13,7 +13,7 @@ # limitations under the License. import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as hlo +import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 7f6ca0414..6a102715b 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -18,7 +18,7 @@ from functools import partial import operator import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as hlo +import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np diff --git a/jaxlib/hlo_helpers.py b/jaxlib/hlo_helpers.py index e029df9b8..38100bf1f 100644 --- a/jaxlib/hlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -16,7 +16,7 @@ from typing import Dict, Optional, Sequence, Union import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as hlo +import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index f728cff0b..eb37972a8 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -16,7 +16,7 @@ # via CustomCallWithLayout. import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as hlo +import jaxlib.mlir.dialects.stablehlo as hlo import numpy as np from jaxlib import xla_client diff --git a/tests/compilation_cache_test.py b/tests/compilation_cache_test.py index f99a62256..6c11da508 100644 --- a/tests/compilation_cache_test.py +++ b/tests/compilation_cache_test.py @@ -209,10 +209,10 @@ class CompilationCacheTest(jtu.JaxTestCase): cc.initialize_cache(tmpdir) computation1 = str(jax.jit(lambda x, y: x + y) .lower(1, 1) - .compiler_ir(dialect="mhlo")) + .compiler_ir()) computation2 = str(jax.jit(lambda x, y: x * y) .lower(2, 2) - .compiler_ir(dialect="mhlo")) + .compiler_ir()) compile_options = xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = xla_bridge.get_backend() @@ -230,7 +230,7 @@ class CompilationCacheTest(jtu.JaxTestCase): cc.initialize_cache(tmpdir) computation = str(jax.jit(lambda x, y: x + y) .lower(np.int32(1), np.int32(1)) - .compiler_ir(dialect="mhlo")) + .compiler_ir()) compile_options = xla_bridge.get_compile_options( num_replicas=1, num_partitions=1) backend = xla_bridge.get_backend() diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 8f3e43714..a78c851aa 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -923,6 +923,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertIsInstance(f.as_text(), str) self.assertIsInstance(f.as_text(dialect='hlo'), str) self.assertIsInstance(f.as_text(dialect='mhlo'), str) + self.assertIsInstance(f.as_text(dialect='stablehlo'), str) @jtu.with_mesh([('x', 2), ('y', 2)]) def testLowerCompilerIR(self): @@ -938,6 +939,7 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertIsNotNone(f.compiler_ir()) self.assertIsNotNone(f.compiler_ir(dialect='hlo')) self.assertIsNotNone(f.compiler_ir(dialect='mhlo')) + self.assertIsNotNone(f.compiler_ir(dialect='stablehlo')) @jtu.ignore_warning(category=DeprecationWarning) @jtu.with_mesh([('x', 2), ('y', 2)]) diff --git a/tests/pmap_test.py b/tests/pmap_test.py index 7da08530b..a488092ae 100644 --- a/tests/pmap_test.py +++ b/tests/pmap_test.py @@ -276,6 +276,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertIsInstance(f.as_text(), str) self.assertIsInstance(f.as_text(dialect='hlo'), str) self.assertIsInstance(f.as_text(dialect='mhlo'), str) + self.assertIsInstance(f.as_text(dialect='stablehlo'), str) def testLowerCompilerIR(self): f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i') @@ -285,6 +286,7 @@ class PythonPmapTest(jtu.JaxTestCase): self.assertIsNotNone(f.compiler_ir()) self.assertIsNotNone(f.compiler_ir(dialect='hlo')) self.assertIsNotNone(f.compiler_ir(dialect='mhlo')) + self.assertIsNotNone(f.compiler_ir(dialect='stablehlo')) @jtu.ignore_warning(category=DeprecationWarning) def testLowerCompileCompilerIR(self): diff --git a/tests/xmap_test.py b/tests/xmap_test.py index ce0763ef5..0e773e5a4 100644 --- a/tests/xmap_test.py +++ b/tests/xmap_test.py @@ -744,6 +744,7 @@ class XMapTest(XMapTestCase): self.assertIsInstance(f.as_text(), str) self.assertIsInstance(f.as_text(dialect='hlo'), str) self.assertIsInstance(f.as_text(dialect='mhlo'), str) + self.assertIsInstance(f.as_text(dialect='stablehlo'), str) def testLowerCompilerIR(self): f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...]) @@ -752,6 +753,7 @@ class XMapTest(XMapTestCase): self.assertIsNotNone(f.compiler_ir()) self.assertIsNotNone(f.compiler_ir(dialect='hlo')) self.assertIsNotNone(f.compiler_ir(dialect='mhlo')) + self.assertIsNotNone(f.compiler_ir(dialect='stablehlo')) @jtu.ignore_warning(category=DeprecationWarning) def testLowerCompileCompilerIR(self):