Deprecate support for the mhlo dialect.

JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing.

PiperOrigin-RevId: 598550225
This commit is contained in:
Peter Hawkins 2024-01-15 02:12:52 -08:00 committed by jax authors
parent 912a5ef771
commit e558feaa5e
20 changed files with 88 additions and 87 deletions

View File

@ -66,6 +66,9 @@ Remember to align the itemized text with the first line of an item within a list
removed. Use {func}`jax.random.key_data` instead.
* `bool(empty_array)` now raises an error rather than returning `False`. This
previously raised a deprecation warning, and follows a similar change in NumPy.
* Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses
the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
removed in the future. Use the "stablehlo" dialect instead.
## jaxlib 0.4.24

View File

@ -154,7 +154,7 @@ def f(runtime_token, x):
```
Notice how the runtime tokens are only used at the JIT boundary and the compiler tokens
are only within the compiled code. Compiler tokens are created during
"lowering" (we convert Python code to a lower level representation like HLO or MHLO)
"lowering" (we convert Python code to a lower level representation like HLO or StableHLO)
but runtime tokens need to be managed in Python since they're being threaded in and out
of JIT-ted functions.
@ -200,7 +200,7 @@ this copy is not necessary.
## Adding compiler tokens
When we lower Python code to HLO or MHLO we need to create a token at the start of the computation and
When we lower Python code to HLO or StableHLO we need to create a token at the start of the computation and
ensure they are available when we have side-effecting computations that need to be ordered. The side-effecting
computations will take the token as input and return it as an output.

View File

@ -350,9 +350,9 @@ out and end-to-end compile a `shmap`ped function, just put a `jit` around it. A
consequence is that `shmap` doesn't have its own dispatch and compilation paths
like `xmap` and `pmap` currently do; it's just the `jit` path.
When it's staged out by e.g. an enclosing `jit`, the lowering of `shmap` to MHLO
is trivial: it just involves switching into 'manual SPMD mode' on the inputs,
and switching back on the outputs. (We don't currently plan to support
When it's staged out by e.g. an enclosing `jit`, the lowering of `shmap` to
StableHLO is trivial: it just involves switching into 'manual SPMD mode' on the
inputs, and switching back on the outputs. (We don't currently plan to support
partially-manual-partially-automatic modes.)
The interaction with effects is the same as with `pmap`.

View File

@ -7,7 +7,7 @@ Python wheel, and `jaxlib`, which is a mostly-C++ wheel that contains libraries
such as:
* XLA,
* pieces of LLVM used by XLA,
* MLIR infrastructure, such as the MHLO Python bindings.
* MLIR infrastructure, such as the StableHLO Python bindings.
* JAX-specific C++ libraries for fast JIT and PyTree manipulation.
We distribute separate `jax` and `jaxlib` packages because it makes it easy to

View File

@ -217,7 +217,7 @@ pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)
#### Emulation mode
By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to MHLO directly and compile/execute them with XLA. Specifically, a `pallas_call` can be implemented as a `lax.scan` over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like `jax.debug.print`). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU.
By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to StableHLO directly and compile/execute them with XLA. Specifically, a `pallas_call` can be implemented as a `lax.scan` over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like `jax.debug.print`). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU.
### Examples

View File

@ -216,7 +216,7 @@ class Config:
self.jax_enable_memories,
self.jax_disable_jit,
self.jax_xla_profile_version,
# Technically this affects jaxpr->MHLO lowering, not tracing.
# Technically this affects jaxpr->stablehlo lowering, not tracing.
self.jax_hlo_source_file_canonicalization_regex)

View File

@ -33,6 +33,7 @@ from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, Union
import warnings
import jax
@ -318,6 +319,11 @@ class XlaLowering(Lowering):
def mhlo(self) -> ir.Module:
"""Return an MHLO representation of this computation."""
warnings.warn(
"mhlo support is deprecated and will be removed "
"from a future release of JAX. Use stablehlo instead.",
DeprecationWarning,
)
module_str = xla_extension.mlir.stablehlo_to_mhlo(
mlir.module_to_bytecode(self.stablehlo()))
with self.stablehlo().context:

View File

@ -275,7 +275,7 @@ def count_jit_and_pmap_compiles():
@contextmanager
def count_subjaxpr_to_mhlo_conversion(fun_name: str):
def count_subjaxpr_to_hlo_conversion(fun_name: str):
# No need to clear any caches since we generally jit and pmap fresh callables
# in tests.

View File

@ -888,7 +888,7 @@ def _check_module(mod: ir.Module, *,
if op_name == "func.func":
check_sharding(op.operation, op.location)
elif op_name == "stablehlo.custom_call" or op_name == "mhlo.custom_call":
elif op_name == "stablehlo.custom_call":
call_target_name_attr = op.operation.attributes["call_target_name"]
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
disallowed_custom_call_ops.append(f"{op} at {op.location}")

View File

@ -2708,7 +2708,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
harness.check_result = False
if harness.group_name == "vmap_tan":
# Tan (b/274462307) require support for custom call mhlo.tan.
# Tan (b/274462307) require support for custom call stablehlo.tan.
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")

View File

@ -107,7 +107,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
num_replicas=1, num_partitions=2):
"""Log the HLO generated from JAX before and after optimizations"""
jax_comp = f_jax.lower(*args).compiler_ir(dialect="mhlo")
jax_comp = f_jax.lower(*args).compiler_ir(dialect="stablehlo")
jax_hlo = str(jax_comp)
logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo)

View File

@ -14,4 +14,9 @@
# ruff: noqa: F403
import warnings
warnings.warn("jax.extend.mlir.dialects.mhlo is deprecated and will be removed "
"from a future release of JAX. Use stablehlo instead.",
DeprecationWarning)
from jaxlib.mlir.dialects.mhlo import *

View File

@ -83,6 +83,7 @@ filterwarnings = [
"ignore:not machine-readable.*:UserWarning",
"ignore:Special cases found for .* but none were parsed.*:UserWarning",
# end array_api_tests-related warnings
"ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
]
doctest_optionflags = [
"NUMBER",

View File

@ -1129,21 +1129,18 @@ class JitTest(jtu.BufferDonationTestCase):
f = jit(lambda x: x + 4).lower(1.)
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect="stablehlo"), str)
def test_jit_lower_compiler_ir(self):
f = jit(lambda x: x + 4).lower(1.)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
def test_jit_lower_trivial_compiler_ir(self):
f = jit(lambda x: x).lower(1.)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
def test_jit_replica_attributes(self):
@ -1216,13 +1213,13 @@ class JitTest(jtu.BufferDonationTestCase):
return y['hi'] + args[1] + sum(kwargs.values())
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
self.assertNotIn("\"x\"", mhlo_str)
self.assertIn("y['hi']", mhlo_str)
self.assertNotIn("args[0]", mhlo_str)
self.assertIn("args[1]", mhlo_str)
self.assertIn("kwargs['z']", mhlo_str)
self.assertIn("kwargs['w']", mhlo_str)
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
@parameterized.parameters([0, 2, [(0, 2)]])
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
@ -1230,14 +1227,14 @@ class JitTest(jtu.BufferDonationTestCase):
return y['hi'] + args[1] + sum(kwargs.values())
ir = jax.jit(f, static_argnums=static_argnums).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.).compiler_ir('mhlo')
mhlo_str = mlir.module_to_string(ir)
self.assertNotIn("\"x\"", mhlo_str)
self.assertIn("y['hi']", mhlo_str)
self.assertNotIn("args[0]", mhlo_str)
self.assertIn("args[1]", mhlo_str)
self.assertIn("kwargs['z']", mhlo_str)
self.assertIn("kwargs['w']", mhlo_str)
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.).compiler_ir('stablehlo')
hlo_str = mlir.module_to_string(ir)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
@parameterized.parameters(['a', 'b', [('a', 'b')]])
def test_jit_lower_arg_info_static_argnames(self, static_argnames):
@ -1245,25 +1242,25 @@ class JitTest(jtu.BufferDonationTestCase):
return y['hi'] + args[1] + kwargs['z'] + kwargs['w']
ir = jax.jit(f, static_argnames=static_argnames).lower(
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.).compiler_ir('mhlo')
mhlo_str = mlir.module_to_string(ir)
self.assertNotIn("\"x\"", mhlo_str)
self.assertIn("y['hi']", mhlo_str)
self.assertNotIn("args[0]", mhlo_str)
self.assertIn("args[1]", mhlo_str)
self.assertIn("kwargs['z']", mhlo_str)
self.assertIn("kwargs['w']", mhlo_str)
self.assertNotIn("kwargs['a']", mhlo_str)
self.assertNotIn("kwargs['b']", mhlo_str)
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.).compiler_ir('stablehlo')
hlo_str = mlir.module_to_string(ir)
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
self.assertNotIn("kwargs['a']", hlo_str)
self.assertNotIn("kwargs['b']", hlo_str)
def test_jit_lower_result_info(self):
def f(x, y, z):
return {'a': x, 'b': [y]}
ir = jax.jit(f).lower(1., (2,), [3]).compiler_ir('mhlo')
mhlo_str = mlir.module_to_string(ir)
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
ir = jax.jit(f).lower(1., (2,), [3]).compiler_ir('stablehlo')
hlo_str = mlir.module_to_string(ir)
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
def test_jit_lower_compile_with_compiler_options(self):
def f(x):
@ -2786,9 +2783,6 @@ class APITest(jtu.JaxTestCase):
hlo = api.jit(e).lower(2.).compiler_ir(dialect="hlo").as_hlo_text()
self.assertIn(' cosine', hlo)
self.assertIn(' sine', hlo)
mhlo = str(api.jit(e).lower(2.).compiler_ir(dialect="mhlo"))
self.assertIn('mhlo.cosine', mhlo)
self.assertIn('mhlo.sine', mhlo)
stablehlo = str(api.jit(e).lower(2.).compiler_ir(dialect="stablehlo"))
self.assertIn("stablehlo.cosine", stablehlo)
self.assertIn("stablehlo.sine", stablehlo)

View File

@ -311,17 +311,15 @@ class JaxExportTest(jtu.JaxTestCase):
@jtu.parameterized_filterable(
testcase_name=lambda kw: kw["dialect"],
kwargs=[dict(dialect=dialect)
for dialect in ("mhlo", "stablehlo")]
for dialect in ("stablehlo",)]
)
def test_error_disallowed_custom_call(self, dialect):
# If we use hlo.custom_call or mhlo.custom_call we detect
# invalid custom call targets.
# If we use hlo.custom_call we detect invalid custom call targets.
# Set up a primitive with custom lowering rules
test_primitive = core.Primitive("_test_primitive_disallowed_custom_call")
test_primitive.def_abstract_eval(lambda in_aval: in_aval)
def test_primitive_lowering(ctx, arg):
from jax._src.lib.mlir.dialects import mhlo
op = dict(stablehlo=hlo.CustomCallOp, mhlo=mhlo.CustomCallOp)[dialect]
op = dict(stablehlo=hlo.CustomCallOp)[dialect]
return op([arg.type], [arg], "disallowed_call_target").results
mlir.register_lowering(test_primitive, test_primitive_lowering)
self.addCleanup(lambda: mlir.register_lowering(test_primitive, None))

View File

@ -1064,7 +1064,6 @@ class PJitTest(jtu.BufferDonationTestCase):
f = f.lower(x, x + 1)
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
@jtu.with_mesh([('x', 2), ('y', 2)])
@ -1080,7 +1079,6 @@ class PJitTest(jtu.BufferDonationTestCase):
f = f.lower(x, x + 1)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
@jtu.with_mesh([('x', 2)])
@ -3217,13 +3215,13 @@ class ArrayPjitTest(jtu.JaxTestCase):
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
{'hi': 1.}, {'hi': 2.}, 3., 4.)
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
self.assertNotIn("\"x\"", mhlo_str)
self.assertIn("y['hi']", mhlo_str)
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
# TODO(yashkatariya): Add keep_unused support to lower_mesh_computation
# and then uncomment the below line.
# self.assertNotIn("args[0]", mhlo_str)
self.assertIn("args[1]", mhlo_str)
# self.assertNotIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
@jtu.with_mesh([('x', 2), ('y', 1)])
def test_jit_nested_xmap_lower_result_info(self):
@ -3234,9 +3232,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
1., (2.,), [3.])
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
def test_with_sharding_constraint_with_two_meshes(self):
if jax.device_count() < 4:
@ -3596,7 +3594,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
b = nest(a)
return b
with jtu.count_subjaxpr_to_mhlo_conversion(fun_name='nest') as count:
with jtu.count_subjaxpr_to_hlo_conversion(fun_name='nest') as count:
top(jnp.arange(8))
# The count should be 1 because `nest`'s lowering to MHLO should be cached.

View File

@ -271,7 +271,6 @@ class PythonPmapTest(jtu.JaxTestCase):
f = f.lower(x)
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
def testLowerCompilerIR(self):
@ -281,7 +280,6 @@ class PythonPmapTest(jtu.JaxTestCase):
f = f.lower(x)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
def testLowerCompileCompilerIR(self):
@ -2130,13 +2128,13 @@ class PythonPmapTest(jtu.JaxTestCase):
lowered = jax.pmap(f).lower(
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
self.assertNotIn("\"x\"", mhlo_str)
self.assertIn("y['hi']", mhlo_str)
self.assertIn("args[0]", mhlo_str)
self.assertIn("args[1]", mhlo_str)
self.assertIn("kwargs['z']", mhlo_str)
self.assertIn("kwargs['w']", mhlo_str)
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertNotIn("\"x\"", hlo_str)
self.assertIn("y['hi']", hlo_str)
self.assertIn("args[0]", hlo_str)
self.assertIn("args[1]", hlo_str)
self.assertIn("kwargs['z']", hlo_str)
self.assertIn("kwargs['w']", hlo_str)
def test_pmap_lower_result_info(self):
def f(x, y, z):
@ -2144,9 +2142,9 @@ class PythonPmapTest(jtu.JaxTestCase):
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
[jnp.array([3])])
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
def test_axis_name_shadowing_with_vmap(self):
# vmap-of-pmap with mismatched axis sizes

View File

@ -2538,7 +2538,7 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
harness.check_result = False
if harness.group_name == "vmap_tan":
# Tan (b/274462307) require support for custom call mhlo.tan.
# Tan (b/274462307) require support for custom call stablehlo.tan.
raise unittest.SkipTest(
"native lowering with shape polymorphism requires additional StableHLO feature support")

View File

@ -909,7 +909,7 @@ class ShardMapTest(jtu.JaxTestCase):
in_specs=P(), out_specs=P())())(2.0)
self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False)
def test_sharding_metadata_in_mhlo_attrs(self):
def test_sharding_metadata_in_hlo_attrs(self):
mesh = Mesh(jax.devices(), ('i',))
x = jnp.arange(len(jax.devices()), dtype='float32')
y = jnp.array([3.], dtype='float32')
@ -922,14 +922,14 @@ class ShardMapTest(jtu.JaxTestCase):
in_specs=P('i'), out_specs=P('i'))(x)
return x
mhlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('mhlo'))
self.assertIn("call @shmap_body", mhlo_str)
self.assertIn("call @shmap_body_0", mhlo_str)
self.assertIn("%arg0: tensor<1xf32>", mhlo_str)
self.assertIn("\"[None]\"", mhlo_str)
self.assertIn("%arg1: tensor<1xf32>", mhlo_str)
self.assertIn("\"[('i',)]\"", mhlo_str)
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", mhlo_str)
hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo'))
self.assertIn("call @shmap_body", hlo_str)
self.assertIn("call @shmap_body_0", hlo_str)
self.assertIn("%arg0: tensor<1xf32>", hlo_str)
self.assertIn("\"[None]\"", hlo_str)
self.assertIn("%arg1: tensor<1xf32>", hlo_str)
self.assertIn("\"[('i',)]\"", hlo_str)
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", hlo_str)
def test_rewrite_process_call(self):
def f(x):

View File

@ -729,7 +729,6 @@ class XMapTest(XMapTestCase):
f = f.lower(x)
self.assertIsInstance(f.as_text(), str)
self.assertIsInstance(f.as_text(dialect='hlo'), str)
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
def testLowerCompilerIR(self):
@ -738,7 +737,6 @@ class XMapTest(XMapTestCase):
f = f.lower(x)
self.assertIsNotNone(f.compiler_ir())
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
@jtu.with_mesh([('x', 2)])