diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b20f3bc0..8412f0948 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.2.28 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.2.27...main). + * `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no + `dialect=` is passed. * The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR `ir.Module` object instead of its string representation. diff --git a/jax/_src/api.py b/jax/_src/api.py index 0a9901b96..e208c0410 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -522,9 +522,9 @@ class Lowered: self.donate_argnums, self._no_kwargs) def compiler_ir(self, dialect: Optional[str] = None): - if dialect == "mhlo": + if dialect is None or dialect == "mhlo": return self._lowering.mhlo() - elif dialect == "hlo" or dialect is None: + elif dialect == "hlo": return self._lowering.hlo() else: raise ValueError(f"Unknown dialect {dialect}")