[JAX] Convert stablehlo to MLIR bytecode, not an MLIR string.

Bytecode is considerably more compact.

PiperOrigin-RevId: 615386276
This commit is contained in:
Peter Hawkins 2024-03-13 06:01:37 -07:00 committed by jax authors
parent f0c5051004
commit 642f20de1c
2 changed files with 15 additions and 5 deletions

View File

@ -589,10 +589,14 @@ def xla_computation(fun: Callable,
arg_shardings=None,
result_shardings=None,
lowering_parameters=mlir.LoweringParameters())
if xla_extension_version >= 244:
m = mlir.module_to_bytecode(lowering_result.module)
else:
m = mlir.module_to_string(lowering_result.module)
built = xc._xla.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(lowering_result.module),
use_tuple_args=tuple_args,
return_tuple=True)
m, use_tuple_args=tuple_args, return_tuple=True)
out_shapes_flat = [
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
out_shape = tree_unflatten(out_tree(), out_shapes_flat)

View File

@ -48,6 +48,7 @@ from jax._src.layout import SpecifiedLayout
from jax._src.interpreters import mlir
from jax._src.lib.mlir import ir
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
source_info_util.register_exclusion(__file__)
@ -314,9 +315,14 @@ class XlaLowering(Lowering):
def hlo(self) -> xc.XlaComputation:
"""Return an HLO representation of this computation."""
hlo = self.stablehlo()
m: Union[str, bytes]
if xla_extension_version >= 244:
m = mlir.module_to_bytecode(hlo)
else:
m = mlir.module_to_string(hlo)
return xla_extension.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(self.stablehlo()),
use_tuple_args=self.compile_args["tuple_args"])
m, use_tuple_args=self.compile_args["tuple_args"])
def mhlo(self) -> ir.Module:
"""Return an MHLO representation of this computation."""