mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Migrate JAX from producing MHLO to producing StableHLO
As discussed over the last few months, it is desirable to migrate JAX from producing MHLO to producing StableHLO, and this CL makes this happen. More specifically: 1) MLIR lowerings now produce StableHLO ops instead of MHLO ops. 2) Fallback lowerings now produce StableHLO ops as well. 3) Occurrences of "MHLO" in prose have been changed to "StableHLO", unless the documents are immutable (changelog, JEPs). From time to time, it might be useful to produce MHLO directly, so MHLO is not going away and is still within arm's reach (although compatibility guarantees will only be provided for StableHLO and not for MHLO): a) `from jax._src.lib.mlir.dialects import mhlo` still does the same thing. b) `XlaLowering.mhlo()` is available as well, but its implementation has changed - it calls `stablehlo-legalize-to-hlo` underneath. c) `Lowering.as_text()/compiler_ir()` still support `dialect="mhlo"`, but the default has changed to "stablehlo". d) We're still using `mhlo.is_same_data_across_replicas` and `mhlo.sharding` because StableHLO currently lacks comparable functionality. https://github.com/openxla/stablehlo/issues/744 tracks the corresponding work, but it is not a blocker - we can use these attributes with StableHLO without any issues. PiperOrigin-RevId: 497978733
This commit is contained in:
parent
7d452adfd3
commit
a1480c454e
12
docs/aot.md
12
docs/aot.md
@ -20,7 +20,7 @@ are arrays, JAX does the following in order:
|
||||
their shape and element type).
|
||||
|
||||
2. **Lower** this specialized, staged-out computation to the XLA compiler's
|
||||
input language, MHLO.
|
||||
input language, StableHLO.
|
||||
|
||||
3. **Compile** the lowered HLO program to produce an optimized executable for
|
||||
the target device (CPU, GPU, or TPU).
|
||||
@ -45,9 +45,9 @@ way. An example:
|
||||
>>> print(lowered.as_text())
|
||||
module @jit_f.0 {
|
||||
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
%0 = mhlo.constant dense<2> : tensor<i32>
|
||||
%1 = mhlo.multiply %0, %arg0 : tensor<i32>
|
||||
%2 = mhlo.add %1, %arg1 : tensor<i32>
|
||||
%0 = stablehlo.constant dense<2> : tensor<i32>
|
||||
%1 = stablehlo.multiply %0, %arg0 : tensor<i32>
|
||||
%2 = stablehlo.add %1, %arg1 : tensor<i32>
|
||||
return %2 : tensor<i32>
|
||||
}
|
||||
}
|
||||
@ -129,8 +129,8 @@ to invoke the resulting compiled function. Continuing with our example above:
|
||||
>>> print(lowered_with_x.as_text())
|
||||
module @jit_f.1 {
|
||||
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = mhlo.constant dense<14> : tensor<i32>
|
||||
%1 = mhlo.add %0, %arg0 : tensor<i32>
|
||||
%0 = stablehlo.constant dense<14> : tensor<i32>
|
||||
%1 = stablehlo.add %0, %arg0 : tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
}
|
||||
|
@ -53,6 +53,7 @@ from jax._src.sharding import (PmapSharding, SingleDeviceSharding,
|
||||
from jax._src.abstract_arrays import array_types
|
||||
from jax._src.config import config, flags
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import use_stablehlo
|
||||
from jax._src.lib import xla_bridge as xb
|
||||
from jax._src.lib import xla_client as xc
|
||||
import jax._src.util as util
|
||||
@ -982,9 +983,20 @@ class XlaComputation(stages.XlaLowering):
|
||||
use_tuple_args=self.compile_args["tuple_args"])
|
||||
|
||||
def mhlo(self) -> ir.Module:
|
||||
if self.is_trivial():
|
||||
raise ValueError("A trivial computation has no MHLO")
|
||||
return self._hlo
|
||||
if use_stablehlo:
|
||||
return super().mhlo()
|
||||
else:
|
||||
if self.is_trivial():
|
||||
raise ValueError("A trivial computation has no MHLO")
|
||||
return self._hlo
|
||||
|
||||
def stablehlo(self) -> ir.Module:
|
||||
if use_stablehlo:
|
||||
if self.is_trivial():
|
||||
raise ValueError("A trivial computation has no StableHLO")
|
||||
return self._hlo
|
||||
else:
|
||||
return super().stablehlo()
|
||||
|
||||
def compile(self) -> XlaCompiledComputation:
|
||||
if self._executable is None:
|
||||
|
@ -38,7 +38,7 @@ from jax._src.numpy import lax_numpy
|
||||
import jax._src.util as util
|
||||
from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import hlo, use_stablehlo
|
||||
|
||||
unsafe_map, map = map, safe_map # type: ignore
|
||||
|
||||
@ -984,7 +984,7 @@ def _all_to_all_lowering(ctx, x, *,
|
||||
else:
|
||||
other_args = {}
|
||||
return hlo.AllToAllOp(
|
||||
[x],
|
||||
x if use_stablehlo else [x],
|
||||
split_dimension=mlir.i64_attr(split_axis),
|
||||
concat_dimension=mlir.i64_attr(concat_axis),
|
||||
split_count=mlir.i64_attr(split_count),
|
||||
|
@ -24,6 +24,8 @@ from jax.lib import xla_client
|
||||
import jaxlib.mlir.dialects.stablehlo as stablehlo
|
||||
|
||||
# Alias that is set up to abstract away the transition from MHLO to StableHLO.
|
||||
# At the moment, it points to MHLO, but in the future it will start to
|
||||
# conditionally and then unconditionally point to StableHLO.
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
use_stablehlo = xla_client.mlir_api_version >= 42
|
||||
if use_stablehlo:
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
else:
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
|
@ -44,6 +44,7 @@ from jax._src import source_info_util
|
||||
from jax._src import traceback_util
|
||||
from jax._src import util
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import use_stablehlo
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters import xla
|
||||
|
||||
@ -169,7 +170,8 @@ class Lowering(Protocol):
|
||||
compiler.
|
||||
|
||||
Args:
|
||||
dialect: Optional string specifying a representation dialect (e.g. "mhlo")
|
||||
dialect: Optional string specifying a representation dialect
|
||||
(e.g. "stablehlo")
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -264,20 +266,31 @@ class XlaLowering(Lowering):
|
||||
|
||||
def mhlo(self) -> ir.Module:
|
||||
"""Return an MHLO representation of this computation."""
|
||||
raise NotImplementedError("must override")
|
||||
if use_stablehlo:
|
||||
module_str = xla_extension.mlir.stablehlo_to_mhlo(
|
||||
mlir.module_to_bytecode(self.stablehlo()))
|
||||
with self.stablehlo().context:
|
||||
return ir.Module.parse(module_str)
|
||||
else:
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def stablehlo(self) -> ir.Module:
|
||||
"""Return a StableHLO representation of this computation."""
|
||||
module_str = xla_extension.mlir.mhlo_to_stablehlo(
|
||||
mlir.module_to_string(self.mhlo()))
|
||||
with mlir.make_ir_context():
|
||||
return ir.Module.parse(module_str)
|
||||
if use_stablehlo:
|
||||
raise NotImplementedError("must override")
|
||||
else:
|
||||
module_str = xla_extension.mlir.mhlo_to_stablehlo(
|
||||
mlir.module_to_bytecode(self.mhlo()))
|
||||
with self.mhlo().context:
|
||||
return ir.Module.parse(module_str)
|
||||
|
||||
def compile(self) -> Executable:
|
||||
raise NotImplementedError("must override")
|
||||
|
||||
def as_text(self, dialect: Optional[str] = None) -> str:
|
||||
if dialect is None or dialect == "mhlo":
|
||||
if dialect is None:
|
||||
dialect = "stablehlo" if use_stablehlo else "mhlo"
|
||||
if dialect == "mhlo":
|
||||
return str(self.mhlo())
|
||||
elif dialect == "stablehlo":
|
||||
return str(self.stablehlo())
|
||||
@ -287,7 +300,9 @@ class XlaLowering(Lowering):
|
||||
raise ValueError(f"unknown dialect: {dialect}")
|
||||
|
||||
def compiler_ir(self, dialect: Optional[str] = None) -> Any:
|
||||
if dialect is None or dialect == "mhlo":
|
||||
if dialect is None:
|
||||
dialect = "stablehlo" if use_stablehlo else "mhlo"
|
||||
if dialect == "mhlo":
|
||||
return self.mhlo()
|
||||
elif dialect == "stablehlo":
|
||||
return self.stablehlo()
|
||||
@ -579,7 +594,7 @@ class Lowered(Stage):
|
||||
nor reliable serialization. It is relayed directly to external callers.
|
||||
|
||||
Args:
|
||||
dialect: Optional string specifying a lowering dialect (e.g. "mhlo")
|
||||
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
|
||||
"""
|
||||
return self._lowering.as_text(dialect)
|
||||
|
||||
@ -594,7 +609,7 @@ class Lowered(Stage):
|
||||
runtime.
|
||||
|
||||
Args:
|
||||
dialect: Optional string specifying a lowering dialect (e.g. "mhlo")
|
||||
dialect: Optional string specifying a lowering dialect (e.g. "stablehlo")
|
||||
"""
|
||||
try:
|
||||
return self._lowering.compiler_ir(dialect)
|
||||
|
@ -77,7 +77,7 @@ from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.lib import pmap_lib
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lib.mlir.dialects import hlo, use_stablehlo
|
||||
from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list,
|
||||
new_name_stack, wrap_name, assert_unreachable,
|
||||
tuple_insert, tuple_delete, distributed_debug_log,
|
||||
@ -1520,7 +1520,16 @@ class PmapComputation(stages.XlaLowering):
|
||||
use_tuple_args=self.compile_args["tuple_args"])
|
||||
|
||||
def mhlo(self) -> ir.Module:
|
||||
return self._hlo
|
||||
if use_stablehlo:
|
||||
return super().mhlo()
|
||||
else:
|
||||
return self._hlo
|
||||
|
||||
def stablehlo(self) -> ir.Module:
|
||||
if use_stablehlo:
|
||||
return self._hlo
|
||||
else:
|
||||
return super().stablehlo()
|
||||
|
||||
@profiler.annotate_function
|
||||
def compile(self) -> PmapExecutable:
|
||||
@ -3187,9 +3196,20 @@ class MeshComputation(stages.XlaLowering):
|
||||
use_tuple_args=self.compile_args["tuple_args"])
|
||||
|
||||
def mhlo(self) -> ir.Module:
|
||||
if self.is_trivial:
|
||||
raise ValueError("A trivial computation has no MHLO")
|
||||
return self._hlo
|
||||
if use_stablehlo:
|
||||
return super().mhlo()
|
||||
else:
|
||||
if self.is_trivial:
|
||||
raise ValueError("A trivial computation has no MHLO")
|
||||
return self._hlo
|
||||
|
||||
def stablehlo(self) -> ir.Module:
|
||||
if use_stablehlo:
|
||||
if self.is_trivial:
|
||||
raise ValueError("A trivial computation has no StableHLO")
|
||||
return self._hlo
|
||||
else:
|
||||
return super().stablehlo()
|
||||
|
||||
def compile(self,
|
||||
_allow_propagation_to_outputs : bool = False,
|
||||
|
@ -15,7 +15,7 @@
|
||||
from typing import List
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
||||
|
||||
from .hlo_helpers import custom_call
|
||||
|
@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -18,7 +18,7 @@ from functools import partial
|
||||
import operator
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
import numpy as np
|
||||
|
||||
|
||||
|
@ -16,7 +16,7 @@
|
||||
# via CustomCallWithLayout.
|
||||
|
||||
import jaxlib.mlir.ir as ir
|
||||
import jaxlib.mlir.dialects.mhlo as hlo
|
||||
import jaxlib.mlir.dialects.stablehlo as hlo
|
||||
|
||||
import numpy as np
|
||||
from jaxlib import xla_client
|
||||
|
@ -209,10 +209,10 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation1 = str(jax.jit(lambda x, y: x + y)
|
||||
.lower(1, 1)
|
||||
.compiler_ir(dialect="mhlo"))
|
||||
.compiler_ir())
|
||||
computation2 = str(jax.jit(lambda x, y: x * y)
|
||||
.lower(2, 2)
|
||||
.compiler_ir(dialect="mhlo"))
|
||||
.compiler_ir())
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = xla_bridge.get_backend()
|
||||
@ -230,7 +230,7 @@ class CompilationCacheTest(jtu.JaxTestCase):
|
||||
cc.initialize_cache(tmpdir)
|
||||
computation = str(jax.jit(lambda x, y: x + y)
|
||||
.lower(np.int32(1), np.int32(1))
|
||||
.compiler_ir(dialect="mhlo"))
|
||||
.compiler_ir())
|
||||
compile_options = xla_bridge.get_compile_options(
|
||||
num_replicas=1, num_partitions=1)
|
||||
backend = xla_bridge.get_backend()
|
||||
|
@ -923,6 +923,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIsInstance(f.as_text(), str)
|
||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
def testLowerCompilerIR(self):
|
||||
@ -938,6 +939,7 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning)
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
|
@ -276,6 +276,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertIsInstance(f.as_text(), str)
|
||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
||||
|
||||
def testLowerCompilerIR(self):
|
||||
f = self.pmap(lambda x: x - lax.pmean(x, 'i'), axis_name='i')
|
||||
@ -285,6 +286,7 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning)
|
||||
def testLowerCompileCompilerIR(self):
|
||||
|
@ -744,6 +744,7 @@ class XMapTest(XMapTestCase):
|
||||
self.assertIsInstance(f.as_text(), str)
|
||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
||||
|
||||
def testLowerCompilerIR(self):
|
||||
f = xmap(lambda x: x + 4, in_axes=['i', ...], out_axes=['i', ...])
|
||||
@ -752,6 +753,7 @@ class XMapTest(XMapTestCase):
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||
|
||||
@jtu.ignore_warning(category=DeprecationWarning)
|
||||
def testLowerCompileCompilerIR(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user