Migrate JAX from producing MHLO to producing StableHLO

As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically:
  1) MLIR lowerings now produce StableHLO ops instead of MHLO ops.
  2) Fallback lowerings now produce StableHLO ops as well.
  3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs).

From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO):
  a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing.
  b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath.
  c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo".
  d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues.

PiperOrigin-RevId: 497978733
This commit is contained in:
Eugene Burmako 2022-12-27 08:52:39 -08:00 committed by jax authors
parent 7d452adfd3
commit a1480c454e
15 changed files with 92 additions and 37 deletions

View File

@ -20,7 +20,7 @@ are arrays, JAX does the following in order:
their shape and element type).
2. **Lower** this specialized, staged-out computation to the XLA compiler's
input language, MHLO.
input language, StableHLO.
3. **Compile** the lowered HLO program to produce an optimized executable for
the target device (CPU, GPU, or TPU).
@ -45,9 +45,9 @@ way. An example:
>>> print(lowered.as_text())
module @jit_f.0 {
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<2> : tensor<i32>
%1 = mhlo.multiply %0, %arg0 : tensor<i32>
%2 = mhlo.add %1, %arg1 : tensor<i32>
%0 = stablehlo.constant dense<2> : tensor<i32>
%1 = stablehlo.multiply %0, %arg0 : tensor<i32>
%2 = stablehlo.add %1, %arg1 : tensor<i32>
return %2 : tensor<i32>
}
}
@ -129,8 +129,8 @@ to invoke the resulting compiled function. Continuing with our example above:
>>> print(lowered_with_x.as_text())
module @jit_f.1 {
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
%0 = mhlo.constant dense<14> : tensor<i32>
%1 = mhlo.add %0, %arg0 : tensor<i32>
%0 = stablehlo.constant dense<14> : tensor<i32>
%1 = stablehlo.add %0, %arg0 : tensor<i32>
return %1 : tensor<i32>
}
}

View File

@ -53,6 +53,7 @@ from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
from jax._src.abstract_arrays import array_types
from jax._src.config import config, flags
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax._src.lib import xla_bridge as xb
from jax._src.lib import xla_client as xc
import jax._src.util as util
@ -982,9 +983,20 @@ class XlaComputation(stages.XlaLowering):
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> ir.Module:
if self.is_trivial():
raise ValueError("A trivial computation has no MHLO")
return self._hlo
if use_stablehlo:
return super().mhlo()
else:
if self.is_trivial():
raise ValueError("A trivial computation has no MHLO")
return self._hlo
def stablehlo(self) -> ir.Module:
if use_stablehlo:
if self.is_trivial():
raise ValueError("A trivial computation has no StableHLO")
return self._hlo
else:
return super().stablehlo()
def compile(self) -> XlaCompiledComputation:
if self._executable is None:

View File

@ -38,7 +38,7 @@ from jax._src.numpy import lax_numpy
import jax._src.util as util
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import hlo, use_stablehlo
unsafe_map, map = map, safe_map # type: ignore
@ -984,7 +984,7 @@ def _all_to_all_lowering(ctx, x, *,
else:
other_args = {}
return hlo.AllToAllOp(
[x],
x if use_stablehlo else [x],
split_dimension=mlir.i64_attr(split_axis),
concat_dimension=mlir.i64_attr(concat_axis),
split_count=mlir.i64_attr(split_count),

View File

@ -24,6 +24,8 @@ from jax.lib import xla_client
import jaxlib.mlir.dialects.stablehlo as stablehlo
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
# At the moment, it points to MHLO, but in the future it will start to
# conditionally and then unconditionally point to StableHLO.
import jaxlib.mlir.dialects.mhlo as hlo
use_stablehlo = xla_client.mlir_api_version >= 42
if use_stablehlo:
import jaxlib.mlir.dialects.stablehlo as hlo
else:
import jaxlib.mlir.dialects.mhlo as hlo

View File

@ -44,6 +44,7 @@ from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import use_stablehlo
from jax.interpreters import mlir
from jax.interpreters import xla
@ -169,7 +170,8 @@ class Lowering(Protocol):
compiler.
Args:
dialect: Optional string specifying a representation dialect (e.g. "mhlo")
dialect: Optional string specifying a representation dialect
(e.g. "stablehlo")
"""
raise NotImplementedError
@ -264,20 +266,31 @@ class XlaLowering(Lowering):
def mhlo(self) -> ir.Module:
"""Return an MHLO representation of this computation."""
raise NotImplementedError("must override")
if use_stablehlo:
module_str = xla_extension.mlir.stablehlo_to_mhlo(
mlir.module_to_bytecode(self.stablehlo()))
with self.stablehlo().context:
return ir.Module.parse(module_str)
else:
raise NotImplementedError("must override")
def stablehlo(self) -> ir.Module:
"""Return a StableHLO representation of this computation."""
module_str = xla_extension.mlir.mhlo_to_stablehlo(
mlir.module_to_string(self.mhlo()))
with mlir.make_ir_context():
return ir.Module.parse(module_str)
if use_stablehlo:
raise NotImplementedError("must override")
else:
module_str = xla_extension.mlir.mhlo_to_stablehlo(
mlir.module_to_bytecode(self.mhlo()))
with self.mhlo().context:
return ir.Module.parse(module_str)
def compile(self) -> Executable:
raise NotImplementedError("must override")
def as_text(self, dialect: Optional[str] = None) -> str:
if dialect is None or dialect == "mhlo":
if dialect is None:
dialect = "stablehlo" if use_stablehlo else "mhlo"
if dialect == "mhlo":
return str(self.mhlo())
elif dialect == "stablehlo":
return str(self.stablehlo())
@ -287,7 +300,9 @@ class XlaLowering(Lowering):
raise ValueError(f"unknown dialect: {dialect}")
def compiler_ir(self, dialect: Optional[str] = None) -> Any:
if dialect is None or dialect == "mhlo":
if dialect is None:
dialect = "stablehlo" if use_stablehlo else "mhlo"
if dialect == "mhlo":
return self.mhlo()
elif dialect == "stablehlo":
return self.stablehlo()
@ -579,7 +594,7 @@ class Lowered(Stage):
nor reliable serialization. It is relayed directly to external callers.
Args:
dialect: Optional string specifying a lowering dialect (e.g. "mhlo")
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
"""
return self._lowering.as_text(dialect)
@ -594,7 +609,7 @@ class Lowered(Stage):
runtime.
Args:
dialect: Optional string specifying a lowering dialect (e.g. "mhlo")
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
"""
try:
return self._lowering.compiler_ir(dialect)

View File

@ -77,7 +77,7 @@ from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.lib import pmap_lib
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib.mlir.dialects import hlo, use_stablehlo
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
new_name_stack, wrap_name, assert_unreachable,
tuple_insert, tuple_delete, distributed_debug_log,
@ -1520,7 +1520,16 @@ class PmapComputation(stages.XlaLowering):
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> ir.Module:
return self._hlo
if use_stablehlo:
return super().mhlo()
else:
return self._hlo
def stablehlo(self) -> ir.Module:
if use_stablehlo:
return self._hlo
else:
return super().stablehlo()
@profiler.annotate_function
def compile(self) -> PmapExecutable:
@ -3187,9 +3196,20 @@ class MeshComputation(stages.XlaLowering):
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> ir.Module:
if self.is_trivial:
raise ValueError("A trivial computation has no MHLO")
return self._hlo
if use_stablehlo:
return super().mhlo()
else:
if self.is_trivial:
raise ValueError("A trivial computation has no MHLO")
return self._hlo
def stablehlo(self) -> ir.Module:
if use_stablehlo:
if self.is_trivial:
raise ValueError("A trivial computation has no StableHLO")
return self._hlo
else:
return super().stablehlo()
def compile(self,
_allow_propagation_to_outputs : bool = False,

View File

@ -15,7 +15,7 @@
from typing import List
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo
from .hlo_helpers import custom_call

View File

@ -13,7 +13,7 @@
# limitations under the License.
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np

View File

@ -18,7 +18,7 @@ from functools import partial
import operator
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np

View File

@ -16,7 +16,7 @@
from typing import Dict, Optional, Sequence, Union
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np

View File

@ -16,7 +16,7 @@
# via CustomCallWithLayout.
import jaxlib.mlir.ir as ir
import jaxlib.mlir.dialects.mhlo as hlo
import jaxlib.mlir.dialects.stablehlo as hlo
import numpy as np
from jaxlib import xla_client

View File

@ -209,10 +209,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.initialize_cache(tmpdir)
computation1 = str(jax.jit(lambda x, y: x + y)
.lower(1, 1)
.compiler_ir(dialect="mhlo"))
.compiler_ir())
computation2 = str(jax.jit(lambda x, y: x * y)
.lower(2, 2)
.compiler_ir(dialect="mhlo"))
.compiler_ir())
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()
@ -230,7 +230,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
cc.initialize_cache(tmpdir)
computation = str(jax.jit(lambda x, y: x + y)
.lower(np.int32(1), np.int32(1))
.compiler_ir(dialect="mhlo"))
.compiler_ir())
compile_options = xla_bridge.get_compile_options(
num_replicas=1, num_partitions=1)
backend = xla_bridge.get_backend()

View File

@ -923,6 +923,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
@jtu.with_mesh([('x', 2), ('y', 2)])
def testLowerCompilerIR(self):
@ -938,6 +939,7 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
@jtu.ignore_warning(category=DeprecationWarning)
@jtu.with_mesh([('x', 2), ('y', 2)])

View File

@ -276,6 +276,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
def testLowerCompilerIR(self):
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
@ -285,6 +286,7 @@ class PythonPmapTest(jtu.JaxTestCase):
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
@jtu.ignore_warning(category=DeprecationWarning)
def testLowerCompileCompilerIR(self):

View File

@ -744,6 +744,7 @@ class XMapTest(XMapTestCase):
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
def testLowerCompilerIR(self):
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
@ -752,6 +753,7 @@ class XMapTest(XMapTestCase):
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
@jtu.ignore_warning(category=DeprecationWarning)
def testLowerCompileCompilerIR(self):