mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Deprecate support for the mhlo dialect.
JAX has not used mhlo for some time, in favor of stablehlo. Deprecate support for this dialect in JAX's API and remove testing. PiperOrigin-RevId: 598550225
This commit is contained in:
parent
912a5ef771
commit
e558feaa5e
@ -66,6 +66,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
removed. Use {func}`jax.random.key_data` instead.
|
||||
* `bool(empty_array)` now raises an error rather than returning `False`. This
|
||||
previously raised a deprecation warning, and follows a similar change in NumPy.
|
||||
* Support for the mhlo MLIR dialect has been deprecated. JAX no longer uses
|
||||
the mhlo dialect, in favor of stablehlo. APIs that refer to "mhlo" will be
|
||||
removed in the future. Use the "stablehlo" dialect instead.
|
||||
|
||||
## jaxlib 0.4.24
|
||||
|
||||
|
@ -154,7 +154,7 @@ def f(runtime_token, x):
|
||||
```
|
||||
Notice how the runtime tokens are only used at the JIT boundary and the compiler tokens
|
||||
are only within the compiled code. Compiler tokens are created during
|
||||
"lowering" (we convert Python code to a lower level representation like HLO or MHLO)
|
||||
"lowering" (we convert Python code to a lower level representation like HLO or StableHLO)
|
||||
but runtime tokens need to be managed in Python since they're being threaded in and out
|
||||
of JIT-ted functions.
|
||||
|
||||
@ -200,7 +200,7 @@ this copy is not necessary.
|
||||
|
||||
## Adding compiler tokens
|
||||
|
||||
When we lower Python code to HLO or MHLO we need to create a token at the start of the computation and
|
||||
When we lower Python code to HLO or StableHLO we need to create a token at the start of the computation and
|
||||
ensure they are available when we have side-effecting computations that need to be ordered. The side-effecting
|
||||
computations will take the token as input and return it as an output.
|
||||
|
||||
|
@ -350,9 +350,9 @@ out and end-to-end compile a `shmap`ped function, just put a `jit` around it. A
|
||||
consequence is that `shmap` doesn't have its own dispatch and compilation paths
|
||||
like `xmap` and `pmap` currently do; it's just the `jit` path.
|
||||
|
||||
When it's staged out by e.g. an enclosing `jit`, the lowering of `shmap` to MHLO
|
||||
is trivial: it just involves switching into 'manual SPMD mode' on the inputs,
|
||||
and switching back on the outputs. (We don't currently plan to support
|
||||
When it's staged out by e.g. an enclosing `jit`, the lowering of `shmap` to
|
||||
StableHLO is trivial: it just involves switching into 'manual SPMD mode' on the
|
||||
inputs, and switching back on the outputs. (We don't currently plan to support
|
||||
partially-manual-partially-automatic modes.)
|
||||
|
||||
The interaction with effects is the same as with `pmap`.
|
||||
|
@ -7,7 +7,7 @@ Python wheel, and `jaxlib`, which is a mostly-C++ wheel that contains libraries
|
||||
such as:
|
||||
* XLA,
|
||||
* pieces of LLVM used by XLA,
|
||||
* MLIR infrastructure, such as the MHLO Python bindings.
|
||||
* MLIR infrastructure, such as the StableHLO Python bindings.
|
||||
* JAX-specific C++ libraries for fast JIT and PyTree manipulation.
|
||||
|
||||
We distribute separate `jax` and `jaxlib` packages because it makes it easy to
|
||||
|
@ -217,7 +217,7 @@ pl.pallas_call(kernel2, out_shape=x, grid=1)(1., 1.)
|
||||
|
||||
#### Emulation mode
|
||||
|
||||
By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to MHLO directly and compile/execute them with XLA. Specifically, a `pallas_call` can be implemented as a `lax.scan` over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like `jax.debug.print`). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU.
|
||||
By representing kernels as programs with JAX primitives and some new Pallas primitives, we can also lower Pallas programs to StableHLO directly and compile/execute them with XLA. Specifically, a `pallas_call` can be implemented as a `lax.scan` over the grid. This enables us to develop GPU or TPU kernels on any XLA-supported platform (even CPU!) and debug them using JAX/XLA debugging tools (like `jax.debug.print`). We can also use the more reliable and better tested XLA numerics to verify the correctness of the Triton and Mosaic compilers. One could also imagine perturbing the `scan` ordering to simulate the parallel reads and writes that happen on GPU.
|
||||
|
||||
### Examples
|
||||
|
||||
|
@ -216,7 +216,7 @@ class Config:
|
||||
self.jax_enable_memories,
|
||||
self.jax_disable_jit,
|
||||
self.jax_xla_profile_version,
|
||||
# Technically this affects jaxpr->MHLO lowering, not tracing.
|
||||
# Technically this affects jaxpr->stablehlo lowering, not tracing.
|
||||
self.jax_hlo_source_file_canonicalization_regex)
|
||||
|
||||
|
||||
|
@ -33,6 +33,7 @@ from __future__ import annotations
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, NamedTuple, Protocol, Union
|
||||
import warnings
|
||||
|
||||
import jax
|
||||
|
||||
@ -318,6 +319,11 @@ class XlaLowering(Lowering):
|
||||
|
||||
def mhlo(self) -> ir.Module:
|
||||
"""Return an MHLO representation of this computation."""
|
||||
warnings.warn(
|
||||
"mhlo support is deprecated and will be removed "
|
||||
"from a future release of JAX. Use stablehlo instead.",
|
||||
DeprecationWarning,
|
||||
)
|
||||
module_str = xla_extension.mlir.stablehlo_to_mhlo(
|
||||
mlir.module_to_bytecode(self.stablehlo()))
|
||||
with self.stablehlo().context:
|
||||
|
@ -275,7 +275,7 @@ def count_jit_and_pmap_compiles():
|
||||
|
||||
|
||||
@contextmanager
|
||||
def count_subjaxpr_to_mhlo_conversion(fun_name: str):
|
||||
def count_subjaxpr_to_hlo_conversion(fun_name: str):
|
||||
# No need to clear any caches since we generally jit and pmap fresh callables
|
||||
# in tests.
|
||||
|
||||
|
@ -888,7 +888,7 @@ def _check_module(mod: ir.Module, *,
|
||||
if op_name == "func.func":
|
||||
check_sharding(op.operation, op.location)
|
||||
|
||||
elif op_name == "stablehlo.custom_call" or op_name == "mhlo.custom_call":
|
||||
elif op_name == "stablehlo.custom_call":
|
||||
call_target_name_attr = op.operation.attributes["call_target_name"]
|
||||
if (call_target_name_attr not in allowed_custom_call_targets_attrs):
|
||||
disallowed_custom_call_ops.append(f"{op} at {op.location}")
|
||||
|
@ -2708,7 +2708,7 @@ class ShapePolyPrimitivesTest(tf_test_util.JaxToTfTestCase):
|
||||
harness.check_result = False
|
||||
|
||||
if harness.group_name == "vmap_tan":
|
||||
# Tan (b/274462307) require support for custom call mhlo.tan.
|
||||
# Tan (b/274462307) require support for custom call stablehlo.tan.
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
|
||||
|
@ -107,7 +107,7 @@ class ShardingTest(tf_test_util.JaxToTfTestCase):
|
||||
def log_jax_hlo(self, f_jax, args: Sequence[Any], *,
|
||||
num_replicas=1, num_partitions=2):
|
||||
"""Log the HLO generated from JAX before and after optimizations"""
|
||||
jax_comp = f_jax.lower(*args).compiler_ir(dialect="mhlo")
|
||||
jax_comp = f_jax.lower(*args).compiler_ir(dialect="stablehlo")
|
||||
jax_hlo = str(jax_comp)
|
||||
logging.info("[%s] got JAX HLO %s", self._testMethodName, jax_hlo)
|
||||
|
||||
|
@ -14,4 +14,9 @@
|
||||
|
||||
# ruff: noqa: F403
|
||||
|
||||
import warnings
|
||||
warnings.warn("jax.extend.mlir.dialects.mhlo is deprecated and will be removed "
|
||||
"from a future release of JAX. Use stablehlo instead.",
|
||||
DeprecationWarning)
|
||||
|
||||
from jaxlib.mlir.dialects.mhlo import *
|
||||
|
@ -83,6 +83,7 @@ filterwarnings = [
|
||||
"ignore:not machine-readable.*:UserWarning",
|
||||
"ignore:Special cases found for .* but none were parsed.*:UserWarning",
|
||||
# end array_api_tests-related warnings
|
||||
"ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning",
|
||||
]
|
||||
doctest_optionflags = [
|
||||
"NUMBER",
|
||||
|
@ -1129,21 +1129,18 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
f = jit(lambda x: x + 4).lower(1.)
|
||||
self.assertIsInstance(f.as_text(), str)
|
||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect="stablehlo"), str)
|
||||
|
||||
def test_jit_lower_compiler_ir(self):
|
||||
f = jit(lambda x: x + 4).lower(1.)
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
|
||||
|
||||
def test_jit_lower_trivial_compiler_ir(self):
|
||||
f = jit(lambda x: x).lower(1.)
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect="stablehlo"))
|
||||
|
||||
def test_jit_replica_attributes(self):
|
||||
@ -1216,13 +1213,13 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
return y['hi'] + args[1] + sum(kwargs.values())
|
||||
|
||||
lowered = jax.jit(f).lower({'hi': 1.}, {'hi': 2.}, 3., 4., z=5., w=6.)
|
||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
||||
self.assertNotIn("\"x\"", mhlo_str)
|
||||
self.assertIn("y['hi']", mhlo_str)
|
||||
self.assertNotIn("args[0]", mhlo_str)
|
||||
self.assertIn("args[1]", mhlo_str)
|
||||
self.assertIn("kwargs['z']", mhlo_str)
|
||||
self.assertIn("kwargs['w']", mhlo_str)
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
self.assertNotIn("args[0]", hlo_str)
|
||||
self.assertIn("args[1]", hlo_str)
|
||||
self.assertIn("kwargs['z']", hlo_str)
|
||||
self.assertIn("kwargs['w']", hlo_str)
|
||||
|
||||
@parameterized.parameters([0, 2, [(0, 2)]])
|
||||
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
|
||||
@ -1230,14 +1227,14 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
return y['hi'] + args[1] + sum(kwargs.values())
|
||||
|
||||
ir = jax.jit(f, static_argnums=static_argnums).lower(
|
||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.).compiler_ir('mhlo')
|
||||
mhlo_str = mlir.module_to_string(ir)
|
||||
self.assertNotIn("\"x\"", mhlo_str)
|
||||
self.assertIn("y['hi']", mhlo_str)
|
||||
self.assertNotIn("args[0]", mhlo_str)
|
||||
self.assertIn("args[1]", mhlo_str)
|
||||
self.assertIn("kwargs['z']", mhlo_str)
|
||||
self.assertIn("kwargs['w']", mhlo_str)
|
||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6.).compiler_ir('stablehlo')
|
||||
hlo_str = mlir.module_to_string(ir)
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
self.assertNotIn("args[0]", hlo_str)
|
||||
self.assertIn("args[1]", hlo_str)
|
||||
self.assertIn("kwargs['z']", hlo_str)
|
||||
self.assertIn("kwargs['w']", hlo_str)
|
||||
|
||||
@parameterized.parameters(['a', 'b', [('a', 'b')]])
|
||||
def test_jit_lower_arg_info_static_argnames(self, static_argnames):
|
||||
@ -1245,25 +1242,25 @@ class JitTest(jtu.BufferDonationTestCase):
|
||||
return y['hi'] + args[1] + kwargs['z'] + kwargs['w']
|
||||
|
||||
ir = jax.jit(f, static_argnames=static_argnames).lower(
|
||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.).compiler_ir('mhlo')
|
||||
mhlo_str = mlir.module_to_string(ir)
|
||||
self.assertNotIn("\"x\"", mhlo_str)
|
||||
self.assertIn("y['hi']", mhlo_str)
|
||||
self.assertNotIn("args[0]", mhlo_str)
|
||||
self.assertIn("args[1]", mhlo_str)
|
||||
self.assertIn("kwargs['z']", mhlo_str)
|
||||
self.assertIn("kwargs['w']", mhlo_str)
|
||||
self.assertNotIn("kwargs['a']", mhlo_str)
|
||||
self.assertNotIn("kwargs['b']", mhlo_str)
|
||||
(1.,), {'hi': 2.}, 3., 4., z=5., w=6., a=7., b=8.).compiler_ir('stablehlo')
|
||||
hlo_str = mlir.module_to_string(ir)
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
self.assertNotIn("args[0]", hlo_str)
|
||||
self.assertIn("args[1]", hlo_str)
|
||||
self.assertIn("kwargs['z']", hlo_str)
|
||||
self.assertIn("kwargs['w']", hlo_str)
|
||||
self.assertNotIn("kwargs['a']", hlo_str)
|
||||
self.assertNotIn("kwargs['b']", hlo_str)
|
||||
|
||||
def test_jit_lower_result_info(self):
|
||||
def f(x, y, z):
|
||||
return {'a': x, 'b': [y]}
|
||||
|
||||
ir = jax.jit(f).lower(1., (2,), [3]).compiler_ir('mhlo')
|
||||
mhlo_str = mlir.module_to_string(ir)
|
||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
ir = jax.jit(f).lower(1., (2,), [3]).compiler_ir('stablehlo')
|
||||
hlo_str = mlir.module_to_string(ir)
|
||||
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||
|
||||
def test_jit_lower_compile_with_compiler_options(self):
|
||||
def f(x):
|
||||
@ -2786,9 +2783,6 @@ class APITest(jtu.JaxTestCase):
|
||||
hlo = api.jit(e).lower(2.).compiler_ir(dialect="hlo").as_hlo_text()
|
||||
self.assertIn(' cosine', hlo)
|
||||
self.assertIn(' sine', hlo)
|
||||
mhlo = str(api.jit(e).lower(2.).compiler_ir(dialect="mhlo"))
|
||||
self.assertIn('mhlo.cosine', mhlo)
|
||||
self.assertIn('mhlo.sine', mhlo)
|
||||
stablehlo = str(api.jit(e).lower(2.).compiler_ir(dialect="stablehlo"))
|
||||
self.assertIn("stablehlo.cosine", stablehlo)
|
||||
self.assertIn("stablehlo.sine", stablehlo)
|
||||
|
@ -311,17 +311,15 @@ class JaxExportTest(jtu.JaxTestCase):
|
||||
@jtu.parameterized_filterable(
|
||||
testcase_name=lambda kw: kw["dialect"],
|
||||
kwargs=[dict(dialect=dialect)
|
||||
for dialect in ("mhlo", "stablehlo")]
|
||||
for dialect in ("stablehlo",)]
|
||||
)
|
||||
def test_error_disallowed_custom_call(self, dialect):
|
||||
# If we use hlo.custom_call or mhlo.custom_call we detect
|
||||
# invalid custom call targets.
|
||||
# If we use hlo.custom_call we detect invalid custom call targets.
|
||||
# Set up a primitive with custom lowering rules
|
||||
test_primitive = core.Primitive("_test_primitive_disallowed_custom_call")
|
||||
test_primitive.def_abstract_eval(lambda in_aval: in_aval)
|
||||
def test_primitive_lowering(ctx, arg):
|
||||
from jax._src.lib.mlir.dialects import mhlo
|
||||
op = dict(stablehlo=hlo.CustomCallOp, mhlo=mhlo.CustomCallOp)[dialect]
|
||||
op = dict(stablehlo=hlo.CustomCallOp)[dialect]
|
||||
return op([arg.type], [arg], "disallowed_call_target").results
|
||||
mlir.register_lowering(test_primitive, test_primitive_lowering)
|
||||
self.addCleanup(lambda: mlir.register_lowering(test_primitive, None))
|
||||
|
@ -1064,7 +1064,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
f = f.lower(x, x + 1)
|
||||
self.assertIsInstance(f.as_text(), str)
|
||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 2)])
|
||||
@ -1080,7 +1079,6 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
f = f.lower(x, x + 1)
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
@ -3217,13 +3215,13 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
|
||||
{'hi': 1.}, {'hi': 2.}, 3., 4.)
|
||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
||||
self.assertNotIn("\"x\"", mhlo_str)
|
||||
self.assertIn("y['hi']", mhlo_str)
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
# TODO(yashkatariya): Add keep_unused support to lower_mesh_computation
|
||||
# and then uncomment the below line.
|
||||
# self.assertNotIn("args[0]", mhlo_str)
|
||||
self.assertIn("args[1]", mhlo_str)
|
||||
# self.assertNotIn("args[0]", hlo_str)
|
||||
self.assertIn("args[1]", hlo_str)
|
||||
|
||||
@jtu.with_mesh([('x', 2), ('y', 1)])
|
||||
def test_jit_nested_xmap_lower_result_info(self):
|
||||
@ -3234,9 +3232,9 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
|
||||
lowered = pjit(f, in_shardings=P(), out_shardings=P()).lower(
|
||||
1., (2.,), [3.])
|
||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||
|
||||
def test_with_sharding_constraint_with_two_meshes(self):
|
||||
if jax.device_count() < 4:
|
||||
@ -3596,7 +3594,7 @@ class ArrayPjitTest(jtu.JaxTestCase):
|
||||
b = nest(a)
|
||||
return b
|
||||
|
||||
with jtu.count_subjaxpr_to_mhlo_conversion(fun_name='nest') as count:
|
||||
with jtu.count_subjaxpr_to_hlo_conversion(fun_name='nest') as count:
|
||||
top(jnp.arange(8))
|
||||
|
||||
# The count should be 1 because `nest`'s lowering to MHLO should be cached.
|
||||
|
@ -271,7 +271,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = f.lower(x)
|
||||
self.assertIsInstance(f.as_text(), str)
|
||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
||||
|
||||
def testLowerCompilerIR(self):
|
||||
@ -281,7 +280,6 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
f = f.lower(x)
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||
|
||||
def testLowerCompileCompilerIR(self):
|
||||
@ -2130,13 +2128,13 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
lowered = jax.pmap(f).lower(
|
||||
{'hi': jnp.array([1.])}, {'hi': jnp.array([2.])}, jnp.array([3.]),
|
||||
jnp.array([4.]), z=jnp.array([5.]), w=jnp.array([6.]))
|
||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
||||
self.assertNotIn("\"x\"", mhlo_str)
|
||||
self.assertIn("y['hi']", mhlo_str)
|
||||
self.assertIn("args[0]", mhlo_str)
|
||||
self.assertIn("args[1]", mhlo_str)
|
||||
self.assertIn("kwargs['z']", mhlo_str)
|
||||
self.assertIn("kwargs['w']", mhlo_str)
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertNotIn("\"x\"", hlo_str)
|
||||
self.assertIn("y['hi']", hlo_str)
|
||||
self.assertIn("args[0]", hlo_str)
|
||||
self.assertIn("args[1]", hlo_str)
|
||||
self.assertIn("kwargs['z']", hlo_str)
|
||||
self.assertIn("kwargs['w']", hlo_str)
|
||||
|
||||
def test_pmap_lower_result_info(self):
|
||||
def f(x, y, z):
|
||||
@ -2144,9 +2142,9 @@ class PythonPmapTest(jtu.JaxTestCase):
|
||||
|
||||
lowered = jax.pmap(f).lower(jnp.array([1.]), (jnp.array([2]),),
|
||||
[jnp.array([3])])
|
||||
mhlo_str = mlir.module_to_string(lowered.compiler_ir('mhlo'))
|
||||
self.assertIn("jax.result_info = \"['a']\"", mhlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", mhlo_str)
|
||||
hlo_str = mlir.module_to_string(lowered.compiler_ir('stablehlo'))
|
||||
self.assertIn("jax.result_info = \"['a']\"", hlo_str)
|
||||
self.assertIn("jax.result_info = \"['b'][0][0]\"", hlo_str)
|
||||
|
||||
def test_axis_name_shadowing_with_vmap(self):
|
||||
# vmap-of-pmap with mismatched axis sizes
|
||||
|
@ -2538,7 +2538,7 @@ class ShapePolyHarnessesTest(jtu.JaxTestCase):
|
||||
harness.check_result = False
|
||||
|
||||
if harness.group_name == "vmap_tan":
|
||||
# Tan (b/274462307) require support for custom call mhlo.tan.
|
||||
# Tan (b/274462307) require support for custom call stablehlo.tan.
|
||||
raise unittest.SkipTest(
|
||||
"native lowering with shape polymorphism requires additional StableHLO feature support")
|
||||
|
||||
|
@ -909,7 +909,7 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
in_specs=P(), out_specs=P())())(2.0)
|
||||
self.assertAllClose(g, jnp.cos(2.0), check_dtypes=False)
|
||||
|
||||
def test_sharding_metadata_in_mhlo_attrs(self):
|
||||
def test_sharding_metadata_in_hlo_attrs(self):
|
||||
mesh = Mesh(jax.devices(), ('i',))
|
||||
x = jnp.arange(len(jax.devices()), dtype='float32')
|
||||
y = jnp.array([3.], dtype='float32')
|
||||
@ -922,14 +922,14 @@ class ShardMapTest(jtu.JaxTestCase):
|
||||
in_specs=P('i'), out_specs=P('i'))(x)
|
||||
return x
|
||||
|
||||
mhlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('mhlo'))
|
||||
self.assertIn("call @shmap_body", mhlo_str)
|
||||
self.assertIn("call @shmap_body_0", mhlo_str)
|
||||
self.assertIn("%arg0: tensor<1xf32>", mhlo_str)
|
||||
self.assertIn("\"[None]\"", mhlo_str)
|
||||
self.assertIn("%arg1: tensor<1xf32>", mhlo_str)
|
||||
self.assertIn("\"[('i',)]\"", mhlo_str)
|
||||
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", mhlo_str)
|
||||
hlo_str = mlir.module_to_string(jax.jit(foo).lower(x).compiler_ir('stablehlo'))
|
||||
self.assertIn("call @shmap_body", hlo_str)
|
||||
self.assertIn("call @shmap_body_0", hlo_str)
|
||||
self.assertIn("%arg0: tensor<1xf32>", hlo_str)
|
||||
self.assertIn("\"[None]\"", hlo_str)
|
||||
self.assertIn("%arg1: tensor<1xf32>", hlo_str)
|
||||
self.assertIn("\"[('i',)]\"", hlo_str)
|
||||
self.assertIn("-> (tensor<1xf32> {jax.result_info = \"[('i',)]\"})", hlo_str)
|
||||
|
||||
def test_rewrite_process_call(self):
|
||||
def f(x):
|
||||
|
@ -729,7 +729,6 @@ class XMapTest(XMapTestCase):
|
||||
f = f.lower(x)
|
||||
self.assertIsInstance(f.as_text(), str)
|
||||
self.assertIsInstance(f.as_text(dialect='hlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='mhlo'), str)
|
||||
self.assertIsInstance(f.as_text(dialect='stablehlo'), str)
|
||||
|
||||
def testLowerCompilerIR(self):
|
||||
@ -738,7 +737,6 @@ class XMapTest(XMapTestCase):
|
||||
f = f.lower(x)
|
||||
self.assertIsNotNone(f.compiler_ir())
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='hlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='mhlo'))
|
||||
self.assertIsNotNone(f.compiler_ir(dialect='stablehlo'))
|
||||
|
||||
@jtu.with_mesh([('x', 2)])
|
||||
|
Loading…
x
Reference in New Issue
Block a user