mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Refactor Jax FFI lowering to prepare for implementing CPU/GPU callbacks using XLA's FFI.
- This refactor just moves code around and should have no impact on tests or public-facing APIs. - `mlir.emit_python_callback` would eventually depend on `ffi.ffi_lowering`, which in turn depends on definitions in `mlir.py`. We break this circular dependency. PiperOrigin-RevId: 729561359
This commit is contained in:
parent
673a02d614
commit
2d1bc5c2a0
@ -27,17 +27,21 @@ from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
from jax._src import ffi
|
||||
from jax._src import pickle_util
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import tree_util
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge as xb
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import lax
|
||||
from jax._src.interpreters import xla
|
||||
from jax._src.lax.control_flow.loops import map as lax_map
|
||||
from jax._src.lax.control_flow.loops import scan
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding_impls import SingleDeviceSharding
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.sharding_impls import SdyArraySharding, SdyArrayShardingList, SingleDeviceSharding
|
||||
from jax._src.typing import DeprecatedArg
|
||||
import numpy as np
|
||||
|
||||
@ -135,99 +139,8 @@ def pure_callback_transpose_rule(*args, **kwargs):
|
||||
ad.primitive_transposes[pure_callback_p] = pure_callback_transpose_rule
|
||||
|
||||
|
||||
def callback_batching_rule(
|
||||
prim,
|
||||
args,
|
||||
dims,
|
||||
*,
|
||||
vectorized: bool | None | DeprecatedArg,
|
||||
vmap_method: str | None,
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
**kwargs: Any,
|
||||
):
|
||||
if isinstance(vectorized, DeprecatedArg) and vmap_method is None:
|
||||
deprecations.warn(
|
||||
"jax-callback-vectorized",
|
||||
f"The default behavior of {prim.name} under vmap will soon "
|
||||
"change. Currently, the default behavior is to generate a sequential "
|
||||
"vmap (i.e. a loop), but in the future the default will be to raise "
|
||||
"an error. To keep the current default, set vmap_method='sequential'.",
|
||||
stacklevel=6)
|
||||
vmap_method = "sequential"
|
||||
|
||||
axis_size, = {a.shape[d] for a, d in zip(args, dims)
|
||||
if d is not batching.not_mapped}
|
||||
new_args = [arg if dim is batching.not_mapped else
|
||||
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
||||
batched_result_avals = tuple(
|
||||
core.unmapped_aval(axis_size, 0, aval) for aval in result_avals)
|
||||
|
||||
# For FFI calls we must update the layouts. We handle the output layouts
|
||||
# here, but the input layout updates depend on the vmap_method parameter.
|
||||
if (
|
||||
vmap_method not in ("sequential", "sequential_unrolled") and
|
||||
kwargs.get("output_layouts") is not None
|
||||
):
|
||||
kwargs["output_layouts"] = tuple(
|
||||
None if layout is None else tuple(n + 1 for n in layout) + (0,)
|
||||
for layout in kwargs["output_layouts"])
|
||||
|
||||
if vmap_method == "legacy_vectorized":
|
||||
# This method is kept to support the behavior that was previously exposed
|
||||
# when using `vectorized=True`.
|
||||
if kwargs.get("input_layouts") is not None:
|
||||
kwargs["input_layouts"] = tuple(
|
||||
layout if d is batching.not_mapped else
|
||||
(None if layout is None else tuple(n + 1 for n in layout) + (0,))
|
||||
for layout, d in zip(kwargs["input_layouts"], dims))
|
||||
outvals = prim.bind(
|
||||
*new_args,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
result_avals=batched_result_avals,
|
||||
**kwargs,
|
||||
)
|
||||
elif vmap_method == "expand_dims" or vmap_method == "broadcast_all":
|
||||
size = axis_size if vmap_method == "broadcast_all" else 1
|
||||
bcast_args = [
|
||||
lax.broadcast(x, (size,)) if d is batching.not_mapped else x
|
||||
for x, d in zip(new_args, dims)]
|
||||
if kwargs.get("input_layouts") is not None:
|
||||
kwargs["input_layouts"] = tuple(
|
||||
None if layout is None else tuple(n + 1 for n in layout) + (0,)
|
||||
for layout in kwargs["input_layouts"])
|
||||
outvals = prim.bind(
|
||||
*bcast_args,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
result_avals=batched_result_avals,
|
||||
**kwargs,
|
||||
)
|
||||
elif vmap_method == "sequential" or vmap_method == "sequential_unrolled":
|
||||
is_batched = [d is not batching.not_mapped for d in dims]
|
||||
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
||||
def _batch_fun(batched_args):
|
||||
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
|
||||
return prim.bind(
|
||||
*merged_args,
|
||||
result_avals=result_avals,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
**kwargs,
|
||||
)
|
||||
unroll = vmap_method == "sequential_unrolled"
|
||||
g = lambda _, x: ((), _batch_fun(x))
|
||||
_, outvals = scan(g, (), batched_args, unroll=unroll)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"vmap is only supported for the {prim.name} primitive when vmap_method "
|
||||
"is one of 'sequential', 'sequential_unrolled', 'expand_dims', "
|
||||
f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.")
|
||||
return tuple(outvals), (0,) * len(outvals)
|
||||
|
||||
|
||||
batching.primitive_batchers[pure_callback_p] = functools.partial(
|
||||
callback_batching_rule, pure_callback_p
|
||||
ffi.ffi_batching_rule, pure_callback_p
|
||||
)
|
||||
|
||||
|
||||
@ -318,7 +231,7 @@ def pure_callback_lowering(
|
||||
|
||||
op_sharding = _callback_op_sharding(
|
||||
ctx.module_context.axis_context, sharding, ctx.avals_out)
|
||||
result, _, _ = mlir.emit_python_callback(
|
||||
result, _, _ = emit_python_callback(
|
||||
ctx,
|
||||
_callback,
|
||||
None,
|
||||
@ -600,7 +513,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
||||
ctx.module_context.axis_context, sharding, ctx.avals_out)
|
||||
if ordered:
|
||||
token = ctx.tokens_in.get(_OrderedIOEffect)
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
result, token, _ = emit_python_callback(
|
||||
ctx,
|
||||
_callback,
|
||||
token,
|
||||
@ -612,7 +525,7 @@ def io_callback_lowering(ctx, *args, callback, sharding, ordered, **params):
|
||||
)
|
||||
ctx.set_tokens_out(mlir.TokenSet({_OrderedIOEffect: token}))
|
||||
else:
|
||||
result, _, _ = mlir.emit_python_callback(
|
||||
result, _, _ = emit_python_callback(
|
||||
ctx,
|
||||
_callback,
|
||||
None,
|
||||
@ -677,3 +590,287 @@ def io_callback(
|
||||
ordered=ordered,
|
||||
)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
|
||||
|
||||
def is_empty_shape(s: core.Shape) -> bool:
|
||||
return any(d == 0 for d in s)
|
||||
|
||||
|
||||
def send_to_host(
|
||||
channel: int,
|
||||
token: hlo.TokenType,
|
||||
operand: Any,
|
||||
name: str,
|
||||
*,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
) -> ir.Value:
|
||||
channel_handle = hlo.ChannelHandle.get(channel, mlir.SEND_TO_HOST_TYPE)
|
||||
send_op = hlo.SendOp([operand], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
||||
dict(
|
||||
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
||||
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
||||
if sharding is not None:
|
||||
if config.use_shardy_partitioner.value:
|
||||
# `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed
|
||||
# in the maximal sharding of the array type. Since a token has no rank,
|
||||
# we need to create an equivalent sharding with no dimensions. If there
|
||||
# are multiple shardings, just grab the first one since all these
|
||||
# shardings should be the same.
|
||||
assert isinstance(sharding, SdyArrayShardingList)
|
||||
assert len(sharding.shardings) >= 1
|
||||
sharding = SdyArrayShardingList([
|
||||
SdyArraySharding(
|
||||
mesh_shape=(), dimension_shardings=[],
|
||||
logical_device_ids=sharding.shardings[0].logical_device_ids)])
|
||||
mlir.set_sharding(send_op, sharding)
|
||||
return send_op.result
|
||||
|
||||
|
||||
def receive_from_host(
|
||||
channel: int,
|
||||
token: hlo.TokenType,
|
||||
out_aval: core.ShapedArray,
|
||||
name: str,
|
||||
*,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
) -> tuple[ir.Value, ir.Value]:
|
||||
channel_handle = hlo.ChannelHandle.get(channel, mlir.RECV_FROM_HOST_TYPE)
|
||||
recv_op = hlo.RecvOp([mlir.aval_to_ir_type(out_aval),
|
||||
hlo.TokenType.get()], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
||||
dict(
|
||||
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
||||
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
||||
if sharding is not None:
|
||||
if config.use_shardy_partitioner.value:
|
||||
assert isinstance(sharding, SdyArrayShardingList)
|
||||
assert len(sharding.shardings) >= 1
|
||||
# `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the
|
||||
# number of shardings to match the number of results, but JAX only sees
|
||||
# the array result, we need to add an equivalent sharding for the token.
|
||||
# Note that even if a function returns N results, we will end up with N
|
||||
# `RecvOp`s, so we only need to get the first sharding. All shardings are
|
||||
# the same anyways, operating on the same single device ID.
|
||||
sharding = SdyArrayShardingList([
|
||||
sharding.shardings[0],
|
||||
SdyArraySharding(
|
||||
mesh_shape=(), dimension_shardings=[],
|
||||
logical_device_ids=sharding.shardings[0].logical_device_ids)])
|
||||
mlir.set_sharding(recv_op, sharding)
|
||||
# Token should be at the end of the results
|
||||
result, token = recv_op.results
|
||||
return token, result
|
||||
|
||||
|
||||
def _emit_tpu_python_callback(
|
||||
backend: xb.XlaBackend,
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
callback,
|
||||
token: Any | None,
|
||||
operands: Sequence[ir.Value],
|
||||
operand_avals: Sequence[core.ShapedArray],
|
||||
operand_shapes: Sequence[xc.Shape],
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
result_shapes: Sequence[xc.Shape],
|
||||
*,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
) -> tuple[Sequence[ir.Value], Any]:
|
||||
token = token or hlo.create_token()
|
||||
_wrapped_callback = callback
|
||||
|
||||
send_channels = []
|
||||
if not operand_avals:
|
||||
# If there are no operands to the callback, we need to insert a dummy send
|
||||
# op or the callback will never be triggered!
|
||||
# TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
|
||||
# MLIR builder.
|
||||
callback_without_args = _wrapped_callback
|
||||
def _wrapped_callback(*args): # pylint: disable=function-redefined
|
||||
del args
|
||||
return callback_without_args()
|
||||
send_channel = ctx.module_context.new_channel()
|
||||
dummy_send_aval = core.ShapedArray((1,), np.float32)
|
||||
dummy_send_val = mlir.ir_constant(np.zeros(1, np.float32))
|
||||
operand_shapes = [*operand_shapes,
|
||||
xla.aval_to_xla_shapes(dummy_send_aval)[0]]
|
||||
token = send_to_host(send_channel, token, dummy_send_val, callback.__name__,
|
||||
sharding=sharding)
|
||||
send_channels.append(send_channel)
|
||||
else:
|
||||
for operand in operands:
|
||||
channel = ctx.module_context.new_channel()
|
||||
token = send_to_host(channel, token, operand, callback.__name__,
|
||||
sharding=sharding)
|
||||
send_channels.append(channel)
|
||||
|
||||
recv_channels = []
|
||||
outputs = []
|
||||
for result_aval in result_avals:
|
||||
channel = ctx.module_context.new_channel()
|
||||
assert isinstance(result_aval, core.ShapedArray)
|
||||
token, out = receive_from_host(channel, token, result_aval,
|
||||
callback.__name__, sharding=sharding)
|
||||
outputs.append(out)
|
||||
recv_channels.append(channel)
|
||||
ifrt_callback = backend.make_python_callback_from_host_send_and_recv(
|
||||
_wrapped_callback, operand_shapes, result_shapes, send_channels,
|
||||
recv_channels, pickle_util.dumps)
|
||||
ctx.module_context.add_host_callback(ifrt_callback)
|
||||
return outputs, token
|
||||
|
||||
|
||||
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
|
||||
if minor_to_major is None:
|
||||
# Needed for token layouts
|
||||
layout: np.ndarray = np.zeros((0,), dtype="int64")
|
||||
else:
|
||||
layout = np.array(minor_to_major, dtype="int64")
|
||||
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
|
||||
|
||||
|
||||
def _aval_to_default_layouts(aval):
|
||||
avals = [core.physical_aval(aval)]
|
||||
# Row major order is default for `NumPy`.
|
||||
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]
|
||||
|
||||
|
||||
def emit_python_callback(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
callback,
|
||||
token: Any | None,
|
||||
operands: Sequence[ir.Value],
|
||||
operand_avals: Sequence[core.ShapedArray],
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
*,
|
||||
has_side_effect: bool,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
operand_layouts: Sequence[Sequence[int] | None] | None = None,
|
||||
result_layouts: Sequence[Sequence[int] | None] | None = None,
|
||||
) -> tuple[Sequence[mlir.IrValues], Any, Any]:
|
||||
"""Emits MLIR that calls back to a provided Python function."""
|
||||
if len(ctx.module_context.platforms) > 1:
|
||||
raise NotImplementedError("multi-platform lowering for python_callback")
|
||||
platform = ctx.module_context.platforms[0]
|
||||
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
|
||||
raise ValueError(
|
||||
f"`EmitPythonCallback` not supported on {platform} backend.")
|
||||
backend = ctx.module_context.get_backend()
|
||||
result_shapes = util.flatten(
|
||||
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
|
||||
operand_shapes = util.flatten(
|
||||
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
|
||||
# Handling layouts
|
||||
if operand_layouts is None:
|
||||
operand_layouts = util.concatenate(
|
||||
map(_aval_to_default_layouts, operand_avals))
|
||||
operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts)
|
||||
if result_layouts is None:
|
||||
result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals))
|
||||
result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts)
|
||||
|
||||
# First we apply checks to ensure output shapes and dtypes match the expected
|
||||
# ones.
|
||||
def _wrapped_callback(*args):
|
||||
out_vals = callback(*args)
|
||||
if len(out_vals) != len(result_avals):
|
||||
raise RuntimeError(
|
||||
"Mismatched number of outputs from callback. "
|
||||
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
|
||||
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
|
||||
out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals)
|
||||
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
|
||||
if out_val.shape != out_aval.shape:
|
||||
raise RuntimeError(
|
||||
f"Incorrect output shape for return value #{i}: "
|
||||
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
|
||||
if out_val.dtype != out_aval.dtype:
|
||||
raise RuntimeError(
|
||||
f"Incorrect output dtype for return value #{i}: "
|
||||
f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}")
|
||||
|
||||
if platform == "tpu":
|
||||
# On TPU we cannot receive empty arrays. So, we return from the wrapped
|
||||
# callback only the non-empty results, and we will create empty constants
|
||||
# in the receiving computation.
|
||||
# TODO(b/238239458): fix TPU Recv to work with empty arrays.
|
||||
non_empty_out_vals = tuple(
|
||||
out_val
|
||||
for out_val, result_aval in zip(out_vals, result_avals)
|
||||
if not is_empty_shape(result_aval.shape))
|
||||
return non_empty_out_vals
|
||||
else:
|
||||
return out_vals
|
||||
|
||||
if platform == "tpu":
|
||||
non_empty_result_avals, non_empty_result_shapes = util.unzip2([
|
||||
(aval, shape)
|
||||
for aval, shape in zip(result_avals, result_shapes)
|
||||
if not is_empty_shape(aval.shape)])
|
||||
non_empty_outputs, token = _emit_tpu_python_callback(
|
||||
backend, ctx, _wrapped_callback, token,
|
||||
operands, operand_avals, operand_shapes,
|
||||
non_empty_result_avals, non_empty_result_shapes,
|
||||
sharding=sharding)
|
||||
non_empty_outputs_iter = iter(non_empty_outputs)
|
||||
outputs = [
|
||||
mlir.ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype))
|
||||
if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter)
|
||||
for result_aval in result_avals]
|
||||
return outputs, token, None
|
||||
|
||||
result_types = mlir.flatten_ir_types([mlir.aval_to_ir_type(aval) for aval in result_avals])
|
||||
if token:
|
||||
|
||||
callback_without_token = _wrapped_callback
|
||||
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
|
||||
return (token, *callback_without_token(*args))
|
||||
|
||||
operand_shapes = [
|
||||
xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
|
||||
]
|
||||
result_shapes = [
|
||||
xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes
|
||||
]
|
||||
operands = [token, *operands]
|
||||
result_types = [mlir.token_type(), *result_types]
|
||||
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
|
||||
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
|
||||
callback_descriptor, ifrt_callback = (
|
||||
backend.get_emit_python_callback_descriptor(_wrapped_callback,
|
||||
operand_shapes,
|
||||
result_shapes))
|
||||
ctx.module_context.add_host_callback(ifrt_callback)
|
||||
descriptor_operand = mlir.ir_constant(callback_descriptor)
|
||||
callback_operands = [descriptor_operand, *operands]
|
||||
if operand_mlir_layouts is not None:
|
||||
operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
|
||||
result_type = ir.TupleType.get_tuple(result_types)
|
||||
call_target_name = ("xla_python_gpu_callback"
|
||||
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
|
||||
result = hlo.CustomCallOp(
|
||||
[result_type],
|
||||
callback_operands,
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
||||
api_version=mlir.i32_attr(2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
backend_config=ir.StringAttr.get(str(callback_descriptor)),
|
||||
operand_layouts=(
|
||||
None if operand_mlir_layouts is None
|
||||
else ir.ArrayAttr.get(operand_mlir_layouts)),
|
||||
result_layouts=(
|
||||
None if result_mlir_layouts is None
|
||||
else ir.ArrayAttr.get(result_mlir_layouts)))
|
||||
if sharding is not None:
|
||||
mlir.set_sharding(result, sharding)
|
||||
results = [
|
||||
hlo.get_tuple_element(result, mlir.i32_attr(i))
|
||||
for i in range(len(result_types))
|
||||
]
|
||||
if token:
|
||||
token, *results = results
|
||||
return results, token, ifrt_callback
|
||||
|
@ -30,6 +30,7 @@ from jax._src import api
|
||||
from jax._src import api_util
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import callback
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import custom_derivatives
|
||||
@ -518,7 +519,7 @@ def check_lowering_rule(ctx, *args, err_tree, debug):
|
||||
if not config.xla_runtime_errors.value:
|
||||
raise functionalization_error
|
||||
|
||||
out_op, _, _ = mlir.emit_python_callback(
|
||||
out_op, _, _ = callback.emit_python_callback(
|
||||
ctx, callback=functools.partial(python_err, err_tree),
|
||||
token=None,
|
||||
operands=args,
|
||||
|
@ -29,6 +29,7 @@ import numpy as np
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax._src import callback as cb
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
from jax._src import dispatch
|
||||
@ -194,12 +195,12 @@ def debug_callback_lowering(ctx, *args, effect, callback, **params):
|
||||
return ()
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token = ctx.tokens_in.get(effect)
|
||||
result, token, _ = mlir.emit_python_callback(
|
||||
result, token, _ = cb.emit_python_callback(
|
||||
ctx, _callback, token, list(args), ctx.avals_in, ctx.avals_out,
|
||||
has_side_effect=True)
|
||||
ctx.set_tokens_out(mlir.TokenSet({effect: token}))
|
||||
else:
|
||||
result, _, _ = mlir.emit_python_callback(
|
||||
result, _, _ = cb.emit_python_callback(
|
||||
ctx, _callback, None, list(args), ctx.avals_in, ctx.avals_out,
|
||||
has_side_effect=True, sharding=sharding)
|
||||
return result
|
||||
|
212
jax/_src/ffi.py
212
jax/_src/ffi.py
@ -22,13 +22,13 @@ from typing import Any, overload
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import util
|
||||
from jax._src import xla_bridge
|
||||
from jax._src.callback import callback_batching_rule
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
@ -157,6 +157,81 @@ def _convert_layout_for_lowering(
|
||||
return tuple(layout)
|
||||
|
||||
|
||||
def build_ffi_lowering_function(
|
||||
call_target_name: str,
|
||||
*,
|
||||
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
|
||||
result_layouts: Sequence[FfiLayoutOptions] | None = None,
|
||||
backend_config: Mapping[str, ir.Attribute] | str | None = None,
|
||||
**lowering_args: Any,
|
||||
) -> Callable[..., ir.Operation]:
|
||||
"""Build a lowering op for an foreign function interface (FFI) target.
|
||||
|
||||
By default, this lowering rule can use the input and output abstract values to
|
||||
compute the input and output types and shapes for the custom call, assuming
|
||||
row-major layouts.
|
||||
|
||||
Note that layouts passed to this function as tuples should be in
|
||||
minor-to-major order (as expected by XLA) rather than major-to-minor as used
|
||||
by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``.
|
||||
|
||||
If keyword arguments are passed to the lowering rule, these are treated as
|
||||
attributes, and added to `backend_config`.
|
||||
|
||||
Args:
|
||||
call_target_name: The name of the custom call target.
|
||||
operand_layouts: A sequence of layouts (dimension orders) for each operand.
|
||||
By default, the operands are assumed to be row-major.
|
||||
result_layouts: A sequence of layouts (dimension orders) for each result.
|
||||
By default, the results are assumed to be row-major.
|
||||
backend_config: Configuration data for the custom call. Any keyword
|
||||
arguments passed to the lowering rule will added to this dictionary.
|
||||
lowering_args: Any other arguments to :func:`mlir.custom_call` will also be
|
||||
passed through if provided as extra arguments to this function.
|
||||
"""
|
||||
|
||||
def _lowering_op(
|
||||
ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any
|
||||
) -> ir.Operation:
|
||||
kwargs = dict(lowering_args)
|
||||
kwargs.setdefault("api_version", 4)
|
||||
if kwargs["api_version"] >= 4:
|
||||
if backend_config is not None and not isinstance(backend_config, dict):
|
||||
raise ValueError(
|
||||
"When api_version > 4, backend_config must be a dictionary.")
|
||||
kwargs["backend_config"] = dict(
|
||||
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
|
||||
else:
|
||||
if params:
|
||||
raise ValueError(
|
||||
"The use of ffi_call attributes requires a custom call API version "
|
||||
f"of at least 4; got api_version={kwargs['api_version']}.")
|
||||
kwargs["backend_config"] = backend_config
|
||||
if "result_types" not in kwargs:
|
||||
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
|
||||
if operand_layouts is None:
|
||||
kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in)
|
||||
else:
|
||||
kwargs["operand_layouts"] = [
|
||||
_convert_layout_for_lowering(*args)
|
||||
for args in zip(ctx.avals_in, operand_layouts)]
|
||||
if result_layouts is None:
|
||||
kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out)
|
||||
else:
|
||||
kwargs["result_layouts"] = [
|
||||
_convert_layout_for_lowering(*args)
|
||||
for args in zip(ctx.avals_out, result_layouts)]
|
||||
if "result_shapes" not in kwargs and not all(
|
||||
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
|
||||
kwargs["result_shapes"] = [
|
||||
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval)))
|
||||
for aval in ctx.avals_out]
|
||||
|
||||
return mlir.custom_call(call_target_name, operands=operands, **kwargs)
|
||||
|
||||
return _lowering_op
|
||||
|
||||
|
||||
def ffi_lowering(
|
||||
call_target_name: str,
|
||||
*,
|
||||
@ -193,41 +268,15 @@ def ffi_lowering(
|
||||
def _lowering(
|
||||
ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any
|
||||
) -> Sequence[ir.Value | Sequence[ir.Value]]:
|
||||
kwargs = dict(lowering_args)
|
||||
kwargs.setdefault("api_version", 4)
|
||||
if kwargs["api_version"] >= 4:
|
||||
if backend_config is not None and not isinstance(backend_config, dict):
|
||||
raise ValueError(
|
||||
"When api_version > 4, backend_config must be a dictionary.")
|
||||
kwargs["backend_config"] = dict(
|
||||
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
|
||||
else:
|
||||
if params:
|
||||
raise ValueError(
|
||||
"The use of ffi_call attributes requires a custom call API version "
|
||||
f"of at least 4; got api_version={kwargs['api_version']}.")
|
||||
kwargs["backend_config"] = backend_config
|
||||
if "result_types" not in kwargs:
|
||||
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
|
||||
if operand_layouts is None:
|
||||
kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in)
|
||||
else:
|
||||
kwargs["operand_layouts"] = [
|
||||
_convert_layout_for_lowering(*args)
|
||||
for args in zip(ctx.avals_in, operand_layouts)]
|
||||
if result_layouts is None:
|
||||
kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out)
|
||||
else:
|
||||
kwargs["result_layouts"] = [
|
||||
_convert_layout_for_lowering(*args)
|
||||
for args in zip(ctx.avals_out, result_layouts)]
|
||||
if "result_shapes" not in kwargs and not all(
|
||||
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
|
||||
kwargs["result_shapes"] = [
|
||||
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval)))
|
||||
for aval in ctx.avals_out]
|
||||
result = build_ffi_lowering_function(
|
||||
call_target_name,
|
||||
operand_layouts=operand_layouts,
|
||||
result_layouts=result_layouts,
|
||||
backend_config=backend_config,
|
||||
**lowering_args,
|
||||
)(ctx, *operands, **params)
|
||||
|
||||
return mlir.custom_call(call_target_name, operands=operands, **kwargs).results # type: ignore
|
||||
return result.results # type: ignore
|
||||
|
||||
return _lowering
|
||||
|
||||
@ -630,6 +679,97 @@ def ffi_call_lowering(
|
||||
return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes))
|
||||
|
||||
|
||||
def ffi_batching_rule(
|
||||
prim,
|
||||
args,
|
||||
dims,
|
||||
*,
|
||||
vectorized: bool | None | DeprecatedArg,
|
||||
vmap_method: str | None,
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
**kwargs: Any,
|
||||
):
|
||||
if isinstance(vectorized, DeprecatedArg) and vmap_method is None:
|
||||
deprecations.warn(
|
||||
"jax-callback-vectorized",
|
||||
f"The default behavior of {prim.name} under vmap will soon "
|
||||
"change. Currently, the default behavior is to generate a sequential "
|
||||
"vmap (i.e. a loop), but in the future the default will be to raise "
|
||||
"an error. To keep the current default, set vmap_method='sequential'.",
|
||||
stacklevel=6)
|
||||
vmap_method = "sequential"
|
||||
|
||||
axis_size, = {a.shape[d] for a, d in zip(args, dims)
|
||||
if d is not batching.not_mapped}
|
||||
new_args = [arg if dim is batching.not_mapped else
|
||||
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
||||
batched_result_avals = tuple(
|
||||
core.unmapped_aval(axis_size, 0, aval) for aval in result_avals)
|
||||
|
||||
# For FFI calls we must update the layouts. We handle the output layouts
|
||||
# here, but the input layout updates depend on the vmap_method parameter.
|
||||
if (
|
||||
vmap_method not in ("sequential", "sequential_unrolled") and
|
||||
kwargs.get("output_layouts") is not None
|
||||
):
|
||||
kwargs["output_layouts"] = tuple(
|
||||
None if layout is None else tuple(n + 1 for n in layout) + (0,)
|
||||
for layout in kwargs["output_layouts"])
|
||||
|
||||
if vmap_method == "legacy_vectorized":
|
||||
# This method is kept to support the behavior that was previously exposed
|
||||
# when using `vectorized=True`.
|
||||
if kwargs.get("input_layouts") is not None:
|
||||
kwargs["input_layouts"] = tuple(
|
||||
layout if d is batching.not_mapped else
|
||||
(None if layout is None else tuple(n + 1 for n in layout) + (0,))
|
||||
for layout, d in zip(kwargs["input_layouts"], dims))
|
||||
outvals = prim.bind(
|
||||
*new_args,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
result_avals=batched_result_avals,
|
||||
**kwargs,
|
||||
)
|
||||
elif vmap_method == "expand_dims" or vmap_method == "broadcast_all":
|
||||
size = axis_size if vmap_method == "broadcast_all" else 1
|
||||
bcast_args = [
|
||||
jax.lax.broadcast(x, (size,)) if d is batching.not_mapped else x
|
||||
for x, d in zip(new_args, dims)]
|
||||
if kwargs.get("input_layouts") is not None:
|
||||
kwargs["input_layouts"] = tuple(
|
||||
None if layout is None else tuple(n + 1 for n in layout) + (0,)
|
||||
for layout in kwargs["input_layouts"])
|
||||
outvals = prim.bind(
|
||||
*bcast_args,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
result_avals=batched_result_avals,
|
||||
**kwargs,
|
||||
)
|
||||
elif vmap_method == "sequential" or vmap_method == "sequential_unrolled":
|
||||
is_batched = [d is not batching.not_mapped for d in dims]
|
||||
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
||||
def _batch_fun(batched_args):
|
||||
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
|
||||
return prim.bind(
|
||||
*merged_args,
|
||||
result_avals=result_avals,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
**kwargs,
|
||||
)
|
||||
unroll = vmap_method == "sequential_unrolled"
|
||||
g = lambda _, x: ((), _batch_fun(x))
|
||||
_, outvals = jax.lax.scan(g, (), batched_args, unroll=unroll)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"vmap is only supported for the {prim.name} primitive when vmap_method "
|
||||
"is one of 'sequential', 'sequential_unrolled', 'expand_dims', "
|
||||
f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.")
|
||||
return tuple(outvals), (0,) * len(outvals)
|
||||
|
||||
|
||||
ffi_call_p = core.Primitive("ffi_call")
|
||||
ffi_call_p.multiple_results = True
|
||||
dispatch.simple_impl(ffi_call_p)
|
||||
@ -637,5 +777,5 @@ ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval)
|
||||
ad.primitive_jvps[ffi_call_p] = ffi_call_jvp
|
||||
ad.primitive_transposes[ffi_call_p] = ffi_call_transpose
|
||||
batching.primitive_batchers[ffi_call_p] = functools.partial(
|
||||
callback_batching_rule, ffi_call_p)
|
||||
ffi_batching_rule, ffi_call_p)
|
||||
mlir.register_lowering(ffi_call_p, ffi_call_lowering)
|
||||
|
@ -41,7 +41,6 @@ from jax._src import dtypes
|
||||
from jax._src import effects as effects_lib
|
||||
from jax._src import linear_util as lu
|
||||
from jax._src import path
|
||||
from jax._src import pickle_util
|
||||
from jax._src import sharding_impls
|
||||
from jax._src import source_info_util
|
||||
from jax._src import util
|
||||
@ -2824,284 +2823,6 @@ DEVICE_TO_DEVICE_TYPE = 1
|
||||
SEND_TO_HOST_TYPE = 2
|
||||
RECV_FROM_HOST_TYPE = 3
|
||||
|
||||
|
||||
def is_empty_shape(s: core.Shape) -> bool:
|
||||
return any(d == 0 for d in s)
|
||||
|
||||
|
||||
def send_to_host(
|
||||
channel: int,
|
||||
token: hlo.TokenType,
|
||||
operand: Any,
|
||||
name: str,
|
||||
*,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
) -> ir.Value:
|
||||
channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE)
|
||||
send_op = hlo.SendOp([operand], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
send_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
||||
dict(
|
||||
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
||||
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
||||
if sharding is not None:
|
||||
if config.use_shardy_partitioner.value:
|
||||
# `SendOp`'s return type is a StableHLO `TokenType`. However JAX passed
|
||||
# in the maximal sharding of the array type. Since a token has no rank,
|
||||
# we need to create an equivalent sharding with no dimensions.
|
||||
assert isinstance(sharding, SdyArrayShardingList)
|
||||
assert len(sharding.shardings) == 1
|
||||
sharding = SdyArrayShardingList([
|
||||
SdyArraySharding(
|
||||
mesh_shape=(), dimension_shardings=[],
|
||||
logical_device_ids=sharding.shardings[0].logical_device_ids)])
|
||||
set_sharding(send_op, sharding)
|
||||
return send_op.result
|
||||
|
||||
|
||||
def receive_from_host(
|
||||
channel: int,
|
||||
token: hlo.TokenType,
|
||||
out_aval: core.ShapedArray,
|
||||
name: str,
|
||||
*,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
) -> tuple[ir.Value, ir.Value]:
|
||||
channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE)
|
||||
recv_op = hlo.RecvOp([aval_to_ir_type(out_aval),
|
||||
hlo.TokenType.get()], token, channel_handle,
|
||||
is_host_transfer=ir.BoolAttr.get(True))
|
||||
recv_op.attributes["mhlo.frontend_attributes"] = ir.DictAttr.get(
|
||||
dict(
|
||||
_xla_host_transfer_handler_name=ir.StringAttr.get(str(name)),
|
||||
_xla_host_transfer_rendezvous=ir.StringAttr.get(str(name))))
|
||||
if sharding is not None:
|
||||
if config.use_shardy_partitioner.value:
|
||||
assert isinstance(sharding, SdyArrayShardingList)
|
||||
assert len(sharding.shardings) == 1
|
||||
# `RecvOp`'s last argument is a `TokenType`. Since Shardy requires the
|
||||
# number of shardings to match the number of results, but JAX only sees
|
||||
# the array result, we need to add an equivalent sharding for the token.
|
||||
sharding = SdyArrayShardingList([
|
||||
sharding.shardings[0],
|
||||
SdyArraySharding(
|
||||
mesh_shape=(), dimension_shardings=[],
|
||||
logical_device_ids=sharding.shardings[0].logical_device_ids)])
|
||||
set_sharding(recv_op, sharding)
|
||||
# Token should be at the end of the results
|
||||
result, token = recv_op.results
|
||||
return token, result
|
||||
|
||||
|
||||
def _emit_tpu_python_callback(
|
||||
backend: xb.XlaBackend,
|
||||
ctx: LoweringRuleContext,
|
||||
callback,
|
||||
token: Any | None,
|
||||
operands: Sequence[ir.Value],
|
||||
operand_avals: Sequence[core.ShapedArray],
|
||||
operand_shapes: Sequence[xc.Shape],
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
result_shapes: Sequence[xc.Shape],
|
||||
*,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
) -> tuple[Sequence[ir.Value], Any]:
|
||||
token = token or hlo.create_token()
|
||||
_wrapped_callback = callback
|
||||
|
||||
send_channels = []
|
||||
if not operand_avals:
|
||||
# If there are no operands to the callback, we need to insert a dummy send
|
||||
# op or the callback will never be triggered!
|
||||
# TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in
|
||||
# MLIR builder.
|
||||
callback_without_args = _wrapped_callback
|
||||
def _wrapped_callback(*args): # pylint: disable=function-redefined
|
||||
del args
|
||||
return callback_without_args()
|
||||
send_channel = ctx.module_context.new_channel()
|
||||
dummy_send_aval = core.ShapedArray((1,), np.float32)
|
||||
dummy_send_val = ir_constant(np.zeros(1, np.float32))
|
||||
operand_shapes = [*operand_shapes,
|
||||
xla.aval_to_xla_shapes(dummy_send_aval)[0]]
|
||||
token = send_to_host(send_channel, token, dummy_send_val, callback.__name__,
|
||||
sharding=sharding)
|
||||
send_channels.append(send_channel)
|
||||
else:
|
||||
for operand in operands:
|
||||
channel = ctx.module_context.new_channel()
|
||||
token = send_to_host(channel, token, operand, callback.__name__,
|
||||
sharding=sharding)
|
||||
send_channels.append(channel)
|
||||
|
||||
recv_channels = []
|
||||
outputs = []
|
||||
for result_aval in result_avals:
|
||||
channel = ctx.module_context.new_channel()
|
||||
assert isinstance(result_aval, core.ShapedArray)
|
||||
token, out = receive_from_host(channel, token, result_aval,
|
||||
callback.__name__, sharding=sharding)
|
||||
outputs.append(out)
|
||||
recv_channels.append(channel)
|
||||
ifrt_callback = backend.make_python_callback_from_host_send_and_recv(
|
||||
_wrapped_callback, operand_shapes, result_shapes, send_channels,
|
||||
recv_channels, pickle_util.dumps)
|
||||
ctx.module_context.add_host_callback(ifrt_callback)
|
||||
return outputs, token
|
||||
|
||||
|
||||
def _layout_to_mlir_layout(minor_to_major: Sequence[int] | None):
|
||||
if minor_to_major is None:
|
||||
# Needed for token layouts
|
||||
layout: np.ndarray = np.zeros((0,), dtype="int64")
|
||||
else:
|
||||
layout = np.array(minor_to_major, dtype="int64")
|
||||
return ir.DenseIntElementsAttr.get(layout, type=ir.IndexType.get())
|
||||
|
||||
def _aval_to_default_layouts(aval):
|
||||
avals = [core.physical_aval(aval)]
|
||||
# Row major order is default for `NumPy`.
|
||||
return [list(range(aval.ndim - 1, -1, -1)) for aval in avals]
|
||||
|
||||
|
||||
def emit_python_callback(
|
||||
ctx: LoweringRuleContext,
|
||||
callback,
|
||||
token: Any | None,
|
||||
operands: Sequence[ir.Value],
|
||||
operand_avals: Sequence[core.ShapedArray],
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
*,
|
||||
has_side_effect: bool,
|
||||
sharding: SdyArrayShardingList | xc.OpSharding | None = None,
|
||||
operand_layouts: Sequence[Sequence[int] | None] | None = None,
|
||||
result_layouts: Sequence[Sequence[int] | None] | None = None,
|
||||
) -> tuple[Sequence[IrValues], Any, Any]:
|
||||
"""Emits MLIR that calls back to a provided Python function."""
|
||||
if len(ctx.module_context.platforms) > 1:
|
||||
raise NotImplementedError("multi-platform lowering for python_callback")
|
||||
platform = ctx.module_context.platforms[0]
|
||||
if platform not in {"cpu", "cuda", "rocm", "tpu"}:
|
||||
raise ValueError(
|
||||
f"`EmitPythonCallback` not supported on {platform} backend.")
|
||||
backend = ctx.module_context.get_backend()
|
||||
result_shapes = util.flatten(
|
||||
[xla.aval_to_xla_shapes(result_aval) for result_aval in result_avals])
|
||||
operand_shapes = util.flatten(
|
||||
[xla.aval_to_xla_shapes(op_aval) for op_aval in operand_avals])
|
||||
# Handling layouts
|
||||
if operand_layouts is None:
|
||||
operand_layouts = util.concatenate(
|
||||
map(_aval_to_default_layouts, operand_avals))
|
||||
operand_mlir_layouts = map(_layout_to_mlir_layout, operand_layouts)
|
||||
if result_layouts is None:
|
||||
result_layouts = util.concatenate(map(_aval_to_default_layouts, result_avals))
|
||||
result_mlir_layouts = map(_layout_to_mlir_layout, result_layouts)
|
||||
|
||||
# First we apply checks to ensure output shapes and dtypes match the expected
|
||||
# ones.
|
||||
def _wrapped_callback(*args):
|
||||
out_vals = callback(*args)
|
||||
if len(out_vals) != len(result_avals):
|
||||
raise RuntimeError(
|
||||
"Mismatched number of outputs from callback. "
|
||||
"Expected: {}, Actual: {}".format(len(result_avals), len(out_vals)))
|
||||
# Handle Python literals, and custom arrays, e.g., tf.Tensor.
|
||||
out_vals = tuple(xla.canonicalize_dtype(np.asarray(a)) for a in out_vals)
|
||||
for i, (out_val, out_aval) in enumerate(zip(out_vals, result_avals)):
|
||||
if out_val.shape != out_aval.shape:
|
||||
raise RuntimeError(
|
||||
f"Incorrect output shape for return value #{i}: "
|
||||
f"Expected: {out_aval.shape}, Actual: {out_val.shape}")
|
||||
if out_val.dtype != out_aval.dtype:
|
||||
raise RuntimeError(
|
||||
f"Incorrect output dtype for return value #{i}: "
|
||||
f"Expected: {out_aval.dtype}, Actual: {out_val.dtype}")
|
||||
|
||||
if platform == "tpu":
|
||||
# On TPU we cannot receive empty arrays. So, we return from the wrapped
|
||||
# callback only the non-empty results, and we will create empty constants
|
||||
# in the receiving computation.
|
||||
# TODO(b/238239458): fix TPU Recv to work with empty arrays.
|
||||
non_empty_out_vals = tuple(
|
||||
out_val
|
||||
for out_val, result_aval in zip(out_vals, result_avals)
|
||||
if not is_empty_shape(result_aval.shape))
|
||||
return non_empty_out_vals
|
||||
else:
|
||||
return out_vals
|
||||
|
||||
if platform == "tpu":
|
||||
non_empty_result_avals, non_empty_result_shapes = util.unzip2([
|
||||
(aval, shape)
|
||||
for aval, shape in zip(result_avals, result_shapes)
|
||||
if not is_empty_shape(aval.shape)])
|
||||
non_empty_outputs, token = _emit_tpu_python_callback(
|
||||
backend, ctx, _wrapped_callback, token,
|
||||
operands, operand_avals, operand_shapes,
|
||||
non_empty_result_avals, non_empty_result_shapes,
|
||||
sharding=sharding)
|
||||
non_empty_outputs_iter = iter(non_empty_outputs)
|
||||
outputs = [
|
||||
ir_constant(np.zeros(result_aval.shape, dtype=result_aval.dtype))
|
||||
if is_empty_shape(result_aval.shape) else next(non_empty_outputs_iter)
|
||||
for result_aval in result_avals]
|
||||
return outputs, token, None
|
||||
|
||||
result_types = flatten_ir_types([aval_to_ir_type(aval) for aval in result_avals])
|
||||
if token:
|
||||
|
||||
callback_without_token = _wrapped_callback
|
||||
def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function-redefined
|
||||
return (token, *callback_without_token(*args))
|
||||
|
||||
operand_shapes = [
|
||||
xla.aval_to_xla_shapes(core.abstract_token)[0], *operand_shapes
|
||||
]
|
||||
result_shapes = [
|
||||
xla.aval_to_xla_shapes(core.abstract_token)[0], *result_shapes
|
||||
]
|
||||
operands = [token, *operands]
|
||||
result_types = [token_type(), *result_types]
|
||||
operand_mlir_layouts = [_layout_to_mlir_layout(None), *operand_mlir_layouts]
|
||||
result_mlir_layouts = [_layout_to_mlir_layout(None), *result_mlir_layouts]
|
||||
callback_descriptor, ifrt_callback = (
|
||||
backend.get_emit_python_callback_descriptor(_wrapped_callback,
|
||||
operand_shapes,
|
||||
result_shapes))
|
||||
ctx.module_context.add_host_callback(ifrt_callback)
|
||||
descriptor_operand = ir_constant(callback_descriptor)
|
||||
callback_operands = [descriptor_operand, *operands]
|
||||
if operand_mlir_layouts is not None:
|
||||
operand_mlir_layouts = [_layout_to_mlir_layout([]), *operand_mlir_layouts]
|
||||
result_type = ir.TupleType.get_tuple(result_types)
|
||||
call_target_name = ("xla_python_gpu_callback"
|
||||
if platform in {"cuda", "rocm"} else "xla_python_cpu_callback")
|
||||
result = hlo.CustomCallOp(
|
||||
[result_type],
|
||||
callback_operands,
|
||||
call_target_name=ir.StringAttr.get(call_target_name),
|
||||
has_side_effect=ir.BoolAttr.get(has_side_effect),
|
||||
api_version=i32_attr(2),
|
||||
called_computations=ir.ArrayAttr.get([]),
|
||||
backend_config=ir.StringAttr.get(str(callback_descriptor)),
|
||||
operand_layouts=(
|
||||
None if operand_mlir_layouts is None
|
||||
else ir.ArrayAttr.get(operand_mlir_layouts)),
|
||||
result_layouts=(
|
||||
None if result_mlir_layouts is None
|
||||
else ir.ArrayAttr.get(result_mlir_layouts)))
|
||||
if sharding is not None:
|
||||
set_sharding(result, sharding)
|
||||
results = [
|
||||
hlo.get_tuple_element(result, i32_attr(i))
|
||||
for i in range(len(result_types))
|
||||
]
|
||||
if token:
|
||||
token, *results = results
|
||||
return results, token, ifrt_callback
|
||||
|
||||
|
||||
def build_mlir_module_helper(
|
||||
closed_jaxpr: core.ClosedJaxpr, *, name: str,
|
||||
platforms: Sequence[str],
|
||||
|
@ -26,6 +26,7 @@ from jax import lax
|
||||
from jax import tree_util
|
||||
from jax._src import ad_util
|
||||
from jax._src import api_util
|
||||
from jax._src import callback
|
||||
from jax._src import core as jax_core
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
@ -811,7 +812,7 @@ batching.primitive_batchers[debug_print_p] = functools.partial(
|
||||
|
||||
@functools.partial(mlir.register_lowering, debug_print_p)
|
||||
def debug_print_lowering_rule(ctx, *args, **params):
|
||||
result, _, _ = mlir.emit_python_callback(
|
||||
result, _, _ = callback.emit_python_callback(
|
||||
ctx,
|
||||
functools.partial(debug_print_p.impl, **params),
|
||||
None,
|
||||
|
@ -39,7 +39,6 @@ from jax._src.interpreters.mlir import (
|
||||
dense_int_array as dense_int_array,
|
||||
dense_int_elements as dense_int_elements,
|
||||
dtype_to_ir_type as dtype_to_ir_type,
|
||||
emit_python_callback as emit_python_callback,
|
||||
flatten_ir_types as flatten_ir_types,
|
||||
flatten_ir_values as flatten_lowering_ir_args, # TODO(phawkins): remove me # noqa: F401
|
||||
flatten_ir_values as flatten_ir_values,
|
||||
@ -74,3 +73,10 @@ from jax._src.sharding_impls import (
|
||||
SPMDAxisContext as SPMDAxisContext,
|
||||
ShardingContext as ShardingContext,
|
||||
)
|
||||
|
||||
|
||||
# TODO(dsuo): Temporarily maintain symbols related to callback lowering for sake
|
||||
# of public APIs.
|
||||
from jax._src.callback import (
|
||||
emit_python_callback as emit_python_callback,
|
||||
)
|
||||
|
@ -22,6 +22,7 @@ import jax.numpy as jnp
|
||||
from jax import lax
|
||||
from jax.experimental import pjit
|
||||
from jax._src import ad_checkpoint
|
||||
from jax._src import callback as cb
|
||||
from jax._src import dispatch
|
||||
from jax._src import config
|
||||
from jax._src import core
|
||||
@ -123,7 +124,7 @@ def callback_effect_lowering(ctx: mlir.LoweringRuleContext, *args, callback, out
|
||||
if effects.ordered_effects.contains(effect):
|
||||
token_in = ctx.tokens_in.get(effect)
|
||||
|
||||
out_op, token_out, _ = mlir.emit_python_callback(
|
||||
out_op, token_out, _ = cb.emit_python_callback(
|
||||
ctx, callback, token_in, list(args), list(ctx.avals_in),
|
||||
list(ctx.avals_out), has_side_effect=True)
|
||||
if token_out:
|
||||
|
Loading…
x
Reference in New Issue
Block a user