diff --git a/CHANGELOG.md b/CHANGELOG.md index f7ccedb32..9b20f3bc0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.2.28 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.27...main). + * The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR + `ir.Module` object instead of its string representation. ## jaxlib 0.1.76 (Unreleased) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 0d2f65fde..deae68b16 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -510,12 +510,14 @@ class XlaComputation: mlir.module_to_string(self._hlo), use_tuple_args=self.compile_args["tuple_args"]) - def mhlo(self) -> str: + def mhlo(self) -> ir.Module: if self.is_trivial(): raise ValueError("A trivial computation has no MHLO") if isinstance(self._hlo, xc.XlaComputation): - return xe.mlir.xla_computation_to_mlir_module(self._hlo) - return mlir.module_to_string(self._hlo) + module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo) + with mlir.make_ir_context(): + return ir.Module.parse(module_str) + return self._hlo def compile(self) -> 'XlaCompiledComputation': if self._executable is None: diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 317f1cc1d..87b067bfc 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -281,6 +281,14 @@ def _source_info_to_location( # Translation rules +def make_ir_context() -> ir.Context: + """Creates an MLIR context suitable for JAX IR.""" + context = ir.Context() + mhlo.register_mhlo_dialect(context) + chlo.register_chlo_dialect(context) + return context + + @dataclasses.dataclass class ModuleContext: """Module-wide context information for MLIR lowering.""" @@ -303,7 +311,7 @@ class ModuleContext: symbol_table: Optional[ir.SymbolTable] = None, cached_primitive_lowerings: Optional[Dict[Any, builtin.FuncOp]] = None): assert platform is not None - self.context = context or ir.Context() + self.context = context or make_ir_context() self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context)) self.ip = ip or ir.InsertionPoint(self.module.operation.opview.body) self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation) @@ -312,8 +320,6 @@ class ModuleContext: self.name_stack = name_stack self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None else cached_primitive_lowerings) - mhlo.register_mhlo_dialect(self.context) - chlo.register_chlo_dialect(self.context) def replace(self, **kw): return dataclasses.replace(self, **kw) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 7efca3d8c..7ba45ac36 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -1040,7 +1040,8 @@ def lower_parallel_callable( class PmapComputation: - def __init__(self, hlo, **compile_args): + _hlo: Union[ir.Module, xc.XlaComputation] + def __init__(self, hlo: Union[ir.Module, xc.XlaComputation], **compile_args): self._executable = None self._hlo = hlo self.compile_args = compile_args @@ -1054,10 +1055,12 @@ class PmapComputation: mlir.module_to_string(self._hlo), use_tuple_args=self.compile_args["tuple_args"]) - def mhlo(self) -> str: + def mhlo(self) -> ir.Module: if isinstance(self._hlo, xc.XlaComputation): - return xe.mlir.xla_computation_to_mlir_module(self._hlo) - return mlir.module_to_string(self._hlo) + module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo) + with mlir.make_ir_context(): + return ir.Module.parse(module_str) + return self._hlo @profiler.annotate_function def compile(self): @@ -2015,9 +2018,10 @@ def lower_mesh_computation( class MeshComputation: + _hlo: Union[ir.Module, xc.XlaComputation] _executable: Optional['MeshExecutable'] - def __init__(self, name: str, hlo: Union[str, xc.XlaComputation], + def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation], donated_invars: Sequence[bool], **compile_args): self._name = name self._hlo = hlo @@ -2035,8 +2039,10 @@ class MeshComputation: def mhlo(self) -> str: if isinstance(self._hlo, xc.XlaComputation): - return xe.mlir.xla_computation_to_mlir_module(self._hlo) - return mlir.module_to_string(self._hlo) + module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo) + with mlir.make_ir_context(): + return ir.Module.parse(module_str) + return self._hlo def compile(self, _allow_propagation_to_outputs : bool = False, diff --git a/tests/api_test.py b/tests/api_test.py index 1b0029c84..e50ef7098 100644 --- a/tests/api_test.py +++ b/tests/api_test.py @@ -1994,7 +1994,7 @@ class APITest(jtu.JaxTestCase): hlo = api.jit(e).lower(2.).compiler_ir(dialect="hlo").as_hlo_text() self.assertIn(' cosine', hlo) self.assertIn(' sine', hlo) - mhlo = api.jit(e).lower(2.).compiler_ir(dialect="mhlo") + mhlo = str(api.jit(e).lower(2.).compiler_ir(dialect="mhlo")) self.assertIn('mhlo.cosine', mhlo) self.assertIn('mhlo.sine', mhlo) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 9f6212bf1..826d8b283 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -256,9 +256,9 @@ class PJitTest(jtu.BufferDonationTestCase): self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual[0]['a'].device_buffers, 4) - compiler_ir = f.lower(x).compiler_ir(dialect="mhlo") - self.assertIn("unspecified_dims=[0]", compiler_ir) - self.assertIn("unspecified_dims=[1]", compiler_ir) + mhlo_str = str(f.lower(x).compiler_ir(dialect="mhlo")) + self.assertIn("unspecified_dims=[0]", mhlo_str) + self.assertIn("unspecified_dims=[1]", mhlo_str) def testCaching(self): def f(x):