mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
[JAX] Convert stablehlo to MLIR bytecode, not an MLIR string.
Bytecode is considerably more compact. PiperOrigin-RevId: 615386276
This commit is contained in:
parent
f0c5051004
commit
642f20de1c
@ -589,10 +589,14 @@ def xla_computation(fun: Callable,
|
||||
arg_shardings=None,
|
||||
result_shardings=None,
|
||||
lowering_parameters=mlir.LoweringParameters())
|
||||
|
||||
if xla_extension_version >= 244:
|
||||
m = mlir.module_to_bytecode(lowering_result.module)
|
||||
else:
|
||||
m = mlir.module_to_string(lowering_result.module)
|
||||
|
||||
built = xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
mlir.module_to_string(lowering_result.module),
|
||||
use_tuple_args=tuple_args,
|
||||
return_tuple=True)
|
||||
m, use_tuple_args=tuple_args, return_tuple=True)
|
||||
out_shapes_flat = [
|
||||
ShapeDtypeStruct(a.shape, a.dtype, a.named_shape) for a in out_avals]
|
||||
out_shape = tree_unflatten(out_tree(), out_shapes_flat)
|
||||
|
@ -48,6 +48,7 @@ from jax._src.layout import SpecifiedLayout
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
@ -314,9 +315,14 @@ class XlaLowering(Lowering):
|
||||
|
||||
def hlo(self) -> xc.XlaComputation:
|
||||
"""Return an HLO representation of this computation."""
|
||||
hlo = self.stablehlo()
|
||||
m: Union[str, bytes]
|
||||
if xla_extension_version >= 244:
|
||||
m = mlir.module_to_bytecode(hlo)
|
||||
else:
|
||||
m = mlir.module_to_string(hlo)
|
||||
return xla_extension.mlir.mlir_module_to_xla_computation(
|
||||
mlir.module_to_string(self.stablehlo()),
|
||||
use_tuple_args=self.compile_args["tuple_args"])
|
||||
m, use_tuple_args=self.compile_args["tuple_args"])
|
||||
|
||||
def mhlo(self) -> ir.Module:
|
||||
"""Return an MHLO representation of this computation."""
|
||||
|
Loading…
x
Reference in New Issue
Block a user