Return mlir modules instead of XlaComputation from custom_partitioning.

This will help with exporting this call to the c-api.

PiperOrigin-RevId: 599921028
This commit is contained in:
Parker Schuh 2024-01-19 13:22:59 -08:00 committed by jax authors
parent a7023b18d5
commit 899765edd0
2 changed files with 10 additions and 7 deletions

View File

@ -2539,10 +2539,10 @@ def emit_python_callback(
token, *results = results
return results, token, ifrt_callback
def build_xla_computation_helper(
def build_mlir_module_helper(
closed_jaxpr: core.ClosedJaxpr, *, name: str,
platforms: Sequence[str],
backend_or_name: str, axis_context: AxisContext) -> xc.XlaComputation:
backend_or_name: str, axis_context: AxisContext) -> ir.Module:
"""Helper to generate pmap-style XLA computations for custom partitioners."""
if closed_jaxpr.effects:
raise NotImplementedError
@ -2552,9 +2552,7 @@ def build_xla_computation_helper(
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
axis_context=axis_context, platforms=platforms,
lowering_parameters=LoweringParameters())
return xc._xla.mlir.mlir_module_to_xla_computation(
module_to_string(lowering_result.module), use_tuple_args=False,
return_tuple=False)
return lowering_result.module
def custom_call(
call_target_name: str,

View File

@ -33,6 +33,7 @@ from jax._src.interpreters import partial_eval as pe
from jax._src.sharding_impls import _op_sharding_to_pos_sharding
from jax._src import custom_api_util
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
@ -180,7 +181,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
% (repr(closed_jaxpr.out_avals), repr(tiled_results))
)
axis_context = sharding_impls.SPMDAxisContext(mesh)
built = mlir.build_xla_computation_helper(
module = mlir.build_mlir_module_helper(
closed_jaxpr,
name="tmp_xla_computation",
platforms=module_context.platforms,
@ -188,7 +189,11 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
)
result_sharding = _pack_result_sharding(result_shape, result_shardings)
return built, arg_shardings, result_sharding
if xla_extension_version < 232:
built = xc._xla.mlir.mlir_module_to_xla_computation(
mlir.module_to_string(module), use_tuple_args=False, return_tuple=False)
return built, arg_shardings, result_sharding
return mlir.module_to_bytecode(module), arg_shardings, result_sharding
def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings,