From 2d1bc5c2a0138dca1a695dc1ed92edaaabc0c5aa Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Fri, 21 Feb 2025 09:45:14 -0800 Subject: [PATCH] Refactor Jax FFI lowering to prepare for implementing CPU/GPU callbacks using XLA's FFI. - This refactor just moves code around and should have no impact on tests or public-facing APIs. - `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency. PiperOrigin-RevId: 729561359 --- jax/_src/callback.py | 393 +++++++++++++++++++++++++--------- jax/_src/checkify.py | 3 +- jax/_src/debugging.py | 5 +- jax/_src/ffi.py | 212 ++++++++++++++---- jax/_src/interpreters/mlir.py | 279 ------------------------ jax/_src/pallas/primitives.py | 3 +- jax/interpreters/mlir.py | 8 +- tests/jaxpr_effects_test.py | 3 +- 8 files changed, 487 insertions(+), 419 deletions(-) diff --git a/jax/_src/callback.py b/jax/_src/callback.py index db257a2bd..bdceb98d9 100644 --- a/jax/_src/callback.py +++ b/jax/_src/callback.py @@ -27,17 +27,21 @@ from jax._src import deprecations from jax._src import dispatch from jax._src import dtypes from jax._src import effects +from jax._src import ffi +from jax._src import pickle_util from jax._src import sharding_impls from jax._src import tree_util from jax._src import util +from jax._src import xla_bridge as xb from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir -from jax._src.lax import lax +from jax._src.interpreters import xla from jax._src.lax.control_flow.loops import map as lax_map -from jax._src.lax.control_flow.loops import scan from jax._src.lib import xla_client as xc -from jax._src.sharding_impls import SingleDeviceSharding +from jax._src.lib.mlir import ir +from jax._src.lib.mlir.dialects import hlo +from jax._src.sharding_impls import SdyArraySharding, SdyArrayShardingList, SingleDeviceSharding from jax._src.typing import DeprecatedArg import numpy as np @@ -135,99 +139,8 @@ def pure_callback_transpose_rule(*args, **kwargs): ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule -def callback_batching_rule( - prim, - args, - dims, - *, - vectorized: bool | None | DeprecatedArg, - vmap_method: str | None, - result_avals: Sequence[core.ShapedArray], - **kwargs: Any, -): - if isinstance(vectorized, DeprecatedArg) and vmap_method is None: - deprecations.warn( - "jax-callback-vectorized", - f"The default behavior of {prim.name} under vmap will soon " - "change. Currently, the default behavior is to generate a sequential " - "vmap (i.e. a loop), but in the future the default will be to raise " - "an error. To keep the current default, set vmap_method='sequential'.", - stacklevel=6) - vmap_method = "sequential" - - axis_size, = {a.shape[d] for a, d in zip(args, dims) - if d is not batching.not_mapped} - new_args = [arg if dim is batching.not_mapped else - batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)] - batched_result_avals = tuple( - core.unmapped_aval(axis_size, 0, aval) for aval in result_avals) - - # For FFI calls we must update the layouts. We handle the output layouts - # here, but the input layout updates depend on the vmap_method parameter. - if ( - vmap_method not in ("sequential", "sequential_unrolled") and - kwargs.get("output_layouts") is not None - ): - kwargs["output_layouts"] = tuple( - None if layout is None else tuple(n + 1 for n in layout) + (0,) - for layout in kwargs["output_layouts"]) - - if vmap_method == "legacy_vectorized": - # This method is kept to support the behavior that was previously exposed - # when using `vectorized=True`. - if kwargs.get("input_layouts") is not None: - kwargs["input_layouts"] = tuple( - layout if d is batching.not_mapped else - (None if layout is None else tuple(n + 1 for n in layout) + (0,)) - for layout, d in zip(kwargs["input_layouts"], dims)) - outvals = prim.bind( - *new_args, - vectorized=vectorized, - vmap_method=vmap_method, - result_avals=batched_result_avals, - **kwargs, - ) - elif vmap_method == "expand_dims" or vmap_method == "broadcast_all": - size = axis_size if vmap_method == "broadcast_all" else 1 - bcast_args = [ - lax.broadcast(x, (size,)) if d is batching.not_mapped else x - for x, d in zip(new_args, dims)] - if kwargs.get("input_layouts") is not None: - kwargs["input_layouts"] = tuple( - None if layout is None else tuple(n + 1 for n in layout) + (0,) - for layout in kwargs["input_layouts"]) - outvals = prim.bind( - *bcast_args, - vectorized=vectorized, - vmap_method=vmap_method, - result_avals=batched_result_avals, - **kwargs, - ) - elif vmap_method == "sequential" or vmap_method == "sequential_unrolled": - is_batched = [d is not batching.not_mapped for d in dims] - unbatched_args, batched_args = util.partition_list(is_batched, new_args) - def _batch_fun(batched_args): - merged_args = util.merge_lists(is_batched, unbatched_args, batched_args) - return prim.bind( - *merged_args, - result_avals=result_avals, - vectorized=vectorized, - vmap_method=vmap_method, - **kwargs, - ) - unroll = vmap_method == "sequential_unrolled" - g = lambda _, x: ((), _batch_fun(x)) - _, outvals = scan(g, (), batched_args, unroll=unroll) - else: - raise NotImplementedError( - f"vmap is only supported for the {prim.name} primitive when vmap_method " - "is one of 'sequential', 'sequential_unrolled', 'expand_dims', " - f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.") - return tuple(outvals), (0,) * len(outvals) - - batching.primitive_batchers[pure_callback_p] = functools.partial( - callback_batching_rule, pure_callback_p + ffi.ffi_batching_rule, pure_callback_p ) @@ -318,7 +231,7 @@ def pure_callback_lowering( op_sharding = _callback_op_sharding( ctx.module_context.axis_context, sharding, ctx.avals_out) - result, _, _ = mlir.emit_python_callback( + result, _, _ = emit_python_callback( ctx, _callback, None, @@ -600,7 +513,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params): ctx.module_context.axis_context, sharding, ctx.avals_out) if ordered: token = ctx.tokens_in.get(_OrderedIOEffect) - result, token, _ = mlir.emit_python_callback( + result, token, _ = emit_python_callback( ctx, _callback, token, @@ -612,7 +525,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params): ) ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token})) else: - result, _, _ = mlir.emit_python_callback( + result, _, _ = emit_python_callback( ctx, _callback, None, @@ -677,3 +590,287 @@ def io_callback( ordered=ordered, ) return tree_util.tree_unflatten(out_tree, out_flat) + + + +def is_empty_shape(s: core.Shape) -> bool: + return any(d == 0 for d in s) + + +def send_to_host( + channel: int, + token: hlo.TokenType, + operand: Any, + name: str, + *, + sharding: SdyArrayShardingList | xc.OpSharding | None = None, +) -> ir.Value: + channel_handle = hlo.ChannelHandle.get(channel, mlir.SEND_TO_HOST_TYPE) + send_op = hlo.SendOp([operand], token, channel_handle, + is_host_transfer=ir.BoolAttr.get(True)) + send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + dict( + _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), + _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) + if sharding is not None: + if config.use_shardy_partitioner.value: + # `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed + # in the maximal sharding of the array type. Since a token has no rank, + # we need to create an equivalent sharding with no dimensions. If there + # are multiple shardings, just grab the first one since all these + # shardings should be the same. + assert isinstance(sharding, SdyArrayShardingList) + assert len(sharding.shardings) >= 1 + sharding = SdyArrayShardingList([ + SdyArraySharding( + mesh_shape=(), dimension_shardings=[], + logical_device_ids=sharding.shardings[0].logical_device_ids)]) + mlir.set_sharding(send_op, sharding) + return send_op.result + + +def receive_from_host( + channel: int, + token: hlo.TokenType, + out_aval: core.ShapedArray, + name: str, + *, + sharding: SdyArrayShardingList | xc.OpSharding | None = None, +) -> tuple[ir.Value, ir.Value]: + channel_handle = hlo.ChannelHandle.get(channel, mlir.RECV_FROM_HOST_TYPE) + recv_op = hlo.RecvOp([mlir.aval_to_ir_type(out_aval), + hlo.TokenType.get()], token, channel_handle, + is_host_transfer=ir.BoolAttr.get(True)) + recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( + dict( + _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), + _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) + if sharding is not None: + if config.use_shardy_partitioner.value: + assert isinstance(sharding, SdyArrayShardingList) + assert len(sharding.shardings) >= 1 + # `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the + # number of shardings to match the number of results, but JAX only sees + # the array result, we need to add an equivalent sharding for the token. + # Note that even if a function returns N results, we will end up with N + # `RecvOp`s, so we only need to get the first sharding. All shardings are + # the same anyways, operating on the same single device ID. + sharding = SdyArrayShardingList([ + sharding.shardings[0], + SdyArraySharding( + mesh_shape=(), dimension_shardings=[], + logical_device_ids=sharding.shardings[0].logical_device_ids)]) + mlir.set_sharding(recv_op, sharding) + # Token should be at the end of the results + result, token = recv_op.results + return token, result + + +def _emit_tpu_python_callback( + backend: xb.XlaBackend, + ctx: mlir.LoweringRuleContext, + callback, + token: Any | None, + operands: Sequence[ir.Value], + operand_avals: Sequence[core.ShapedArray], + operand_shapes: Sequence[xc.Shape], + result_avals: Sequence[core.ShapedArray], + result_shapes: Sequence[xc.Shape], + *, + sharding: SdyArrayShardingList | xc.OpSharding | None = None, +) -> tuple[Sequence[ir.Value], Any]: + token = token or hlo.create_token() + _wrapped_callback = callback + + send_channels = [] + if not operand_avals: + # If there are no operands to the callback, we need to insert a dummy send + # op or the callback will never be triggered! + # TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in + # MLIR builder. + callback_without_args = _wrapped_callback + def _wrapped_callback(*args): # pylint: disable=function-redefined + del args + return callback_without_args() + send_channel = ctx.module_context.new_channel() + dummy_send_aval = core.ShapedArray((1,), np.float32) + dummy_send_val = mlir.ir_constant(np.zeros(1, np.float32)) + operand_shapes = [*operand_shapes, + xla.aval_to_xla_shapes(dummy_send_aval)[0]] + token = send_to_host(send_channel, token, dummy_send_val, callback.__name__, + sharding=sharding) + send_channels.append(send_channel) + else: + for operand in operands: + channel = ctx.module_context.new_channel() + token = send_to_host(channel, token, operand, callback.__name__, + sharding=sharding) + send_channels.append(channel) + + recv_channels = [] + outputs = [] + for result_aval in result_avals: + channel = ctx.module_context.new_channel() + assert isinstance(result_aval, core.ShapedArray) + token, out = receive_from_host(channel, token, result_aval, + callback.__name__, sharding=sharding) + outputs.append(out) + recv_channels.append(channel) + ifrt_callback = backend.make_python_callback_from_host_send_and_recv( + _wrapped_callback, operand_shapes, result_shapes, send_channels, + recv_channels, pickle_util.dumps) + ctx.module_context.add_host_callback(ifrt_callback) + return outputs, token + + +def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None): + if minor_to_major is None: + # Needed for token layouts + layout: np.ndarray = np.zeros((0,), dtype="int64") + else: + layout = np.array(minor_to_major, dtype="int64") + return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get()) + + +def _aval_to_default_layouts(aval): + avals = [core.physical_aval(aval)] + # Row major order is default for `NumPy`. + return [list(range(aval.ndim - 1, -1, -1)) for aval in avals] + + +def emit_python_callback( + ctx: mlir.LoweringRuleContext, + callback, + token: Any | None, + operands: Sequence[ir.Value], + operand_avals: Sequence[core.ShapedArray], + result_avals: Sequence[core.ShapedArray], + *, + has_side_effect: bool, + sharding: SdyArrayShardingList | xc.OpSharding | None = None, + operand_layouts: Sequence[Sequence[int] | None] | None = None, + result_layouts: Sequence[Sequence[int] | None] | None = None, +) -> tuple[Sequence[mlir.IrValues], Any, Any]: + """Emits MLIR that calls back to a provided Python function.""" + if len(ctx.module_context.platforms) > 1: + raise NotImplementedError("multi-platform lowering for python_callback") + platform = ctx.module_context.platforms[0] + if platform not in {"cpu", "cuda", "rocm", "tpu"}: + raise ValueError( + f"`EmitPythonCallback` not supported on {platform} backend.") + backend = ctx.module_context.get_backend() + result_shapes = util.flatten( + [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) + operand_shapes = util.flatten( + [xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals]) + # Handling layouts + if operand_layouts is None: + operand_layouts = util.concatenate( + map(_aval_to_default_layouts, operand_avals)) + operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts) + if result_layouts is None: + result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals)) + result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts) + + # First we apply checks to ensure output shapes and dtypes match the expected + # ones. + def _wrapped_callback(*args): + out_vals = callback(*args) + if len(out_vals) != len(result_avals): + raise RuntimeError( + "Mismatched number of outputs from callback. " + "Expected: {}, Actual: {}".format(len(result_avals), len(out_vals))) + # Handle Python literals, and custom arrays, e.g., tf.Tensor. + out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals) + for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)): + if out_val.shape != out_aval.shape: + raise RuntimeError( + f"Incorrect output shape for return value #{i}: " + f"Expected: {out_aval.shape}, Actual: {out_val.shape}") + if out_val.dtype != out_aval.dtype: + raise RuntimeError( + f"Incorrect output dtype for return value #{i}: " + f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}") + + if platform == "tpu": + # On TPU we cannot receive empty arrays. So, we return from the wrapped + # callback only the non-empty results, and we will create empty constants + # in the receiving computation. + # TODO(b/238239458): fix TPU Recv to work with empty arrays. + non_empty_out_vals = tuple( + out_val + for out_val, result_aval in zip(out_vals, result_avals) + if not is_empty_shape(result_aval.shape)) + return non_empty_out_vals + else: + return out_vals + + if platform == "tpu": + non_empty_result_avals, non_empty_result_shapes = util.unzip2([ + (aval, shape) + for aval, shape in zip(result_avals, result_shapes) + if not is_empty_shape(aval.shape)]) + non_empty_outputs, token = _emit_tpu_python_callback( + backend, ctx, _wrapped_callback, token, + operands, operand_avals, operand_shapes, + non_empty_result_avals, non_empty_result_shapes, + sharding=sharding) + non_empty_outputs_iter = iter(non_empty_outputs) + outputs = [ + mlir.ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype)) + if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter) + for result_aval in result_avals] + return outputs, token, None + + result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals]) + if token: + + callback_without_token = _wrapped_callback + def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined + return (token, *callback_without_token(*args)) + + operand_shapes = [ + xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes + ] + result_shapes = [ + xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes + ] + operands = [token, *operands] + result_types = [mlir.token_type(), *result_types] + operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] + result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] + callback_descriptor, ifrt_callback = ( + backend.get_emit_python_callback_descriptor(_wrapped_callback, + operand_shapes, + result_shapes)) + ctx.module_context.add_host_callback(ifrt_callback) + descriptor_operand = mlir.ir_constant(callback_descriptor) + callback_operands = [descriptor_operand, *operands] + if operand_mlir_layouts is not None: + operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] + result_type = ir.TupleType.get_tuple(result_types) + call_target_name = ("xla_python_gpu_callback" + if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") + result = hlo.CustomCallOp( + [result_type], + callback_operands, + call_target_name=ir.StringAttr.get(call_target_name), + has_side_effect=ir.BoolAttr.get(has_side_effect), + api_version=mlir.i32_attr(2), + called_computations=ir.ArrayAttr.get([]), + backend_config=ir.StringAttr.get(str(callback_descriptor)), + operand_layouts=( + None if operand_mlir_layouts is None + else ir.ArrayAttr.get(operand_mlir_layouts)), + result_layouts=( + None if result_mlir_layouts is None + else ir.ArrayAttr.get(result_mlir_layouts))) + if sharding is not None: + mlir.set_sharding(result, sharding) + results = [ + hlo.get_tuple_element(result, mlir.i32_attr(i)) + for i in range(len(result_types)) + ] + if token: + token, *results = results + return results, token, ifrt_callback diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index e18ddd03f..5a6561afa 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -30,6 +30,7 @@ from jax._src import api from jax._src import api_util from jax._src import ad_checkpoint from jax._src import linear_util as lu +from jax._src import callback from jax._src import config from jax._src import core from jax._src import custom_derivatives @@ -518,7 +519,7 @@ def check_lowering_rule(ctx, *args, err_tree, debug): if not config.xla_runtime_errors.value: raise functionalization_error - out_op, _, _ = mlir.emit_python_callback( + out_op, _, _ = callback.emit_python_callback( ctx, callback=functools.partial(python_err, err_tree), token=None, operands=args, diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index 88874f2b9..45eac11c4 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -29,6 +29,7 @@ import numpy as np import jax import jax.numpy as jnp from jax import lax +from jax._src import callback as cb from jax._src import config from jax._src import core from jax._src import dispatch @@ -194,12 +195,12 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params): return () if effects.ordered_effects.contains(effect): token = ctx.tokens_in.get(effect) - result, token, _ = mlir.emit_python_callback( + result, token, _ = cb.emit_python_callback( ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out, has_side_effect=True) ctx.set_tokens_out(mlir.TokenSet({effect: token})) else: - result, _, _ = mlir.emit_python_callback( + result, _, _ = cb.emit_python_callback( ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out, has_side_effect=True, sharding=sharding) return result diff --git a/jax/_src/ffi.py b/jax/_src/ffi.py index d0cdaa30b..05697f00e 100644 --- a/jax/_src/ffi.py +++ b/jax/_src/ffi.py @@ -22,13 +22,13 @@ from typing import Any, overload import numpy as np +import jax from jax._src import core from jax._src import deprecations from jax._src import dispatch from jax._src import effects from jax._src import util from jax._src import xla_bridge -from jax._src.callback import callback_batching_rule from jax._src.interpreters import ad from jax._src.interpreters import batching from jax._src.interpreters import mlir @@ -157,6 +157,81 @@ def _convert_layout_for_lowering( return tuple(layout) +def build_ffi_lowering_function( + call_target_name: str, + *, + operand_layouts: Sequence[FfiLayoutOptions] | None = None, + result_layouts: Sequence[FfiLayoutOptions] | None = None, + backend_config: Mapping[str, ir.Attribute] | str | None = None, + **lowering_args: Any, +) -> Callable[..., ir.Operation]: + """Build a lowering op for an foreign function interface (FFI) target. + + By default, this lowering rule can use the input and output abstract values to + compute the input and output types and shapes for the custom call, assuming + row-major layouts. + + Note that layouts passed to this function as tuples should be in + minor-to-major order (as expected by XLA) rather than major-to-minor as used + by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``. + + If keyword arguments are passed to the lowering rule, these are treated as + attributes, and added to `backend_config`. + + Args: + call_target_name: The name of the custom call target. + operand_layouts: A sequence of layouts (dimension orders) for each operand. + By default, the operands are assumed to be row-major. + result_layouts: A sequence of layouts (dimension orders) for each result. + By default, the results are assumed to be row-major. + backend_config: Configuration data for the custom call. Any keyword + arguments passed to the lowering rule will added to this dictionary. + lowering_args: Any other arguments to :func:`mlir.custom_call` will also be + passed through if provided as extra arguments to this function. + """ + + def _lowering_op( + ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any + ) -> ir.Operation: + kwargs = dict(lowering_args) + kwargs.setdefault("api_version", 4) + if kwargs["api_version"] >= 4: + if backend_config is not None and not isinstance(backend_config, dict): + raise ValueError( + "When api_version > 4, backend_config must be a dictionary.") + kwargs["backend_config"] = dict( + backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()}) + else: + if params: + raise ValueError( + "The use of ffi_call attributes requires a custom call API version " + f"of at least 4; got api_version={kwargs['api_version']}.") + kwargs["backend_config"] = backend_config + if "result_types" not in kwargs: + kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] + if operand_layouts is None: + kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in) + else: + kwargs["operand_layouts"] = [ + _convert_layout_for_lowering(*args) + for args in zip(ctx.avals_in, operand_layouts)] + if result_layouts is None: + kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out) + else: + kwargs["result_layouts"] = [ + _convert_layout_for_lowering(*args) + for args in zip(ctx.avals_out, result_layouts)] + if "result_shapes" not in kwargs and not all( + core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out): + kwargs["result_shapes"] = [ + mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval))) + for aval in ctx.avals_out] + + return mlir.custom_call(call_target_name, operands=operands, **kwargs) + + return _lowering_op + + def ffi_lowering( call_target_name: str, *, @@ -193,41 +268,15 @@ def ffi_lowering( def _lowering( ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any ) -> Sequence[ir.Value | Sequence[ir.Value]]: - kwargs = dict(lowering_args) - kwargs.setdefault("api_version", 4) - if kwargs["api_version"] >= 4: - if backend_config is not None and not isinstance(backend_config, dict): - raise ValueError( - "When api_version > 4, backend_config must be a dictionary.") - kwargs["backend_config"] = dict( - backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()}) - else: - if params: - raise ValueError( - "The use of ffi_call attributes requires a custom call API version " - f"of at least 4; got api_version={kwargs['api_version']}.") - kwargs["backend_config"] = backend_config - if "result_types" not in kwargs: - kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out] - if operand_layouts is None: - kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in) - else: - kwargs["operand_layouts"] = [ - _convert_layout_for_lowering(*args) - for args in zip(ctx.avals_in, operand_layouts)] - if result_layouts is None: - kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out) - else: - kwargs["result_layouts"] = [ - _convert_layout_for_lowering(*args) - for args in zip(ctx.avals_out, result_layouts)] - if "result_shapes" not in kwargs and not all( - core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out): - kwargs["result_shapes"] = [ - mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval))) - for aval in ctx.avals_out] + result = build_ffi_lowering_function( + call_target_name, + operand_layouts=operand_layouts, + result_layouts=result_layouts, + backend_config=backend_config, + **lowering_args, + )(ctx, *operands, **params) - return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore + return result.results # type: ignore return _lowering @@ -630,6 +679,97 @@ def ffi_call_lowering( return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes)) +def ffi_batching_rule( + prim, + args, + dims, + *, + vectorized: bool | None | DeprecatedArg, + vmap_method: str | None, + result_avals: Sequence[core.ShapedArray], + **kwargs: Any, +): + if isinstance(vectorized, DeprecatedArg) and vmap_method is None: + deprecations.warn( + "jax-callback-vectorized", + f"The default behavior of {prim.name} under vmap will soon " + "change. Currently, the default behavior is to generate a sequential " + "vmap (i.e. a loop), but in the future the default will be to raise " + "an error. To keep the current default, set vmap_method='sequential'.", + stacklevel=6) + vmap_method = "sequential" + + axis_size, = {a.shape[d] for a, d in zip(args, dims) + if d is not batching.not_mapped} + new_args = [arg if dim is batching.not_mapped else + batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)] + batched_result_avals = tuple( + core.unmapped_aval(axis_size, 0, aval) for aval in result_avals) + + # For FFI calls we must update the layouts. We handle the output layouts + # here, but the input layout updates depend on the vmap_method parameter. + if ( + vmap_method not in ("sequential", "sequential_unrolled") and + kwargs.get("output_layouts") is not None + ): + kwargs["output_layouts"] = tuple( + None if layout is None else tuple(n + 1 for n in layout) + (0,) + for layout in kwargs["output_layouts"]) + + if vmap_method == "legacy_vectorized": + # This method is kept to support the behavior that was previously exposed + # when using `vectorized=True`. + if kwargs.get("input_layouts") is not None: + kwargs["input_layouts"] = tuple( + layout if d is batching.not_mapped else + (None if layout is None else tuple(n + 1 for n in layout) + (0,)) + for layout, d in zip(kwargs["input_layouts"], dims)) + outvals = prim.bind( + *new_args, + vectorized=vectorized, + vmap_method=vmap_method, + result_avals=batched_result_avals, + **kwargs, + ) + elif vmap_method == "expand_dims" or vmap_method == "broadcast_all": + size = axis_size if vmap_method == "broadcast_all" else 1 + bcast_args = [ + jax.lax.broadcast(x, (size,)) if d is batching.not_mapped else x + for x, d in zip(new_args, dims)] + if kwargs.get("input_layouts") is not None: + kwargs["input_layouts"] = tuple( + None if layout is None else tuple(n + 1 for n in layout) + (0,) + for layout in kwargs["input_layouts"]) + outvals = prim.bind( + *bcast_args, + vectorized=vectorized, + vmap_method=vmap_method, + result_avals=batched_result_avals, + **kwargs, + ) + elif vmap_method == "sequential" or vmap_method == "sequential_unrolled": + is_batched = [d is not batching.not_mapped for d in dims] + unbatched_args, batched_args = util.partition_list(is_batched, new_args) + def _batch_fun(batched_args): + merged_args = util.merge_lists(is_batched, unbatched_args, batched_args) + return prim.bind( + *merged_args, + result_avals=result_avals, + vectorized=vectorized, + vmap_method=vmap_method, + **kwargs, + ) + unroll = vmap_method == "sequential_unrolled" + g = lambda _, x: ((), _batch_fun(x)) + _, outvals = jax.lax.scan(g, (), batched_args, unroll=unroll) + else: + raise NotImplementedError( + f"vmap is only supported for the {prim.name} primitive when vmap_method " + "is one of 'sequential', 'sequential_unrolled', 'expand_dims', " + f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.") + return tuple(outvals), (0,) * len(outvals) + + ffi_call_p = core.Primitive("ffi_call") ffi_call_p.multiple_results = True dispatch.simple_impl(ffi_call_p) @@ -637,5 +777,5 @@ ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval) ad.primitive_jvps[ffi_call_p] = ffi_call_jvp ad.primitive_transposes[ffi_call_p] = ffi_call_transpose batching.primitive_batchers[ffi_call_p] = functools.partial( - callback_batching_rule, ffi_call_p) + ffi_batching_rule, ffi_call_p) mlir.register_lowering(ffi_call_p, ffi_call_lowering) diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index d3e8c22cf..06789a680 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -41,7 +41,6 @@ from jax._src import dtypes from jax._src import effects as effects_lib from jax._src import linear_util as lu from jax._src import path -from jax._src import pickle_util from jax._src import sharding_impls from jax._src import source_info_util from jax._src import util @@ -2824,284 +2823,6 @@ DEVICE_TO_DEVICE_TYPE = 1 SEND_TO_HOST_TYPE = 2 RECV_FROM_HOST_TYPE = 3 - -def is_empty_shape(s: core.Shape) -> bool: - return any(d == 0 for d in s) - - -def send_to_host( - channel: int, - token: hlo.TokenType, - operand: Any, - name: str, - *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, -) -> ir.Value: - channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE) - send_op = hlo.SendOp([operand], token, channel_handle, - is_host_transfer=ir.BoolAttr.get(True)) - send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( - dict( - _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), - _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) - if sharding is not None: - if config.use_shardy_partitioner.value: - # `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed - # in the maximal sharding of the array type. Since a token has no rank, - # we need to create an equivalent sharding with no dimensions. - assert isinstance(sharding, SdyArrayShardingList) - assert len(sharding.shardings) == 1 - sharding = SdyArrayShardingList([ - SdyArraySharding( - mesh_shape=(), dimension_shardings=[], - logical_device_ids=sharding.shardings[0].logical_device_ids)]) - set_sharding(send_op, sharding) - return send_op.result - - -def receive_from_host( - channel: int, - token: hlo.TokenType, - out_aval: core.ShapedArray, - name: str, - *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, -) -> tuple[ir.Value, ir.Value]: - channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE) - recv_op = hlo.RecvOp([aval_to_ir_type(out_aval), - hlo.TokenType.get()], token, channel_handle, - is_host_transfer=ir.BoolAttr.get(True)) - recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get( - dict( - _xla_host_transfer_handler_name=ir.StringAttr.get(str(name)), - _xla_host_transfer_rendezvous=ir.StringAttr.get(str(name)))) - if sharding is not None: - if config.use_shardy_partitioner.value: - assert isinstance(sharding, SdyArrayShardingList) - assert len(sharding.shardings) == 1 - # `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the - # number of shardings to match the number of results, but JAX only sees - # the array result, we need to add an equivalent sharding for the token. - sharding = SdyArrayShardingList([ - sharding.shardings[0], - SdyArraySharding( - mesh_shape=(), dimension_shardings=[], - logical_device_ids=sharding.shardings[0].logical_device_ids)]) - set_sharding(recv_op, sharding) - # Token should be at the end of the results - result, token = recv_op.results - return token, result - - -def _emit_tpu_python_callback( - backend: xb.XlaBackend, - ctx: LoweringRuleContext, - callback, - token: Any | None, - operands: Sequence[ir.Value], - operand_avals: Sequence[core.ShapedArray], - operand_shapes: Sequence[xc.Shape], - result_avals: Sequence[core.ShapedArray], - result_shapes: Sequence[xc.Shape], - *, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, -) -> tuple[Sequence[ir.Value], Any]: - token = token or hlo.create_token() - _wrapped_callback = callback - - send_channels = [] - if not operand_avals: - # If there are no operands to the callback, we need to insert a dummy send - # op or the callback will never be triggered! - # TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in - # MLIR builder. - callback_without_args = _wrapped_callback - def _wrapped_callback(*args): # pylint: disable=function-redefined - del args - return callback_without_args() - send_channel = ctx.module_context.new_channel() - dummy_send_aval = core.ShapedArray((1,), np.float32) - dummy_send_val = ir_constant(np.zeros(1, np.float32)) - operand_shapes = [*operand_shapes, - xla.aval_to_xla_shapes(dummy_send_aval)[0]] - token = send_to_host(send_channel, token, dummy_send_val, callback.__name__, - sharding=sharding) - send_channels.append(send_channel) - else: - for operand in operands: - channel = ctx.module_context.new_channel() - token = send_to_host(channel, token, operand, callback.__name__, - sharding=sharding) - send_channels.append(channel) - - recv_channels = [] - outputs = [] - for result_aval in result_avals: - channel = ctx.module_context.new_channel() - assert isinstance(result_aval, core.ShapedArray) - token, out = receive_from_host(channel, token, result_aval, - callback.__name__, sharding=sharding) - outputs.append(out) - recv_channels.append(channel) - ifrt_callback = backend.make_python_callback_from_host_send_and_recv( - _wrapped_callback, operand_shapes, result_shapes, send_channels, - recv_channels, pickle_util.dumps) - ctx.module_context.add_host_callback(ifrt_callback) - return outputs, token - - -def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None): - if minor_to_major is None: - # Needed for token layouts - layout: np.ndarray = np.zeros((0,), dtype="int64") - else: - layout = np.array(minor_to_major, dtype="int64") - return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get()) - -def _aval_to_default_layouts(aval): - avals = [core.physical_aval(aval)] - # Row major order is default for `NumPy`. - return [list(range(aval.ndim - 1, -1, -1)) for aval in avals] - - -def emit_python_callback( - ctx: LoweringRuleContext, - callback, - token: Any | None, - operands: Sequence[ir.Value], - operand_avals: Sequence[core.ShapedArray], - result_avals: Sequence[core.ShapedArray], - *, - has_side_effect: bool, - sharding: SdyArrayShardingList | xc.OpSharding | None = None, - operand_layouts: Sequence[Sequence[int] | None] | None = None, - result_layouts: Sequence[Sequence[int] | None] | None = None, -) -> tuple[Sequence[IrValues], Any, Any]: - """Emits MLIR that calls back to a provided Python function.""" - if len(ctx.module_context.platforms) > 1: - raise NotImplementedError("multi-platform lowering for python_callback") - platform = ctx.module_context.platforms[0] - if platform not in {"cpu", "cuda", "rocm", "tpu"}: - raise ValueError( - f"`EmitPythonCallback` not supported on {platform} backend.") - backend = ctx.module_context.get_backend() - result_shapes = util.flatten( - [xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals]) - operand_shapes = util.flatten( - [xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals]) - # Handling layouts - if operand_layouts is None: - operand_layouts = util.concatenate( - map(_aval_to_default_layouts, operand_avals)) - operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts) - if result_layouts is None: - result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals)) - result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts) - - # First we apply checks to ensure output shapes and dtypes match the expected - # ones. - def _wrapped_callback(*args): - out_vals = callback(*args) - if len(out_vals) != len(result_avals): - raise RuntimeError( - "Mismatched number of outputs from callback. " - "Expected: {}, Actual: {}".format(len(result_avals), len(out_vals))) - # Handle Python literals, and custom arrays, e.g., tf.Tensor. - out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals) - for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)): - if out_val.shape != out_aval.shape: - raise RuntimeError( - f"Incorrect output shape for return value #{i}: " - f"Expected: {out_aval.shape}, Actual: {out_val.shape}") - if out_val.dtype != out_aval.dtype: - raise RuntimeError( - f"Incorrect output dtype for return value #{i}: " - f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}") - - if platform == "tpu": - # On TPU we cannot receive empty arrays. So, we return from the wrapped - # callback only the non-empty results, and we will create empty constants - # in the receiving computation. - # TODO(b/238239458): fix TPU Recv to work with empty arrays. - non_empty_out_vals = tuple( - out_val - for out_val, result_aval in zip(out_vals, result_avals) - if not is_empty_shape(result_aval.shape)) - return non_empty_out_vals - else: - return out_vals - - if platform == "tpu": - non_empty_result_avals, non_empty_result_shapes = util.unzip2([ - (aval, shape) - for aval, shape in zip(result_avals, result_shapes) - if not is_empty_shape(aval.shape)]) - non_empty_outputs, token = _emit_tpu_python_callback( - backend, ctx, _wrapped_callback, token, - operands, operand_avals, operand_shapes, - non_empty_result_avals, non_empty_result_shapes, - sharding=sharding) - non_empty_outputs_iter = iter(non_empty_outputs) - outputs = [ - ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype)) - if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter) - for result_aval in result_avals] - return outputs, token, None - - result_types = flatten_ir_types([aval_to_ir_type(aval) for aval in result_avals]) - if token: - - callback_without_token = _wrapped_callback - def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined - return (token, *callback_without_token(*args)) - - operand_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes - ] - result_shapes = [ - xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes - ] - operands = [token, *operands] - result_types = [token_type(), *result_types] - operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts] - result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts] - callback_descriptor, ifrt_callback = ( - backend.get_emit_python_callback_descriptor(_wrapped_callback, - operand_shapes, - result_shapes)) - ctx.module_context.add_host_callback(ifrt_callback) - descriptor_operand = ir_constant(callback_descriptor) - callback_operands = [descriptor_operand, *operands] - if operand_mlir_layouts is not None: - operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts] - result_type = ir.TupleType.get_tuple(result_types) - call_target_name = ("xla_python_gpu_callback" - if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") - result = hlo.CustomCallOp( - [result_type], - callback_operands, - call_target_name=ir.StringAttr.get(call_target_name), - has_side_effect=ir.BoolAttr.get(has_side_effect), - api_version=i32_attr(2), - called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(str(callback_descriptor)), - operand_layouts=( - None if operand_mlir_layouts is None - else ir.ArrayAttr.get(operand_mlir_layouts)), - result_layouts=( - None if result_mlir_layouts is None - else ir.ArrayAttr.get(result_mlir_layouts))) - if sharding is not None: - set_sharding(result, sharding) - results = [ - hlo.get_tuple_element(result, i32_attr(i)) - for i in range(len(result_types)) - ] - if token: - token, *results = results - return results, token, ifrt_callback - - def build_mlir_module_helper( closed_jaxpr: core.ClosedJaxpr, *, name: str, platforms: Sequence[str], diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 57837a2be..07a6fcf0a 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -26,6 +26,7 @@ from jax import lax from jax import tree_util from jax._src import ad_util from jax._src import api_util +from jax._src import callback from jax._src import core as jax_core from jax._src import dtypes from jax._src import effects @@ -811,7 +812,7 @@ batching.primitive_batchers[debug_print_p] = functools.partial( @functools.partial(mlir.register_lowering, debug_print_p) def debug_print_lowering_rule(ctx, *args, **params): - result, _, _ = mlir.emit_python_callback( + result, _, _ = callback.emit_python_callback( ctx, functools.partial(debug_print_p.impl, **params), None, diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index 41456c4cf..0f32799f7 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -39,7 +39,6 @@ from jax._src.interpreters.mlir import ( dense_int_array as dense_int_array, dense_int_elements as dense_int_elements, dtype_to_ir_type as dtype_to_ir_type, - emit_python_callback as emit_python_callback, flatten_ir_types as flatten_ir_types, flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401 flatten_ir_values as flatten_ir_values, @@ -74,3 +73,10 @@ from jax._src.sharding_impls import ( SPMDAxisContext as SPMDAxisContext, ShardingContext as ShardingContext, ) + + +# TODO(dsuo): Temporarily maintain symbols related to callback lowering for sake +# of public APIs. +from jax._src.callback import ( + emit_python_callback as emit_python_callback, +) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index 43c910f00..c331bfaf4 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -22,6 +22,7 @@ import jax.numpy as jnp from jax import lax from jax.experimental import pjit from jax._src import ad_checkpoint +from jax._src import callback as cb from jax._src import dispatch from jax._src import config from jax._src import core @@ -123,7 +124,7 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out if effects.ordered_effects.contains(effect): token_in = ctx.tokens_in.get(effect) - out_op, token_out, _ = mlir.emit_python_callback( + out_op, token_out, _ = cb.emit_python_callback( ctx, callback, token_in, list(args), list(ctx.avals_in), list(ctx.avals_out), has_side_effect=True) if token_out: