Change the default IR dialect returned by .compiler_ir() to MHLO.

PiperOrigin-RevId: 423091674
This commit is contained in:
Peter Hawkins 2022-01-20 09:49:40 -08:00 committed by jax authors
parent 2bebf50783
commit 74e4db47da
2 changed files with 4 additions and 2 deletions

View File

@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
## jax 0.2.28 (Unreleased)
* [GitHub
commits](https://github.com/google/jax/compare/jax-v0.2.27...main).
* `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no
`dialect=` is passed.
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
`ir.Module` object instead of its string representation.

View File

@ -522,9 +522,9 @@ class Lowered:
self.donate_argnums, self._no_kwargs)
def compiler_ir(self, dialect: Optional[str] = None):
if dialect == "mhlo":
if dialect is None or dialect == "mhlo":
return self._lowering.mhlo()
elif dialect == "hlo" or dialect is None:
elif dialect == "hlo":
return self._lowering.hlo()
else:
raise ValueError(f"Unknown dialect {dialect}")