mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Change the default IR dialect returned by .compiler_ir() to MHLO.
PiperOrigin-RevId: 423091674
This commit is contained in:
parent
2bebf50783
commit
74e4db47da
@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.2.28 (Unreleased)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.27...main).
|
||||
* `jax.jit(f).lower(...).compiler_ir()` now defaults to the MHLO dialect if no
|
||||
`dialect=` is passed.
|
||||
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
|
||||
`ir.Module` object instead of its string representation.
|
||||
|
||||
|
@ -522,9 +522,9 @@ class Lowered:
|
||||
self.donate_argnums, self._no_kwargs)
|
||||
|
||||
def compiler_ir(self, dialect: Optional[str] = None):
|
||||
if dialect == "mhlo":
|
||||
if dialect is None or dialect == "mhlo":
|
||||
return self._lowering.mhlo()
|
||||
elif dialect == "hlo" or dialect is None:
|
||||
elif dialect == "hlo":
|
||||
return self._lowering.hlo()
|
||||
else:
|
||||
raise ValueError(f"Unknown dialect {dialect}")
|
||||
|
Loading…
x
Reference in New Issue
Block a user