mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Return mlir modules instead of XlaComputation from custom_partitioning.
This will help with exporting this call to the c-api. PiperOrigin-RevId: 599921028
This commit is contained in:
parent
a7023b18d5
commit
899765edd0
@ -2539,10 +2539,10 @@ def emit_python_callback(
|
||||
token, *results = results
|
||||
return results, token, ifrt_callback
|
||||
|
||||
def build_xla_computation_helper(
|
||||
def build_mlir_module_helper(
|
||||
closed_jaxpr: core.ClosedJaxpr, *, name: str,
|
||||
platforms: Sequence[str],
|
||||
backend_or_name: str, axis_context: AxisContext) -> xc.XlaComputation:
|
||||
backend_or_name: str, axis_context: AxisContext) -> ir.Module:
|
||||
"""Helper to generate pmap-style XLA computations for custom partitioners."""
|
||||
if closed_jaxpr.effects:
|
||||
raise NotImplementedError
|
||||
@ -2552,9 +2552,7 @@ def build_xla_computation_helper(
|
||||
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
|
||||
axis_context=axis_context, platforms=platforms,
|
||||
lowering_parameters=LoweringParameters())
|
||||
return xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
module_to_string(lowering_result.module), use_tuple_args=False,
|
||||
return_tuple=False)
|
||||
return lowering_result.module
|
||||
|
||||
def custom_call(
|
||||
call_target_name: str,
|
||||
|
@ -33,6 +33,7 @@ from jax._src.interpreters import partial_eval as pe
|
||||
from jax._src.sharding_impls import _op_sharding_to_pos_sharding
|
||||
from jax._src import custom_api_util
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.lib import xla_extension_version
|
||||
from jax._src.api_util import flatten_fun_nokwargs, argnums_partial
|
||||
|
||||
|
||||
@ -180,7 +181,7 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
% (repr(closed_jaxpr.out_avals), repr(tiled_results))
|
||||
)
|
||||
axis_context = sharding_impls.SPMDAxisContext(mesh)
|
||||
built = mlir.build_xla_computation_helper(
|
||||
module = mlir.build_mlir_module_helper(
|
||||
closed_jaxpr,
|
||||
name="tmp_xla_computation",
|
||||
platforms=module_context.platforms,
|
||||
@ -188,7 +189,11 @@ def _custom_partitioning_partition(arg_shapes, arg_shardings, result_shape,
|
||||
axis_context=axis_context.extend_manual(frozenset(mesh.axis_names)),
|
||||
)
|
||||
result_sharding = _pack_result_sharding(result_shape, result_shardings)
|
||||
return built, arg_shardings, result_sharding
|
||||
if xla_extension_version < 232:
|
||||
built = xc._xla.mlir.mlir_module_to_xla_computation(
|
||||
mlir.module_to_string(module), use_tuple_args=False, return_tuple=False)
|
||||
return built, arg_shardings, result_sharding
|
||||
return mlir.module_to_bytecode(module), arg_shardings, result_sharding
|
||||
|
||||
|
||||
def _custom_partitioning_infer_sharding_from_operands(arg_shapes, arg_shardings,
|
||||
|
Loading…
x
Reference in New Issue
Block a user