Refactor Jax FFI lowering to prepare for implementing CPU/GPU callbacks using XLA's FFI.

- This refactor just moves code around and should have no impact on tests or public-facing APIs.
- `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency.

PiperOrigin-RevId: 729561359
This commit is contained in:
Daniel Suo 2025-02-21 09:45:14 -08:00 committed by jax authors
parent 673a02d614
commit 2d1bc5c2a0
8 changed files with 487 additions and 419 deletions

View File

@ -27,17 +27,21 @@ from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import ffi
from jax._src import pickle_util
from jax._src import sharding_impls
from jax._src import tree_util
from jax._src import util
from jax._src import xla_bridge as xb
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.interpreters import xla
from jax._src.lax.control_flow.loops import map as lax_map
from jax._src.lax.control_flow.loops import scan
from jax._src.lib import xla_client as xc
from jax._src.sharding_impls import SingleDeviceSharding
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import hlo
from jax._src.sharding_impls import SdyArraySharding, SdyArrayShardingList, SingleDeviceSharding
from jax._src.typing import DeprecatedArg
import numpy as np
@ -135,99 +139,8 @@ def pure_callback_transpose_rule(*args, **kwargs):
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule
def callback_batching_rule(
prim,
args,
dims,
*,
vectorized: bool | None | DeprecatedArg,
vmap_method: str | None,
result_avals: Sequence[core.ShapedArray],
**kwargs: Any,
):
if isinstance(vectorized, DeprecatedArg) and vmap_method is None:
deprecations.warn(
"jax-callback-vectorized",
f"The default behavior of {prim.name} under vmap will soon "
"change. Currently, the default behavior is to generate a sequential "
"vmap (i.e. a loop), but in the future the default will be to raise "
"an error. To keep the current default, set vmap_method='sequential'.",
stacklevel=6)
vmap_method = "sequential"
axis_size, = {a.shape[d] for a, d in zip(args, dims)
if d is not batching.not_mapped}
new_args = [arg if dim is batching.not_mapped else
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
batched_result_avals = tuple(
core.unmapped_aval(axis_size, 0, aval) for aval in result_avals)
# For FFI calls we must update the layouts. We handle the output layouts
# here, but the input layout updates depend on the vmap_method parameter.
if (
vmap_method not in ("sequential", "sequential_unrolled") and
kwargs.get("output_layouts") is not None
):
kwargs["output_layouts"] = tuple(
None if layout is None else tuple(n + 1 for n in layout) + (0,)
for layout in kwargs["output_layouts"])
if vmap_method == "legacy_vectorized":
# This method is kept to support the behavior that was previously exposed
# when using `vectorized=True`.
if kwargs.get("input_layouts") is not None:
kwargs["input_layouts"] = tuple(
layout if d is batching.not_mapped else
(None if layout is None else tuple(n + 1 for n in layout) + (0,))
for layout, d in zip(kwargs["input_layouts"], dims))
outvals = prim.bind(
*new_args,
vectorized=vectorized,
vmap_method=vmap_method,
result_avals=batched_result_avals,
**kwargs,
)
elif vmap_method == "expand_dims" or vmap_method == "broadcast_all":
size = axis_size if vmap_method == "broadcast_all" else 1
bcast_args = [
lax.broadcast(x, (size,)) if d is batching.not_mapped else x
for x, d in zip(new_args, dims)]
if kwargs.get("input_layouts") is not None:
kwargs["input_layouts"] = tuple(
None if layout is None else tuple(n + 1 for n in layout) + (0,)
for layout in kwargs["input_layouts"])
outvals = prim.bind(
*bcast_args,
vectorized=vectorized,
vmap_method=vmap_method,
result_avals=batched_result_avals,
**kwargs,
)
elif vmap_method == "sequential" or vmap_method == "sequential_unrolled":
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(batched_args):
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
return prim.bind(
*merged_args,
result_avals=result_avals,
vectorized=vectorized,
vmap_method=vmap_method,
**kwargs,
)
unroll = vmap_method == "sequential_unrolled"
g = lambda _, x: ((), _batch_fun(x))
_, outvals = scan(g, (), batched_args, unroll=unroll)
else:
raise NotImplementedError(
f"vmap is only supported for the {prim.name} primitive when vmap_method "
"is one of 'sequential', 'sequential_unrolled', 'expand_dims', "
f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.")
return tuple(outvals), (0,) * len(outvals)
batching.primitive_batchers[pure_callback_p] = functools.partial(
callback_batching_rule, pure_callback_p
ffi.ffi_batching_rule, pure_callback_p
)
@ -318,7 +231,7 @@ def pure_callback_lowering(
op_sharding = _callback_op_sharding(
ctx.module_context.axis_context, sharding, ctx.avals_out)
result, _, _ = mlir.emit_python_callback(
result, _, _ = emit_python_callback(
ctx,
_callback,
None,
@ -600,7 +513,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
ctx.module_context.axis_context, sharding, ctx.avals_out)
if ordered:
token = ctx.tokens_in.get(_OrderedIOEffect)
result, token, _ = mlir.emit_python_callback(
result, token, _ = emit_python_callback(
ctx,
_callback,
token,
@ -612,7 +525,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
)
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token}))
else:
result, _, _ = mlir.emit_python_callback(
result, _, _ = emit_python_callback(
ctx,
_callback,
None,
@ -677,3 +590,287 @@ def io_callback(
ordered=ordered,
)
return tree_util.tree_unflatten(out_tree, out_flat)
def is_empty_shape(s: core.Shape) -> bool:
return any(d == 0 for d in s)
def send_to_host(
channel: int,
token: hlo.TokenType,
operand: Any,
name: str,
*,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
) -> ir.Value:
channel_handle = hlo.ChannelHandle.get(channel, mlir.SEND_TO_HOST_TYPE)
send_op = hlo.SendOp([operand], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
dict(
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
if sharding is not None:
if config.use_shardy_partitioner.value:
# `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed
# in the maximal sharding of the array type. Since a token has no rank,
# we need to create an equivalent sharding with no dimensions. If there
# are multiple shardings, just grab the first one since all these
# shardings should be the same.
assert isinstance(sharding, SdyArrayShardingList)
assert len(sharding.shardings) >= 1
sharding = SdyArrayShardingList([
SdyArraySharding(
mesh_shape=(), dimension_shardings=[],
logical_device_ids=sharding.shardings[0].logical_device_ids)])
mlir.set_sharding(send_op, sharding)
return send_op.result
def receive_from_host(
channel: int,
token: hlo.TokenType,
out_aval: core.ShapedArray,
name: str,
*,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
) -> tuple[ir.Value, ir.Value]:
channel_handle = hlo.ChannelHandle.get(channel, mlir.RECV_FROM_HOST_TYPE)
recv_op = hlo.RecvOp([mlir.aval_to_ir_type(out_aval),
hlo.TokenType.get()], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
dict(
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
if sharding is not None:
if config.use_shardy_partitioner.value:
assert isinstance(sharding, SdyArrayShardingList)
assert len(sharding.shardings) >= 1
# `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the
# number of shardings to match the number of results, but JAX only sees
# the array result, we need to add an equivalent sharding for the token.
# Note that even if a function returns N results, we will end up with N
# `RecvOp`s, so we only need to get the first sharding. All shardings are
# the same anyways, operating on the same single device ID.
sharding = SdyArrayShardingList([
sharding.shardings[0],
SdyArraySharding(
mesh_shape=(), dimension_shardings=[],
logical_device_ids=sharding.shardings[0].logical_device_ids)])
mlir.set_sharding(recv_op, sharding)
# Token should be at the end of the results
result, token = recv_op.results
return token, result
def _emit_tpu_python_callback(
backend: xb.XlaBackend,
ctx: mlir.LoweringRuleContext,
callback,
token: Any | None,
operands: Sequence[ir.Value],
operand_avals: Sequence[core.ShapedArray],
operand_shapes: Sequence[xc.Shape],
result_avals: Sequence[core.ShapedArray],
result_shapes: Sequence[xc.Shape],
*,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
) -> tuple[Sequence[ir.Value], Any]:
token = token or hlo.create_token()
_wrapped_callback = callback
send_channels = []
if not operand_avals:
# If there are no operands to the callback, we need to insert a dummy send
# op or the callback will never be triggered!
# TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
# MLIR builder.
callback_without_args = _wrapped_callback
def _wrapped_callback(*args): # pylint: disable=function-redefined
del args
return callback_without_args()
send_channel = ctx.module_context.new_channel()
dummy_send_aval = core.ShapedArray((1,), np.float32)
dummy_send_val = mlir.ir_constant(np.zeros(1, np.float32))
operand_shapes = [*operand_shapes,
xla.aval_to_xla_shapes(dummy_send_aval)[0]]
token = send_to_host(send_channel, token, dummy_send_val, callback.__name__,
sharding=sharding)
send_channels.append(send_channel)
else:
for operand in operands:
channel = ctx.module_context.new_channel()
token = send_to_host(channel, token, operand, callback.__name__,
sharding=sharding)
send_channels.append(channel)
recv_channels = []
outputs = []
for result_aval in result_avals:
channel = ctx.module_context.new_channel()
assert isinstance(result_aval, core.ShapedArray)
token, out = receive_from_host(channel, token, result_aval,
callback.__name__, sharding=sharding)
outputs.append(out)
recv_channels.append(channel)
ifrt_callback = backend.make_python_callback_from_host_send_and_recv(
_wrapped_callback, operand_shapes, result_shapes, send_channels,
recv_channels, pickle_util.dumps)
ctx.module_context.add_host_callback(ifrt_callback)
return outputs, token
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
if minor_to_major is None:
# Needed for token layouts
layout: np.ndarray = np.zeros((0,), dtype="int64")
else:
layout = np.array(minor_to_major, dtype="int64")
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
def _aval_to_default_layouts(aval):
avals = [core.physical_aval(aval)]
# Row major order is default for `NumPy`.
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]
def emit_python_callback(
ctx: mlir.LoweringRuleContext,
callback,
token: Any | None,
operands: Sequence[ir.Value],
operand_avals: Sequence[core.ShapedArray],
result_avals: Sequence[core.ShapedArray],
*,
has_side_effect: bool,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
operand_layouts: Sequence[Sequence[int] | None] | None = None,
result_layouts: Sequence[Sequence[int] | None] | None = None,
) -> tuple[Sequence[mlir.IrValues], Any, Any]:
"""Emits MLIR that calls back to a provided Python function."""
if len(ctx.module_context.platforms) > 1:
raise NotImplementedError("multi-platform lowering for python_callback")
platform = ctx.module_context.platforms[0]
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
raise ValueError(
f"`EmitPythonCallback` not supported on {platform} backend.")
backend = ctx.module_context.get_backend()
result_shapes = util.flatten(
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
operand_shapes = util.flatten(
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
# Handling layouts
if operand_layouts is None:
operand_layouts = util.concatenate(
map(_aval_to_default_layouts, operand_avals))
operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts)
if result_layouts is None:
result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals))
result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts)
# First we apply checks to ensure output shapes and dtypes match the expected
# ones.
def _wrapped_callback(*args):
out_vals = callback(*args)
if len(out_vals) != len(result_avals):
raise RuntimeError(
"Mismatched number of outputs from callback. "
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals)
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
if out_val.shape != out_aval.shape:
raise RuntimeError(
f"Incorrect output shape for return value #{i}: "
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
if out_val.dtype != out_aval.dtype:
raise RuntimeError(
f"Incorrect output dtype for return value #{i}: "
f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}")
if platform == "tpu":
# On TPU we cannot receive empty arrays. So, we return from the wrapped
# callback only the non-empty results, and we will create empty constants
# in the receiving computation.
# TODO(b/238239458): fix TPU Recv to work with empty arrays.
non_empty_out_vals = tuple(
out_val
for out_val, result_aval in zip(out_vals, result_avals)
if not is_empty_shape(result_aval.shape))
return non_empty_out_vals
else:
return out_vals
if platform == "tpu":
non_empty_result_avals, non_empty_result_shapes = util.unzip2([
(aval, shape)
for aval, shape in zip(result_avals, result_shapes)
if not is_empty_shape(aval.shape)])
non_empty_outputs, token = _emit_tpu_python_callback(
backend, ctx, _wrapped_callback, token,
operands, operand_avals, operand_shapes,
non_empty_result_avals, non_empty_result_shapes,
sharding=sharding)
non_empty_outputs_iter = iter(non_empty_outputs)
outputs = [
mlir.ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype))
if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter)
for result_aval in result_avals]
return outputs, token, None
result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals])
if token:
callback_without_token = _wrapped_callback
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
return (token, *callback_without_token(*args))
operand_shapes = [
xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
]
result_shapes = [
xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes
]
operands = [token, *operands]
result_types = [mlir.token_type(), *result_types]
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
callback_descriptor, ifrt_callback = (
backend.get_emit_python_callback_descriptor(_wrapped_callback,
operand_shapes,
result_shapes))
ctx.module_context.add_host_callback(ifrt_callback)
descriptor_operand = mlir.ir_constant(callback_descriptor)
callback_operands = [descriptor_operand, *operands]
if operand_mlir_layouts is not None:
operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
result_type = ir.TupleType.get_tuple(result_types)
call_target_name = ("xla_python_gpu_callback"
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
result = hlo.CustomCallOp(
[result_type],
callback_operands,
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
api_version=mlir.i32_attr(2),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(str(callback_descriptor)),
operand_layouts=(
None if operand_mlir_layouts is None
else ir.ArrayAttr.get(operand_mlir_layouts)),
result_layouts=(
None if result_mlir_layouts is None
else ir.ArrayAttr.get(result_mlir_layouts)))
if sharding is not None:
mlir.set_sharding(result, sharding)
results = [
hlo.get_tuple_element(result, mlir.i32_attr(i))
for i in range(len(result_types))
]
if token:
token, *results = results
return results, token, ifrt_callback

View File

@ -30,6 +30,7 @@ from jax._src import api
from jax._src import api_util
from jax._src import ad_checkpoint
from jax._src import linear_util as lu
from jax._src import callback
from jax._src import config
from jax._src import core
from jax._src import custom_derivatives
@ -518,7 +519,7 @@ def check_lowering_rule(ctx, *args, err_tree, debug):
if not config.xla_runtime_errors.value:
raise functionalization_error
out_op, _, _ = mlir.emit_python_callback(
out_op, _, _ = callback.emit_python_callback(
ctx, callback=functools.partial(python_err, err_tree),
token=None,
operands=args,

View File

@ -29,6 +29,7 @@ import numpy as np
import jax
import jax.numpy as jnp
from jax import lax
from jax._src import callback as cb
from jax._src import config
from jax._src import core
from jax._src import dispatch
@ -194,12 +195,12 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
return ()
if effects.ordered_effects.contains(effect):
token = ctx.tokens_in.get(effect)
result, token, _ = mlir.emit_python_callback(
result, token, _ = cb.emit_python_callback(
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out,
has_side_effect=True)
ctx.set_tokens_out(mlir.TokenSet({effect: token}))
else:
result, _, _ = mlir.emit_python_callback(
result, _, _ = cb.emit_python_callback(
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out,
has_side_effect=True, sharding=sharding)
return result

View File

@ -22,13 +22,13 @@ from typing import Any, overload
import numpy as np
import jax
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
from jax._src import effects
from jax._src import util
from jax._src import xla_bridge
from jax._src.callback import callback_batching_rule
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
@ -157,6 +157,81 @@ def _convert_layout_for_lowering(
return tuple(layout)
def build_ffi_lowering_function(
call_target_name: str,
*,
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
result_layouts: Sequence[FfiLayoutOptions] | None = None,
backend_config: Mapping[str, ir.Attribute] | str | None = None,
**lowering_args: Any,
) -> Callable[..., ir.Operation]:
"""Build a lowering op for an foreign function interface (FFI) target.
By default, this lowering rule can use the input and output abstract values to
compute the input and output types and shapes for the custom call, assuming
row-major layouts.
Note that layouts passed to this function as tuples should be in
minor-to-major order (as expected by XLA) rather than major-to-minor as used
by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``.
If keyword arguments are passed to the lowering rule, these are treated as
attributes, and added to `backend_config`.
Args:
call_target_name: The name of the custom call target.
operand_layouts: A sequence of layouts (dimension orders) for each operand.
By default, the operands are assumed to be row-major.
result_layouts: A sequence of layouts (dimension orders) for each result.
By default, the results are assumed to be row-major.
backend_config: Configuration data for the custom call. Any keyword
arguments passed to the lowering rule will added to this dictionary.
lowering_args: Any other arguments to :func:`mlir.custom_call` will also be
passed through if provided as extra arguments to this function.
"""
def _lowering_op(
ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any
) -> ir.Operation:
kwargs = dict(lowering_args)
kwargs.setdefault("api_version", 4)
if kwargs["api_version"] >= 4:
if backend_config is not None and not isinstance(backend_config, dict):
raise ValueError(
"When api_version > 4, backend_config must be a dictionary.")
kwargs["backend_config"] = dict(
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
else:
if params:
raise ValueError(
"The use of ffi_call attributes requires a custom call API version "
f"of at least 4; got api_version={kwargs['api_version']}.")
kwargs["backend_config"] = backend_config
if "result_types" not in kwargs:
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
if operand_layouts is None:
kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in)
else:
kwargs["operand_layouts"] = [
_convert_layout_for_lowering(*args)
for args in zip(ctx.avals_in, operand_layouts)]
if result_layouts is None:
kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out)
else:
kwargs["result_layouts"] = [
_convert_layout_for_lowering(*args)
for args in zip(ctx.avals_out, result_layouts)]
if "result_shapes" not in kwargs and not all(
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
kwargs["result_shapes"] = [
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval)))
for aval in ctx.avals_out]
return mlir.custom_call(call_target_name, operands=operands, **kwargs)
return _lowering_op
def ffi_lowering(
call_target_name: str,
*,
@ -193,41 +268,15 @@ def ffi_lowering(
def _lowering(
ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any
) -> Sequence[ir.Value | Sequence[ir.Value]]:
kwargs = dict(lowering_args)
kwargs.setdefault("api_version", 4)
if kwargs["api_version"] >= 4:
if backend_config is not None and not isinstance(backend_config, dict):
raise ValueError(
"When api_version > 4, backend_config must be a dictionary.")
kwargs["backend_config"] = dict(
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
else:
if params:
raise ValueError(
"The use of ffi_call attributes requires a custom call API version "
f"of at least 4; got api_version={kwargs['api_version']}.")
kwargs["backend_config"] = backend_config
if "result_types" not in kwargs:
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
if operand_layouts is None:
kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in)
else:
kwargs["operand_layouts"] = [
_convert_layout_for_lowering(*args)
for args in zip(ctx.avals_in, operand_layouts)]
if result_layouts is None:
kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out)
else:
kwargs["result_layouts"] = [
_convert_layout_for_lowering(*args)
for args in zip(ctx.avals_out, result_layouts)]
if "result_shapes" not in kwargs and not all(
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
kwargs["result_shapes"] = [
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval)))
for aval in ctx.avals_out]
result = build_ffi_lowering_function(
call_target_name,
operand_layouts=operand_layouts,
result_layouts=result_layouts,
backend_config=backend_config,
**lowering_args,
)(ctx, *operands, **params)
return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore
return result.results # type: ignore
return _lowering
@ -630,6 +679,97 @@ def ffi_call_lowering(
return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes))
def ffi_batching_rule(
prim,
args,
dims,
*,
vectorized: bool | None | DeprecatedArg,
vmap_method: str | None,
result_avals: Sequence[core.ShapedArray],
**kwargs: Any,
):
if isinstance(vectorized, DeprecatedArg) and vmap_method is None:
deprecations.warn(
"jax-callback-vectorized",
f"The default behavior of {prim.name} under vmap will soon "
"change. Currently, the default behavior is to generate a sequential "
"vmap (i.e. a loop), but in the future the default will be to raise "
"an error. To keep the current default, set vmap_method='sequential'.",
stacklevel=6)
vmap_method = "sequential"
axis_size, = {a.shape[d] for a, d in zip(args, dims)
if d is not batching.not_mapped}
new_args = [arg if dim is batching.not_mapped else
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
batched_result_avals = tuple(
core.unmapped_aval(axis_size, 0, aval) for aval in result_avals)
# For FFI calls we must update the layouts. We handle the output layouts
# here, but the input layout updates depend on the vmap_method parameter.
if (
vmap_method not in ("sequential", "sequential_unrolled") and
kwargs.get("output_layouts") is not None
):
kwargs["output_layouts"] = tuple(
None if layout is None else tuple(n + 1 for n in layout) + (0,)
for layout in kwargs["output_layouts"])
if vmap_method == "legacy_vectorized":
# This method is kept to support the behavior that was previously exposed
# when using `vectorized=True`.
if kwargs.get("input_layouts") is not None:
kwargs["input_layouts"] = tuple(
layout if d is batching.not_mapped else
(None if layout is None else tuple(n + 1 for n in layout) + (0,))
for layout, d in zip(kwargs["input_layouts"], dims))
outvals = prim.bind(
*new_args,
vectorized=vectorized,
vmap_method=vmap_method,
result_avals=batched_result_avals,
**kwargs,
)
elif vmap_method == "expand_dims" or vmap_method == "broadcast_all":
size = axis_size if vmap_method == "broadcast_all" else 1
bcast_args = [
jax.lax.broadcast(x, (size,)) if d is batching.not_mapped else x
for x, d in zip(new_args, dims)]
if kwargs.get("input_layouts") is not None:
kwargs["input_layouts"] = tuple(
None if layout is None else tuple(n + 1 for n in layout) + (0,)
for layout in kwargs["input_layouts"])
outvals = prim.bind(
*bcast_args,
vectorized=vectorized,
vmap_method=vmap_method,
result_avals=batched_result_avals,
**kwargs,
)
elif vmap_method == "sequential" or vmap_method == "sequential_unrolled":
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(batched_args):
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
return prim.bind(
*merged_args,
result_avals=result_avals,
vectorized=vectorized,
vmap_method=vmap_method,
**kwargs,
)
unroll = vmap_method == "sequential_unrolled"
g = lambda _, x: ((), _batch_fun(x))
_, outvals = jax.lax.scan(g, (), batched_args, unroll=unroll)
else:
raise NotImplementedError(
f"vmap is only supported for the {prim.name} primitive when vmap_method "
"is one of 'sequential', 'sequential_unrolled', 'expand_dims', "
f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.")
return tuple(outvals), (0,) * len(outvals)
ffi_call_p = core.Primitive("ffi_call")
ffi_call_p.multiple_results = True
dispatch.simple_impl(ffi_call_p)
@ -637,5 +777,5 @@ ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval)
ad.primitive_jvps[ffi_call_p] = ffi_call_jvp
ad.primitive_transposes[ffi_call_p] = ffi_call_transpose
batching.primitive_batchers[ffi_call_p] = functools.partial(
callback_batching_rule, ffi_call_p)
ffi_batching_rule, ffi_call_p)
mlir.register_lowering(ffi_call_p, ffi_call_lowering)

View File

@ -41,7 +41,6 @@ from jax._src import dtypes
from jax._src import effects as effects_lib
from jax._src import linear_util as lu
from jax._src import path
from jax._src import pickle_util
from jax._src import sharding_impls
from jax._src import source_info_util
from jax._src import util
@ -2824,284 +2823,6 @@ DEVICE_TO_DEVICE_TYPE = 1
SEND_TO_HOST_TYPE = 2
RECV_FROM_HOST_TYPE = 3
def is_empty_shape(s: core.Shape) -> bool:
return any(d == 0 for d in s)
def send_to_host(
channel: int,
token: hlo.TokenType,
operand: Any,
name: str,
*,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
) -> ir.Value:
channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
send_op = hlo.SendOp([operand], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
dict(
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
if sharding is not None:
if config.use_shardy_partitioner.value:
# `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed
# in the maximal sharding of the array type. Since a token has no rank,
# we need to create an equivalent sharding with no dimensions.
assert isinstance(sharding, SdyArrayShardingList)
assert len(sharding.shardings) == 1
sharding = SdyArrayShardingList([
SdyArraySharding(
mesh_shape=(), dimension_shardings=[],
logical_device_ids=sharding.shardings[0].logical_device_ids)])
set_sharding(send_op, sharding)
return send_op.result
def receive_from_host(
channel: int,
token: hlo.TokenType,
out_aval: core.ShapedArray,
name: str,
*,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
) -> tuple[ir.Value, ir.Value]:
channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
hlo.TokenType.get()], token, channel_handle,
is_host_transfer=ir.BoolAttr.get(True))
recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
dict(
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
if sharding is not None:
if config.use_shardy_partitioner.value:
assert isinstance(sharding, SdyArrayShardingList)
assert len(sharding.shardings) == 1
# `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the
# number of shardings to match the number of results, but JAX only sees
# the array result, we need to add an equivalent sharding for the token.
sharding = SdyArrayShardingList([
sharding.shardings[0],
SdyArraySharding(
mesh_shape=(), dimension_shardings=[],
logical_device_ids=sharding.shardings[0].logical_device_ids)])
set_sharding(recv_op, sharding)
# Token should be at the end of the results
result, token = recv_op.results
return token, result
def _emit_tpu_python_callback(
backend: xb.XlaBackend,
ctx: LoweringRuleContext,
callback,
token: Any | None,
operands: Sequence[ir.Value],
operand_avals: Sequence[core.ShapedArray],
operand_shapes: Sequence[xc.Shape],
result_avals: Sequence[core.ShapedArray],
result_shapes: Sequence[xc.Shape],
*,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
) -> tuple[Sequence[ir.Value], Any]:
token = token or hlo.create_token()
_wrapped_callback = callback
send_channels = []
if not operand_avals:
# If there are no operands to the callback, we need to insert a dummy send
# op or the callback will never be triggered!
# TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
# MLIR builder.
callback_without_args = _wrapped_callback
def _wrapped_callback(*args): # pylint: disable=function-redefined
del args
return callback_without_args()
send_channel = ctx.module_context.new_channel()
dummy_send_aval = core.ShapedArray((1,), np.float32)
dummy_send_val = ir_constant(np.zeros(1, np.float32))
operand_shapes = [*operand_shapes,
xla.aval_to_xla_shapes(dummy_send_aval)[0]]
token = send_to_host(send_channel, token, dummy_send_val, callback.__name__,
sharding=sharding)
send_channels.append(send_channel)
else:
for operand in operands:
channel = ctx.module_context.new_channel()
token = send_to_host(channel, token, operand, callback.__name__,
sharding=sharding)
send_channels.append(channel)
recv_channels = []
outputs = []
for result_aval in result_avals:
channel = ctx.module_context.new_channel()
assert isinstance(result_aval, core.ShapedArray)
token, out = receive_from_host(channel, token, result_aval,
callback.__name__, sharding=sharding)
outputs.append(out)
recv_channels.append(channel)
ifrt_callback = backend.make_python_callback_from_host_send_and_recv(
_wrapped_callback, operand_shapes, result_shapes, send_channels,
recv_channels, pickle_util.dumps)
ctx.module_context.add_host_callback(ifrt_callback)
return outputs, token
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
if minor_to_major is None:
# Needed for token layouts
layout: np.ndarray = np.zeros((0,), dtype="int64")
else:
layout = np.array(minor_to_major, dtype="int64")
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
def _aval_to_default_layouts(aval):
avals = [core.physical_aval(aval)]
# Row major order is default for `NumPy`.
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]
def emit_python_callback(
ctx: LoweringRuleContext,
callback,
token: Any | None,
operands: Sequence[ir.Value],
operand_avals: Sequence[core.ShapedArray],
result_avals: Sequence[core.ShapedArray],
*,
has_side_effect: bool,
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
operand_layouts: Sequence[Sequence[int] | None] | None = None,
result_layouts: Sequence[Sequence[int] | None] | None = None,
) -> tuple[Sequence[IrValues], Any, Any]:
"""Emits MLIR that calls back to a provided Python function."""
if len(ctx.module_context.platforms) > 1:
raise NotImplementedError("multi-platform lowering for python_callback")
platform = ctx.module_context.platforms[0]
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
raise ValueError(
f"`EmitPythonCallback` not supported on {platform} backend.")
backend = ctx.module_context.get_backend()
result_shapes = util.flatten(
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
operand_shapes = util.flatten(
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
# Handling layouts
if operand_layouts is None:
operand_layouts = util.concatenate(
map(_aval_to_default_layouts, operand_avals))
operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts)
if result_layouts is None:
result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals))
result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts)
# First we apply checks to ensure output shapes and dtypes match the expected
# ones.
def _wrapped_callback(*args):
out_vals = callback(*args)
if len(out_vals) != len(result_avals):
raise RuntimeError(
"Mismatched number of outputs from callback. "
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals)
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
if out_val.shape != out_aval.shape:
raise RuntimeError(
f"Incorrect output shape for return value #{i}: "
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
if out_val.dtype != out_aval.dtype:
raise RuntimeError(
f"Incorrect output dtype for return value #{i}: "
f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}")
if platform == "tpu":
# On TPU we cannot receive empty arrays. So, we return from the wrapped
# callback only the non-empty results, and we will create empty constants
# in the receiving computation.
# TODO(b/238239458): fix TPU Recv to work with empty arrays.
non_empty_out_vals = tuple(
out_val
for out_val, result_aval in zip(out_vals, result_avals)
if not is_empty_shape(result_aval.shape))
return non_empty_out_vals
else:
return out_vals
if platform == "tpu":
non_empty_result_avals, non_empty_result_shapes = util.unzip2([
(aval, shape)
for aval, shape in zip(result_avals, result_shapes)
if not is_empty_shape(aval.shape)])
non_empty_outputs, token = _emit_tpu_python_callback(
backend, ctx, _wrapped_callback, token,
operands, operand_avals, operand_shapes,
non_empty_result_avals, non_empty_result_shapes,
sharding=sharding)
non_empty_outputs_iter = iter(non_empty_outputs)
outputs = [
ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype))
if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter)
for result_aval in result_avals]
return outputs, token, None
result_types = flatten_ir_types([aval_to_ir_type(aval) for aval in result_avals])
if token:
callback_without_token = _wrapped_callback
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
return (token, *callback_without_token(*args))
operand_shapes = [
xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
]
result_shapes = [
xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes
]
operands = [token, *operands]
result_types = [token_type(), *result_types]
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
callback_descriptor, ifrt_callback = (
backend.get_emit_python_callback_descriptor(_wrapped_callback,
operand_shapes,
result_shapes))
ctx.module_context.add_host_callback(ifrt_callback)
descriptor_operand = ir_constant(callback_descriptor)
callback_operands = [descriptor_operand, *operands]
if operand_mlir_layouts is not None:
operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
result_type = ir.TupleType.get_tuple(result_types)
call_target_name = ("xla_python_gpu_callback"
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
result = hlo.CustomCallOp(
[result_type],
callback_operands,
call_target_name=ir.StringAttr.get(call_target_name),
has_side_effect=ir.BoolAttr.get(has_side_effect),
api_version=i32_attr(2),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(str(callback_descriptor)),
operand_layouts=(
None if operand_mlir_layouts is None
else ir.ArrayAttr.get(operand_mlir_layouts)),
result_layouts=(
None if result_mlir_layouts is None
else ir.ArrayAttr.get(result_mlir_layouts)))
if sharding is not None:
set_sharding(result, sharding)
results = [
hlo.get_tuple_element(result, i32_attr(i))
for i in range(len(result_types))
]
if token:
token, *results = results
return results, token, ifrt_callback
def build_mlir_module_helper(
closed_jaxpr: core.ClosedJaxpr, *, name: str,
platforms: Sequence[str],

View File

@ -26,6 +26,7 @@ from jax import lax
from jax import tree_util
from jax._src import ad_util
from jax._src import api_util
from jax._src import callback
from jax._src import core as jax_core
from jax._src import dtypes
from jax._src import effects
@ -811,7 +812,7 @@ batching.primitive_batchers[debug_print_p] = functools.partial(
@functools.partial(mlir.register_lowering, debug_print_p)
def debug_print_lowering_rule(ctx, *args, **params):
result, _, _ = mlir.emit_python_callback(
result, _, _ = callback.emit_python_callback(
ctx,
functools.partial(debug_print_p.impl, **params),
None,

View File

@ -39,7 +39,6 @@ from jax._src.interpreters.mlir import (
dense_int_array as dense_int_array,
dense_int_elements as dense_int_elements,
dtype_to_ir_type as dtype_to_ir_type,
emit_python_callback as emit_python_callback,
flatten_ir_types as flatten_ir_types,
flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401
flatten_ir_values as flatten_ir_values,
@ -74,3 +73,10 @@ from jax._src.sharding_impls import (
SPMDAxisContext as SPMDAxisContext,
ShardingContext as ShardingContext,
)
# TODO(dsuo): Temporarily maintain symbols related to callback lowering for sake
# of public APIs.
from jax._src.callback import (
emit_python_callback as emit_python_callback,
)

View File

@ -22,6 +22,7 @@ import jax.numpy as jnp
from jax import lax
from jax.experimental import pjit
from jax._src import ad_checkpoint
from jax._src import callback as cb
from jax._src import dispatch
from jax._src import config
from jax._src import core
@ -123,7 +124,7 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
if effects.ordered_effects.contains(effect):
token_in = ctx.tokens_in.get(effect)
out_op, token_out, _ = mlir.emit_python_callback(
out_op, token_out, _ = cb.emit_python_callback(
ctx, callback, token_in, list(args), list(ctx.avals_in),
list(ctx.avals_out), has_side_effect=True)
if token_out: