[JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.

It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options.

For example, one can now write things like:

```
In [1]: import numpy as np, jax, jax.numpy as jnp
In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo')
In [3]: m.operation.print(large_elements_limit=10)
module @jit__lambda_.4 {
  func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> {
    %0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32>
    %1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32>
    %2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32>
    %3 = mhlo.add %2, %1 : tensor<1000xf32>
    return %3 : tensor<1000xf32>
  }
}
```

Fixes https://github.com/google/jax/issues/9226

PiperOrigin-RevId: 422855649
This commit is contained in:
Peter Hawkins 2022-01-19 11:01:03 -08:00 committed by jax authors
parent e152000c03
commit 3fef74b2d0
6 changed files with 33 additions and 17 deletions

View File

@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.28 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.27...main).
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
`ir.Module` object instead of its string representation.
## jaxlib 0.1.76 (Unreleased)

View File

@ -510,12 +510,14 @@ class XlaComputation:
mlir.module_to_string(self._hlo),
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> str:
def mhlo(self) -> ir.Module:
if self.is_trivial():
raise ValueError("A trivial computation has no MHLO")
if isinstance(self._hlo, xc.XlaComputation):
return xe.mlir.xla_computation_to_mlir_module(self._hlo)
return mlir.module_to_string(self._hlo)
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
return ir.Module.parse(module_str)
return self._hlo
def compile(self) -> 'XlaCompiledComputation':
if self._executable is None:

View File

@ -281,6 +281,14 @@ def _source_info_to_location(
# Translation rules
def make_ir_context() -> ir.Context:
"""Creates an MLIR context suitable for JAX IR."""
context = ir.Context()
mhlo.register_mhlo_dialect(context)
chlo.register_chlo_dialect(context)
return context
@dataclasses.dataclass
class ModuleContext:
"""Module-wide context information for MLIR lowering."""
@ -303,7 +311,7 @@ class ModuleContext:
symbol_table: Optional[ir.SymbolTable] = None,
cached_primitive_lowerings: Optional[Dict[Any, builtin.FuncOp]] = None):
assert platform is not None
self.context = context or ir.Context()
self.context = context or make_ir_context()
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
self.ip = ip or ir.InsertionPoint(self.module.operation.opview.body)
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
@ -312,8 +320,6 @@ class ModuleContext:
self.name_stack = name_stack
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
else cached_primitive_lowerings)
mhlo.register_mhlo_dialect(self.context)
chlo.register_chlo_dialect(self.context)
def replace(self, **kw): return dataclasses.replace(self, **kw)

View File

@ -1040,7 +1040,8 @@ def lower_parallel_callable(
class PmapComputation:
def __init__(self, hlo, **compile_args):
_hlo: Union[ir.Module, xc.XlaComputation]
def __init__(self, hlo: Union[ir.Module, xc.XlaComputation], **compile_args):
self._executable = None
self._hlo = hlo
self.compile_args = compile_args
@ -1054,10 +1055,12 @@ class PmapComputation:
mlir.module_to_string(self._hlo),
use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> str:
def mhlo(self) -> ir.Module:
if isinstance(self._hlo, xc.XlaComputation):
return xe.mlir.xla_computation_to_mlir_module(self._hlo)
return mlir.module_to_string(self._hlo)
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
return ir.Module.parse(module_str)
return self._hlo
@profiler.annotate_function
def compile(self):
@ -2015,9 +2018,10 @@ def lower_mesh_computation(
class MeshComputation:
_hlo: Union[ir.Module, xc.XlaComputation]
_executable: Optional['MeshExecutable']
def __init__(self, name: str, hlo: Union[str, xc.XlaComputation],
def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation],
donated_invars: Sequence[bool], **compile_args):
self._name = name
self._hlo = hlo
@ -2035,8 +2039,10 @@ class MeshComputation:
def mhlo(self) -> str:
if isinstance(self._hlo, xc.XlaComputation):
return xe.mlir.xla_computation_to_mlir_module(self._hlo)
return mlir.module_to_string(self._hlo)
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
with mlir.make_ir_context():
return ir.Module.parse(module_str)
return self._hlo
def compile(self,
_allow_propagation_to_outputs : bool = False,

View File

@ -1994,7 +1994,7 @@ class APITest(jtu.JaxTestCase):
hlo = api.jit(e).lower(2.).compiler_ir(dialect="hlo").as_hlo_text()
self.assertIn(' cosine', hlo)
self.assertIn(' sine', hlo)
mhlo = api.jit(e).lower(2.).compiler_ir(dialect="mhlo")
mhlo = str(api.jit(e).lower(2.).compiler_ir(dialect="mhlo"))
self.assertIn('mhlo.cosine', mhlo)
self.assertIn('mhlo.sine', mhlo)

View File

@ -256,9 +256,9 @@ class PJitTest(jtu.BufferDonationTestCase):
self.assertAllClose(actual, expected, check_dtypes=False)
self.assertLen(actual[0]['a'].device_buffers, 4)
compiler_ir = f.lower(x).compiler_ir(dialect="mhlo")
self.assertIn("unspecified_dims=[0]", compiler_ir)
self.assertIn("unspecified_dims=[1]", compiler_ir)
mhlo_str = str(f.lower(x).compiler_ir(dialect="mhlo"))
self.assertIn("unspecified_dims=[0]", mhlo_str)
self.assertIn("unspecified_dims=[1]", mhlo_str)
def testCaching(self):
def f(x):