mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[JAX] Change signature of .mhlo() method on compiler IR objects to return an ir.Module object instead of its string representation.
It isn't free to pretty-print IR, so it's best to avoid it unless necessary. In addition, by returning an IR object, the user is now free to, say, print it with different options. For example, one can now write things like: ``` In [1]: import numpy as np, jax, jax.numpy as jnp In [2]: m = jax.jit(lambda x: x + jnp.array(np.arange(1000))).lower(7.).compiler_ir(dialect='mhlo') In [3]: m.operation.print(large_elements_limit=10) module @jit__lambda_.4 { func public @main(%arg0: tensor<f32>) -> tensor<1000xf32> { %0 = mhlo.constant opaque<"_", "0xDEADBEEF"> : tensor<1000xi32> %1 = "mhlo.convert"(%0) : (tensor<1000xi32>) -> tensor<1000xf32> %2 = "mhlo.broadcast_in_dim"(%arg0) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f32>) -> tensor<1000xf32> %3 = mhlo.add %2, %1 : tensor<1000xf32> return %3 : tensor<1000xf32> } } ``` Fixes https://github.com/google/jax/issues/9226 PiperOrigin-RevId: 422855649
This commit is contained in:
parent
e152000c03
commit
3fef74b2d0
@ -11,6 +11,8 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
## jax 0.2.28 (Unreleased)
|
||||
* [GitHub
|
||||
commits](https://github.com/google/jax/compare/jax-v0.2.27...main).
|
||||
* The `jax.jit(f).lower(...).compiler_ir(dialect='mhlo')` now returns an MLIR
|
||||
`ir.Module` object instead of its string representation.
|
||||
|
||||
## jaxlib 0.1.76 (Unreleased)
|
||||
|
||||
|
@ -510,12 +510,14 @@ class XlaComputation:
|
||||
mlir.module_to_string(self._hlo),
|
||||
use_tuple_args=self.compile_args["tuple_args"])
|
||||
|
||||
def mhlo(self) -> str:
|
||||
def mhlo(self) -> ir.Module:
|
||||
if self.is_trivial():
|
||||
raise ValueError("A trivial computation has no MHLO")
|
||||
if isinstance(self._hlo, xc.XlaComputation):
|
||||
return xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
||||
return mlir.module_to_string(self._hlo)
|
||||
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
||||
with mlir.make_ir_context():
|
||||
return ir.Module.parse(module_str)
|
||||
return self._hlo
|
||||
|
||||
def compile(self) -> 'XlaCompiledComputation':
|
||||
if self._executable is None:
|
||||
|
@ -281,6 +281,14 @@ def _source_info_to_location(
|
||||
|
||||
# Translation rules
|
||||
|
||||
def make_ir_context() -> ir.Context:
|
||||
"""Creates an MLIR context suitable for JAX IR."""
|
||||
context = ir.Context()
|
||||
mhlo.register_mhlo_dialect(context)
|
||||
chlo.register_chlo_dialect(context)
|
||||
return context
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ModuleContext:
|
||||
"""Module-wide context information for MLIR lowering."""
|
||||
@ -303,7 +311,7 @@ class ModuleContext:
|
||||
symbol_table: Optional[ir.SymbolTable] = None,
|
||||
cached_primitive_lowerings: Optional[Dict[Any, builtin.FuncOp]] = None):
|
||||
assert platform is not None
|
||||
self.context = context or ir.Context()
|
||||
self.context = context or make_ir_context()
|
||||
self.module = module or ir.Module.create(loc=ir.Location.unknown(self.context))
|
||||
self.ip = ip or ir.InsertionPoint(self.module.operation.opview.body)
|
||||
self.symbol_table = symbol_table or ir.SymbolTable(self.module.operation)
|
||||
@ -312,8 +320,6 @@ class ModuleContext:
|
||||
self.name_stack = name_stack
|
||||
self.cached_primitive_lowerings = ({} if cached_primitive_lowerings is None
|
||||
else cached_primitive_lowerings)
|
||||
mhlo.register_mhlo_dialect(self.context)
|
||||
chlo.register_chlo_dialect(self.context)
|
||||
|
||||
def replace(self, **kw): return dataclasses.replace(self, **kw)
|
||||
|
||||
|
@ -1040,7 +1040,8 @@ def lower_parallel_callable(
|
||||
|
||||
|
||||
class PmapComputation:
|
||||
def __init__(self, hlo, **compile_args):
|
||||
_hlo: Union[ir.Module, xc.XlaComputation]
|
||||
def __init__(self, hlo: Union[ir.Module, xc.XlaComputation], **compile_args):
|
||||
self._executable = None
|
||||
self._hlo = hlo
|
||||
self.compile_args = compile_args
|
||||
@ -1054,10 +1055,12 @@ class PmapComputation:
|
||||
mlir.module_to_string(self._hlo),
|
||||
use_tuple_args=self.compile_args["tuple_args"])
|
||||
|
||||
def mhlo(self) -> str:
|
||||
def mhlo(self) -> ir.Module:
|
||||
if isinstance(self._hlo, xc.XlaComputation):
|
||||
return xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
||||
return mlir.module_to_string(self._hlo)
|
||||
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
||||
with mlir.make_ir_context():
|
||||
return ir.Module.parse(module_str)
|
||||
return self._hlo
|
||||
|
||||
@profiler.annotate_function
|
||||
def compile(self):
|
||||
@ -2015,9 +2018,10 @@ def lower_mesh_computation(
|
||||
|
||||
|
||||
class MeshComputation:
|
||||
_hlo: Union[ir.Module, xc.XlaComputation]
|
||||
_executable: Optional['MeshExecutable']
|
||||
|
||||
def __init__(self, name: str, hlo: Union[str, xc.XlaComputation],
|
||||
def __init__(self, name: str, hlo: Union[ir.Module, xc.XlaComputation],
|
||||
donated_invars: Sequence[bool], **compile_args):
|
||||
self._name = name
|
||||
self._hlo = hlo
|
||||
@ -2035,8 +2039,10 @@ class MeshComputation:
|
||||
|
||||
def mhlo(self) -> str:
|
||||
if isinstance(self._hlo, xc.XlaComputation):
|
||||
return xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
||||
return mlir.module_to_string(self._hlo)
|
||||
module_str = xe.mlir.xla_computation_to_mlir_module(self._hlo)
|
||||
with mlir.make_ir_context():
|
||||
return ir.Module.parse(module_str)
|
||||
return self._hlo
|
||||
|
||||
def compile(self,
|
||||
_allow_propagation_to_outputs : bool = False,
|
||||
|
@ -1994,7 +1994,7 @@ class APITest(jtu.JaxTestCase):
|
||||
hlo = api.jit(e).lower(2.).compiler_ir(dialect="hlo").as_hlo_text()
|
||||
self.assertIn(' cosine', hlo)
|
||||
self.assertIn(' sine', hlo)
|
||||
mhlo = api.jit(e).lower(2.).compiler_ir(dialect="mhlo")
|
||||
mhlo = str(api.jit(e).lower(2.).compiler_ir(dialect="mhlo"))
|
||||
self.assertIn('mhlo.cosine', mhlo)
|
||||
self.assertIn('mhlo.sine', mhlo)
|
||||
|
||||
|
@ -256,9 +256,9 @@ class PJitTest(jtu.BufferDonationTestCase):
|
||||
self.assertAllClose(actual, expected, check_dtypes=False)
|
||||
self.assertLen(actual[0]['a'].device_buffers, 4)
|
||||
|
||||
compiler_ir = f.lower(x).compiler_ir(dialect="mhlo")
|
||||
self.assertIn("unspecified_dims=[0]", compiler_ir)
|
||||
self.assertIn("unspecified_dims=[1]", compiler_ir)
|
||||
mhlo_str = str(f.lower(x).compiler_ir(dialect="mhlo"))
|
||||
self.assertIn("unspecified_dims=[0]", mhlo_str)
|
||||
self.assertIn("unspecified_dims=[1]", mhlo_str)
|
||||
|
||||
def testCaching(self):
|
||||
def f(x):
|
||||
|
Loading…
x
Reference in New Issue
Block a user