mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing. PiperOrigin-RevId: 598550225
This commit is contained in:
parent
912a5ef771
commit
e558feaa5e
@ -66,6 +66,9 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
removed. Use {func}`jax.random.key_data` instead.
|
removed. Use {func}`jax.random.key_data` instead.
|
||||||
* `bool(empty_array)` now raises an error rather than returning `False`. This
|
* `bool(empty_array)` now raises an error rather than returning `False`. This
|
||||||
previously raised a deprecation warning, and follows a similar change in NumPy.
|
previously raised a deprecation warning, and follows a similar change in NumPy.
|
||||||
|
* Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses
|
||||||
|
the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
|
||||||
|
removed in the future. Use the "stablehlo" dialect instead.
|
||||||
|
|
||||||
## jaxlib 0.4.24
|
## jaxlib 0.4.24
|
||||||
|
|
||||||
|
@ -154,7 +154,7 @@ def f(runtime_token, x):
|
|||||||
```
|
```
|
||||||
Notice how the runtime tokens are only used at the JIT boundary and the compiler tokens
|
Notice how the runtime tokens are only used at the JIT boundary and the compiler tokens
|
||||||
are only within the compiled code. Compiler tokens are created during
|
are only within the compiled code. Compiler tokens are created during
|
||||||
"lowering" (we convert Python code to a lower level representation like HLO or MHLO)
|
"lowering" (we convert Python code to a lower level representation like HLO or StableHLO)
|
||||||
but runtime tokens need to be managed in Python since they're being threaded in and out
|
but runtime tokens need to be managed in Python since they're being threaded in and out
|
||||||
of JIT-ted functions.
|
of JIT-ted functions.
|
||||||
|
|
||||||
@ -200,7 +200,7 @@ this copy is not necessary.
|
|||||||
|
|
||||||
## Adding compiler tokens
|
## Adding compiler tokens
|
||||||
|
|
||||||
When we lower Python code to HLO or MHLO we need to create a token at the start of the computation and
|
When we lower Python code to HLO or StableHLO we need to create a token at the start of the computation and
|
||||||
ensure they are available when we have side-effecting computations that need to be ordered. The side-effecting
|
ensure they are available when we have side-effecting computations that need to be ordered. The side-effecting
|
||||||
computations will take the token as input and return it as an output.
|
computations will take the token as input and return it as an output.
|
||||||
|
|
||||||
|
@ -350,9 +350,9 @@ out and end-to-end compile a `shmap`ped function, just put a `jit` around it. A
|
|||||||
consequence is that `shmap` doesn't have its own dispatch and compilation paths
|
consequence is that `shmap` doesn't have its own dispatch and compilation paths
|
||||||
like `xmap` and `pmap` currently do; it's just the `jit` path.
|
like `xmap` and `pmap` currently do; it's just the `jit` path.
|
||||||
|
|
||||||
When it's staged out by e.g. an enclosing `jit`, the lowering of `shmap` to MHLO
|
When it's staged out by e.g. an enclosing `jit`, the lowering of `shmap` to
|
||||||
is trivial: it just involves switching into 'manual SPMD mode' on the inputs,
|
StableHLO is trivial: it just involves switching into 'manual SPMD mode' on the
|
||||||
and switching back on the outputs. (We don't currently plan to support
|
inputs, and switching back on the outputs. (We don't currently plan to support
|
||||||
partially-manual-partially-automatic modes.)
|
partially-manual-partially-automatic modes.)
|
||||||
|
|
||||||
The interaction with effects is the same as with `pmap`.
|
The interaction with effects is the same as with `pmap`.
|
||||||
|
@ -7,7 +7,7 @@ Python wheel, and `jaxlib`, which is a mostly-C++ wheel that contains libraries
|
|||||||
such as:
|
such as:
|
||||||
* XLA,
|
* XLA,
|
||||||
* pieces of LLVM used by XLA,
|
* pieces of LLVM used by XLA,
|
||||||
* MLIR infrastructure, such as the MHLO Python bindings.
|
* MLIR infrastructure, such as the StableHLO Python bindings.
|
||||||
* JAX-specific C++ libraries for fast JIT and PyTree manipulation.
|
* JAX-specific C++ libraries for fast JIT and PyTree manipulation.
|
||||||
|
|
||||||
We distribute separate `jax` and `jaxlib` packages because it makes it easy to
|
We distribute separate `jax` and `jaxlib` packages because it makes it easy to
|
||||||
|
@ -217,7 +217,7 @@ pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)
|
|||||||
|
|
||||||
#### Emulation mode
|
#### Emulation mode
|
||||||
|
|
||||||
By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to MHLO directly and compile/execute them with XLA. Specifically, a `pallas_call` can be implemented as a `lax.scan` over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like `jax.debug.print`). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU.
|
By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to StableHLO directly and compile/execute them with XLA. Specifically, a `pallas_call` can be implemented as a `lax.scan` over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like `jax.debug.print`). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU.
|
||||||
|
|
||||||
### Examples
|
### Examples
|
||||||
|
|
||||||
|
@ -216,7 +216,7 @@ class Config:
|
|||||||
self.jax_enable_memories,
|
self.jax_enable_memories,
|
||||||
self.jax_disable_jit,
|
self.jax_disable_jit,
|
||||||
self.jax_xla_profile_version,
|
self.jax_xla_profile_version,
|
||||||
# Technically this affects jaxpr->MHLO lowering, not tracing.
|
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||||
self.jax_hlo_source_file_canonicalization_regex)
|
self.jax_hlo_source_file_canonicalization_regex)
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
|||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, NamedTuple, Protocol, Union
|
from typing import Any, NamedTuple, Protocol, Union
|
||||||
|
import warnings
|
||||||
|
|
||||||
import jax
|
import jax
|
||||||
|
|
||||||
@ -318,6 +319,11 @@ class XlaLowering(Lowering):
|
|||||||
|
|
||||||
def mhlo(self) -> ir.Module:
|
def mhlo(self) -> ir.Module:
|
||||||
"""Return an MHLO representation of this computation."""
|
"""Return an MHLO representation of this computation."""
|
||||||
|
warnings.warn(
|
||||||
|
"mhlo support is deprecated and will be removed "
|
||||||
|
"from a future release of JAX. Use stablehlo instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
)
|
||||||
module_str = xla_extension.mlir.stablehlo_to_mhlo(
|
module_str = xla_extension.mlir.stablehlo_to_mhlo(
|
||||||
mlir.module_to_bytecode(self.stablehlo()))
|
mlir.module_to_bytecode(self.stablehlo()))
|
||||||
with self.stablehlo().context:
|
with self.stablehlo().context:
|
||||||
|
@ -275,7 +275,7 @@ def count_jit_and_pmap_compiles():
|
|||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def count_subjaxpr_to_mhlo_conversion(fun_name: str):
|
def count_subjaxpr_to_hlo_conversion(fun_name: str):
|
||||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||||
# in tests.
|
# in tests.
|
||||||
|
|
||||||
|
@ -888,7 +888,7 @@ def _check_module(mod: ir.Module, *,
|
|||||||
if op_name == "func.func":
|
if op_name == "func.func":
|
||||||
check_sharding(op.operation, op.location)
|
check_sharding(op.operation, op.location)
|
||||||
|
|
||||||
elif op_name == "stablehlo.custom_call" or op_name == "mhlo.custom_call":
|
elif op_name == "stablehlo.custom_call":
|
||||||
call_target_name_attr = op.operation.attributes["call_target_name"]
|
call_target_name_attr = op.operation.attributes["call_target_name"]
|
||||||
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
|
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
|
||||||
disallowed_custom_call_ops.append(f"{op} at {op.location}")
|
disallowed_custom_call_ops.append(f"{op} at {op.location}")
|
||||||
|
@ -2708,7 +2708,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
|||||||
harness.check_result = False
|
harness.check_result = False
|
||||||
|
|
||||||
if harness.group_name == "vmap_tan":
|
if harness.group_name == "vmap_tan":
|
||||||
# Tan (b/274462307) require support for custom call mhlo.tan.
|
# Tan (b/274462307) require support for custom call stablehlo.tan.
|
||||||
raise unittest.SkipTest(
|
raise unittest.SkipTest(
|
||||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||||
|
|
||||||
|
@ -107,7 +107,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
|||||||
def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
|
def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
|
||||||
num_replicas=1, num_partitions=2):
|
num_replicas=1, num_partitions=2):
|
||||||
"""Log the HLO generated from JAX before and after optimizations"""
|
"""Log the HLO generated from JAX before and after optimizations"""
|
||||||
jax_comp = f_jax.lower(*args).compiler_ir(dialect="mhlo")
|
jax_comp = f_jax.lower(*args).compiler_ir(dialect="stablehlo")
|
||||||
jax_hlo = str(jax_comp)
|
jax_hlo = str(jax_comp)
|
||||||
logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo)
|
logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo)
|
||||||
|
|
||||||
|
@ -14,4 +14,9 @@
|
|||||||
|
|
||||||
# ruff: noqa: F403
|
# ruff: noqa: F403
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
warnings.warn("jax.extend.mlir.dialects.mhlo is deprecated and will be removed "
|
||||||
|
"from a future release of JAX. Use stablehlo instead.",
|
||||||
|
DeprecationWarning)
|
||||||
|
|
||||||
from jaxlib.mlir.dialects.mhlo import *
|
from jaxlib.mlir.dialects.mhlo import *
|
||||||
|
@ -83,6 +83,7 @@ filterwarnings = [
|
|||||||
"ignore:not machine-readable.*:UserWarning",
|
"ignore:not machine-readable.*:UserWarning",
|
||||||
"ignore:Special cases found for .* but none were parsed.*:UserWarning",
|
"ignore:Special cases found for .* but none were parsed.*:UserWarning",
|
||||||
# end array_api_tests-related warnings
|
# end array_api_tests-related warnings
|
||||||
|
"ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
|
||||||
]
|
]
|
||||||
doctest_optionflags = [
|
doctest_optionflags = [
|
||||||
"NUMBER",
|
"NUMBER",
|
||||||
|
@ -1129,21 +1129,18 @@ class JitTest(jtu.BufferDonationTestCase):
|
|||||||
f = jit(lambda x: x + 4).lower(1.)
|
f = jit(lambda x: x + 4).lower(1.)
|
||||||
self.assertIsInstance(f.as_text(), str)
|
self.assertIsInstance(f.as_text(), str)
|
||||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
|
||||||
self.assertIsInstance(f.as_text(dialect="stablehlo"), str)
|
self.assertIsInstance(f.as_text(dialect="stablehlo"), str)
|
||||||
|
|
||||||
def test_jit_lower_compiler_ir(self):
|
def test_jit_lower_compiler_ir(self):
|
||||||
f = jit(lambda x: x + 4).lower(1.)
|
f = jit(lambda x: x + 4).lower(1.)
|
||||||
self.assertIsNotNone(f.compiler_ir())
|
self.assertIsNotNone(f.compiler_ir())
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
|
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
|
||||||
|
|
||||||
def test_jit_lower_trivial_compiler_ir(self):
|
def test_jit_lower_trivial_compiler_ir(self):
|
||||||
f = jit(lambda x: x).lower(1.)
|
f = jit(lambda x: x).lower(1.)
|
||||||
self.assertIsNotNone(f.compiler_ir())
|
self.assertIsNotNone(f.compiler_ir())
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
|
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
|
||||||
|
|
||||||
def test_jit_replica_attributes(self):
|
def test_jit_replica_attributes(self):
|
||||||
@ -1216,13 +1213,13 @@ class JitTest(jtu.BufferDonationTestCase):
|
|||||||
return y['hi'] + args[1] + sum(kwargs.values())
|
return y['hi'] + args[1] + sum(kwargs.values())
|
||||||
|
|
||||||
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
|
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
|
||||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||||
self.assertNotIn("\"x\"", mhlo_str)
|
self.assertNotIn("\"x\"", hlo_str)
|
||||||
self.assertIn("y['hi']", mhlo_str)
|
self.assertIn("y['hi']", hlo_str)
|
||||||
self.assertNotIn("args[0]", mhlo_str)
|
self.assertNotIn("args[0]", hlo_str)
|
||||||
self.assertIn("args[1]", mhlo_str)
|
self.assertIn("args[1]", hlo_str)
|
||||||
self.assertIn("kwargs['z']", mhlo_str)
|
self.assertIn("kwargs['z']", hlo_str)
|
||||||
self.assertIn("kwargs['w']", mhlo_str)
|
self.assertIn("kwargs['w']", hlo_str)
|
||||||
|
|
||||||
@parameterized.parameters([0, 2, [(0, 2)]])
|
@parameterized.parameters([0, 2, [(0, 2)]])
|
||||||
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
|
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
|
||||||
@ -1230,14 +1227,14 @@ class JitTest(jtu.BufferDonationTestCase):
|
|||||||
return y['hi'] + args[1] + sum(kwargs.values())
|
return y['hi'] + args[1] + sum(kwargs.values())
|
||||||
|
|
||||||
ir = jax.jit(f, static_argnums=static_argnums).lower(
|
ir = jax.jit(f, static_argnums=static_argnums).lower(
|
||||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.).compiler_ir('mhlo')
|
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.).compiler_ir('stablehlo')
|
||||||
mhlo_str = mlir.module_to_string(ir)
|
hlo_str = mlir.module_to_string(ir)
|
||||||
self.assertNotIn("\"x\"", mhlo_str)
|
self.assertNotIn("\"x\"", hlo_str)
|
||||||
self.assertIn("y['hi']", mhlo_str)
|
self.assertIn("y['hi']", hlo_str)
|
||||||
self.assertNotIn("args[0]", mhlo_str)
|
self.assertNotIn("args[0]", hlo_str)
|
||||||
self.assertIn("args[1]", mhlo_str)
|
self.assertIn("args[1]", hlo_str)
|
||||||
self.assertIn("kwargs['z']", mhlo_str)
|
self.assertIn("kwargs['z']", hlo_str)
|
||||||
self.assertIn("kwargs['w']", mhlo_str)
|
self.assertIn("kwargs['w']", hlo_str)
|
||||||
|
|
||||||
@parameterized.parameters(['a', 'b', [('a', 'b')]])
|
@parameterized.parameters(['a', 'b', [('a', 'b')]])
|
||||||
def test_jit_lower_arg_info_static_argnames(self, static_argnames):
|
def test_jit_lower_arg_info_static_argnames(self, static_argnames):
|
||||||
@ -1245,25 +1242,25 @@ class JitTest(jtu.BufferDonationTestCase):
|
|||||||
return y['hi'] + args[1] + kwargs['z'] + kwargs['w']
|
return y['hi'] + args[1] + kwargs['z'] + kwargs['w']
|
||||||
|
|
||||||
ir = jax.jit(f, static_argnames=static_argnames).lower(
|
ir = jax.jit(f, static_argnames=static_argnames).lower(
|
||||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.).compiler_ir('mhlo')
|
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.).compiler_ir('stablehlo')
|
||||||
mhlo_str = mlir.module_to_string(ir)
|
hlo_str = mlir.module_to_string(ir)
|
||||||
self.assertNotIn("\"x\"", mhlo_str)
|
self.assertNotIn("\"x\"", hlo_str)
|
||||||
self.assertIn("y['hi']", mhlo_str)
|
self.assertIn("y['hi']", hlo_str)
|
||||||
self.assertNotIn("args[0]", mhlo_str)
|
self.assertNotIn("args[0]", hlo_str)
|
||||||
self.assertIn("args[1]", mhlo_str)
|
self.assertIn("args[1]", hlo_str)
|
||||||
self.assertIn("kwargs['z']", mhlo_str)
|
self.assertIn("kwargs['z']", hlo_str)
|
||||||
self.assertIn("kwargs['w']", mhlo_str)
|
self.assertIn("kwargs['w']", hlo_str)
|
||||||
self.assertNotIn("kwargs['a']", mhlo_str)
|
self.assertNotIn("kwargs['a']", hlo_str)
|
||||||
self.assertNotIn("kwargs['b']", mhlo_str)
|
self.assertNotIn("kwargs['b']", hlo_str)
|
||||||
|
|
||||||
def test_jit_lower_result_info(self):
|
def test_jit_lower_result_info(self):
|
||||||
def f(x, y, z):
|
def f(x, y, z):
|
||||||
return {'a': x, 'b': [y]}
|
return {'a': x, 'b': [y]}
|
||||||
|
|
||||||
ir = jax.jit(f).lower(1., (2,), [3]).compiler_ir('mhlo')
|
ir = jax.jit(f).lower(1., (2,), [3]).compiler_ir('stablehlo')
|
||||||
mhlo_str = mlir.module_to_string(ir)
|
hlo_str = mlir.module_to_string(ir)
|
||||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||||
|
|
||||||
def test_jit_lower_compile_with_compiler_options(self):
|
def test_jit_lower_compile_with_compiler_options(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
@ -2786,9 +2783,6 @@ class APITest(jtu.JaxTestCase):
|
|||||||
hlo = api.jit(e).lower(2.).compiler_ir(dialect="hlo").as_hlo_text()
|
hlo = api.jit(e).lower(2.).compiler_ir(dialect="hlo").as_hlo_text()
|
||||||
self.assertIn(' cosine', hlo)
|
self.assertIn(' cosine', hlo)
|
||||||
self.assertIn(' sine', hlo)
|
self.assertIn(' sine', hlo)
|
||||||
mhlo = str(api.jit(e).lower(2.).compiler_ir(dialect="mhlo"))
|
|
||||||
self.assertIn('mhlo.cosine', mhlo)
|
|
||||||
self.assertIn('mhlo.sine', mhlo)
|
|
||||||
stablehlo = str(api.jit(e).lower(2.).compiler_ir(dialect="stablehlo"))
|
stablehlo = str(api.jit(e).lower(2.).compiler_ir(dialect="stablehlo"))
|
||||||
self.assertIn("stablehlo.cosine", stablehlo)
|
self.assertIn("stablehlo.cosine", stablehlo)
|
||||||
self.assertIn("stablehlo.sine", stablehlo)
|
self.assertIn("stablehlo.sine", stablehlo)
|
||||||
|
@ -311,17 +311,15 @@ class JaxExportTest(jtu.JaxTestCase):
|
|||||||
@jtu.parameterized_filterable(
|
@jtu.parameterized_filterable(
|
||||||
testcase_name=lambda kw: kw["dialect"],
|
testcase_name=lambda kw: kw["dialect"],
|
||||||
kwargs=[dict(dialect=dialect)
|
kwargs=[dict(dialect=dialect)
|
||||||
for dialect in ("mhlo", "stablehlo")]
|
for dialect in ("stablehlo",)]
|
||||||
)
|
)
|
||||||
def test_error_disallowed_custom_call(self, dialect):
|
def test_error_disallowed_custom_call(self, dialect):
|
||||||
# If we use hlo.custom_call or mhlo.custom_call we detect
|
# If we use hlo.custom_call we detect invalid custom call targets.
|
||||||
# invalid custom call targets.
|
|
||||||
# Set up a primitive with custom lowering rules
|
# Set up a primitive with custom lowering rules
|
||||||
test_primitive = core.Primitive("_test_primitive_disallowed_custom_call")
|
test_primitive = core.Primitive("_test_primitive_disallowed_custom_call")
|
||||||
test_primitive.def_abstract_eval(lambda in_aval: in_aval)
|
test_primitive.def_abstract_eval(lambda in_aval: in_aval)
|
||||||
def test_primitive_lowering(ctx, arg):
|
def test_primitive_lowering(ctx, arg):
|
||||||
from jax._src.lib.mlir.dialects import mhlo
|
op = dict(stablehlo=hlo.CustomCallOp)[dialect]
|
||||||
op = dict(stablehlo=hlo.CustomCallOp, mhlo=mhlo.CustomCallOp)[dialect]
|
|
||||||
return op([arg.type], [arg], "disallowed_call_target").results
|
return op([arg.type], [arg], "disallowed_call_target").results
|
||||||
mlir.register_lowering(test_primitive, test_primitive_lowering)
|
mlir.register_lowering(test_primitive, test_primitive_lowering)
|
||||||
self.addCleanup(lambda: mlir.register_lowering(test_primitive, None))
|
self.addCleanup(lambda: mlir.register_lowering(test_primitive, None))
|
||||||
|
@ -1064,7 +1064,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
f = f.lower(x, x + 1)
|
f = f.lower(x, x + 1)
|
||||||
self.assertIsInstance(f.as_text(), str)
|
self.assertIsInstance(f.as_text(), str)
|
||||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
|
||||||
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
||||||
|
|
||||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||||
@ -1080,7 +1079,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
|||||||
f = f.lower(x, x + 1)
|
f = f.lower(x, x + 1)
|
||||||
self.assertIsNotNone(f.compiler_ir())
|
self.assertIsNotNone(f.compiler_ir())
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||||
|
|
||||||
@jtu.with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
@ -3217,13 +3215,13 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
|
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
|
||||||
{'hi': 1.}, {'hi': 2.}, 3., 4.)
|
{'hi': 1.}, {'hi': 2.}, 3., 4.)
|
||||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||||
self.assertNotIn("\"x\"", mhlo_str)
|
self.assertNotIn("\"x\"", hlo_str)
|
||||||
self.assertIn("y['hi']", mhlo_str)
|
self.assertIn("y['hi']", hlo_str)
|
||||||
# TODO(yashkatariya): Add keep_unused support to lower_mesh_computation
|
# TODO(yashkatariya): Add keep_unused support to lower_mesh_computation
|
||||||
# and then uncomment the below line.
|
# and then uncomment the below line.
|
||||||
# self.assertNotIn("args[0]", mhlo_str)
|
# self.assertNotIn("args[0]", hlo_str)
|
||||||
self.assertIn("args[1]", mhlo_str)
|
self.assertIn("args[1]", hlo_str)
|
||||||
|
|
||||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||||
def test_jit_nested_xmap_lower_result_info(self):
|
def test_jit_nested_xmap_lower_result_info(self):
|
||||||
@ -3234,9 +3232,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
|
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
|
||||||
1., (2.,), [3.])
|
1., (2.,), [3.])
|
||||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||||
|
|
||||||
def test_with_sharding_constraint_with_two_meshes(self):
|
def test_with_sharding_constraint_with_two_meshes(self):
|
||||||
if jax.device_count() < 4:
|
if jax.device_count() < 4:
|
||||||
@ -3596,7 +3594,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
|||||||
b = nest(a)
|
b = nest(a)
|
||||||
return b
|
return b
|
||||||
|
|
||||||
with jtu.count_subjaxpr_to_mhlo_conversion(fun_name='nest') as count:
|
with jtu.count_subjaxpr_to_hlo_conversion(fun_name='nest') as count:
|
||||||
top(jnp.arange(8))
|
top(jnp.arange(8))
|
||||||
|
|
||||||
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
|
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
|
||||||
|
@ -271,7 +271,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
|||||||
f = f.lower(x)
|
f = f.lower(x)
|
||||||
self.assertIsInstance(f.as_text(), str)
|
self.assertIsInstance(f.as_text(), str)
|
||||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
|
||||||
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
||||||
|
|
||||||
def testLowerCompilerIR(self):
|
def testLowerCompilerIR(self):
|
||||||
@ -281,7 +280,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
|||||||
f = f.lower(x)
|
f = f.lower(x)
|
||||||
self.assertIsNotNone(f.compiler_ir())
|
self.assertIsNotNone(f.compiler_ir())
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||||
|
|
||||||
def testLowerCompileCompilerIR(self):
|
def testLowerCompileCompilerIR(self):
|
||||||
@ -2130,13 +2128,13 @@ class PythonPmapTest(jtu.JaxTestCase):
|
|||||||
lowered = jax.pmap(f).lower(
|
lowered = jax.pmap(f).lower(
|
||||||
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
|
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
|
||||||
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
|
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
|
||||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||||
self.assertNotIn("\"x\"", mhlo_str)
|
self.assertNotIn("\"x\"", hlo_str)
|
||||||
self.assertIn("y['hi']", mhlo_str)
|
self.assertIn("y['hi']", hlo_str)
|
||||||
self.assertIn("args[0]", mhlo_str)
|
self.assertIn("args[0]", hlo_str)
|
||||||
self.assertIn("args[1]", mhlo_str)
|
self.assertIn("args[1]", hlo_str)
|
||||||
self.assertIn("kwargs['z']", mhlo_str)
|
self.assertIn("kwargs['z']", hlo_str)
|
||||||
self.assertIn("kwargs['w']", mhlo_str)
|
self.assertIn("kwargs['w']", hlo_str)
|
||||||
|
|
||||||
def test_pmap_lower_result_info(self):
|
def test_pmap_lower_result_info(self):
|
||||||
def f(x, y, z):
|
def f(x, y, z):
|
||||||
@ -2144,9 +2142,9 @@ class PythonPmapTest(jtu.JaxTestCase):
|
|||||||
|
|
||||||
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
|
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
|
||||||
[jnp.array([3])])
|
[jnp.array([3])])
|
||||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||||
|
|
||||||
def test_axis_name_shadowing_with_vmap(self):
|
def test_axis_name_shadowing_with_vmap(self):
|
||||||
# vmap-of-pmap with mismatched axis sizes
|
# vmap-of-pmap with mismatched axis sizes
|
||||||
|
@ -2538,7 +2538,7 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
|||||||
harness.check_result = False
|
harness.check_result = False
|
||||||
|
|
||||||
if harness.group_name == "vmap_tan":
|
if harness.group_name == "vmap_tan":
|
||||||
# Tan (b/274462307) require support for custom call mhlo.tan.
|
# Tan (b/274462307) require support for custom call stablehlo.tan.
|
||||||
raise unittest.SkipTest(
|
raise unittest.SkipTest(
|
||||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||||
|
|
||||||
|
@ -909,7 +909,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
|||||||
in_specs=P(), out_specs=P())())(2.0)
|
in_specs=P(), out_specs=P())())(2.0)
|
||||||
self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False)
|
self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False)
|
||||||
|
|
||||||
def test_sharding_metadata_in_mhlo_attrs(self):
|
def test_sharding_metadata_in_hlo_attrs(self):
|
||||||
mesh = Mesh(jax.devices(), ('i',))
|
mesh = Mesh(jax.devices(), ('i',))
|
||||||
x = jnp.arange(len(jax.devices()), dtype='float32')
|
x = jnp.arange(len(jax.devices()), dtype='float32')
|
||||||
y = jnp.array([3.], dtype='float32')
|
y = jnp.array([3.], dtype='float32')
|
||||||
@ -922,14 +922,14 @@ class ShardMapTest(jtu.JaxTestCase):
|
|||||||
in_specs=P('i'), out_specs=P('i'))(x)
|
in_specs=P('i'), out_specs=P('i'))(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
mhlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('mhlo'))
|
hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo'))
|
||||||
self.assertIn("call @shmap_body", mhlo_str)
|
self.assertIn("call @shmap_body", hlo_str)
|
||||||
self.assertIn("call @shmap_body_0", mhlo_str)
|
self.assertIn("call @shmap_body_0", hlo_str)
|
||||||
self.assertIn("%arg0: tensor<1xf32>", mhlo_str)
|
self.assertIn("%arg0: tensor<1xf32>", hlo_str)
|
||||||
self.assertIn("\"[None]\"", mhlo_str)
|
self.assertIn("\"[None]\"", hlo_str)
|
||||||
self.assertIn("%arg1: tensor<1xf32>", mhlo_str)
|
self.assertIn("%arg1: tensor<1xf32>", hlo_str)
|
||||||
self.assertIn("\"[('i',)]\"", mhlo_str)
|
self.assertIn("\"[('i',)]\"", hlo_str)
|
||||||
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", mhlo_str)
|
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", hlo_str)
|
||||||
|
|
||||||
def test_rewrite_process_call(self):
|
def test_rewrite_process_call(self):
|
||||||
def f(x):
|
def f(x):
|
||||||
|
@ -729,7 +729,6 @@ class XMapTest(XMapTestCase):
|
|||||||
f = f.lower(x)
|
f = f.lower(x)
|
||||||
self.assertIsInstance(f.as_text(), str)
|
self.assertIsInstance(f.as_text(), str)
|
||||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
|
||||||
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
||||||
|
|
||||||
def testLowerCompilerIR(self):
|
def testLowerCompilerIR(self):
|
||||||
@ -738,7 +737,6 @@ class XMapTest(XMapTestCase):
|
|||||||
f = f.lower(x)
|
f = f.lower(x)
|
||||||
self.assertIsNotNone(f.compiler_ir())
|
self.assertIsNotNone(f.compiler_ir())
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
|
||||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||||
|
|
||||||
@jtu.with_mesh([('x', 2)])
|
@jtu.with_mesh([('x', 2)])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user