Add support for layouts and other advanced features in ffi_call.

This commit is contained in:
Dan Foreman-Mackey 2024-10-21 11:34:57 -04:00
parent 4972f84c94
commit 21f3353544
3 changed files with 285 additions and 77 deletions

View File

@ -160,9 +160,22 @@ def callback_batching_rule(
batched_result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval)
for aval in result_avals)
# For FFI calls we must update the layouts. We handle the output layouts
# here, but the input layout updates depend on the vmap_method parameter.
if vmap_method != "sequential" and kwargs.get("output_layouts") is not None:
kwargs["output_layouts"] = tuple(
None if layout is None else tuple(n + 1 for n in layout) + (0,)
for layout in kwargs["output_layouts"])
if vmap_method == "legacy_vectorized":
# This method is kept to support the behavior that was previously exposed
# when using `vectorized=True`.
if kwargs.get("input_layouts") is not None:
kwargs["input_layouts"] = tuple(
layout if d is batching.not_mapped else
(None if layout is None else tuple(n + 1 for n in layout) + (0,))
for layout, d in zip(kwargs["input_layouts"], dims))
outvals = prim.bind(
*new_args,
vectorized=vectorized,
@ -175,6 +188,10 @@ def callback_batching_rule(
bcast_args = [
lax.broadcast(x, (size,)) if d is batching.not_mapped else x
for x, d in zip(new_args, dims)]
if kwargs.get("input_layouts") is not None:
kwargs["input_layouts"] = tuple(
None if layout is None else tuple(n + 1 for n in layout) + (0,)
for layout in kwargs["input_layouts"])
outvals = prim.bind(
*bcast_args,
vectorized=vectorized,

View File

@ -116,17 +116,17 @@ def _aval_shape(aval: core.AbstractValue) -> Shape:
return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error
def _convert_layout(aval: core.AbstractValue,
layout: FfiLayoutOptions = None) -> Sequence[int]:
def _convert_layout_for_lowering(
aval: core.AbstractValue, layout: FfiLayoutOptions = None) -> Sequence[int]:
"""Convert a layout to the minor-to-major order used by the custom call API."""
if layout is None:
return list(reversed(range(len(_aval_shape(aval)))))
return tuple(reversed(range(len(_aval_shape(aval)))))
elif isinstance(layout, DeviceLocalLayout):
if layout._tiling is not None:
raise ValueError("The FFI does not support layouts with tiling")
return layout.major_to_minor[::-1]
else:
return layout
return tuple(layout)
def ffi_lowering(
@ -134,7 +134,7 @@ def ffi_lowering(
*,
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
result_layouts: Sequence[FfiLayoutOptions] | None = None,
backend_config: Mapping[str, ir.Attribute] | None = None,
backend_config: Mapping[str, ir.Attribute] | str | None = None,
**lowering_args: Any
) -> mlir.LoweringRule:
"""Build a lowering rule for an foreign function interface (FFI) target.
@ -143,6 +143,10 @@ def ffi_lowering(
compute the input and output types and shapes for the custom call, assuming
row-major layouts.
Note that layouts passed to this function as tuples should be in
minor-to-major order (as expected by XLA) rather than major-to-minor as used
by :func:`~jax.extend.ffi.ffi_call` and ``DeviceLocalLayout``.
If keyword arguments are passed to the lowering rule, these are treated as
attributes, and added to `backend_config`.
@ -163,20 +167,32 @@ def ffi_lowering(
) -> Sequence[ir.Value | Sequence[ir.Value]]:
kwargs = dict(lowering_args)
kwargs.setdefault("api_version", 4)
kwargs["backend_config"] = dict(
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
if kwargs["api_version"] >= 4:
if backend_config is not None and not isinstance(backend_config, dict):
raise ValueError(
"When api_version > 4, backend_config must be a dictionary.")
kwargs["backend_config"] = dict(
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
else:
if params:
raise ValueError(
"The use of ffi_call attributes requires a custom call API version "
f"of at least 4; got api_version={kwargs['api_version']}.")
kwargs["backend_config"] = backend_config
if "result_types" not in kwargs:
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
if operand_layouts is None:
kwargs["operand_layouts"] = map(_convert_layout, ctx.avals_in)
kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in)
else:
kwargs["operand_layouts"] = [
_convert_layout(*args) for args in zip(ctx.avals_in, operand_layouts)]
_convert_layout_for_lowering(*args)
for args in zip(ctx.avals_in, operand_layouts)]
if result_layouts is None:
kwargs["result_layouts"] = map(_convert_layout, ctx.avals_out)
kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out)
else:
kwargs["result_layouts"] = [
_convert_layout(*args) for args in zip(ctx.avals_out, result_layouts)]
_convert_layout_for_lowering(*args)
for args in zip(ctx.avals_out, result_layouts)]
if "result_shapes" not in kwargs and not all(
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
kwargs["result_shapes"] = [
@ -202,12 +218,39 @@ def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue
return tuple(avals)
def _check_compatible_avals(a: core.AbstractValue, b: core.AbstractValue) -> bool:
if isinstance(a, core.AbstractToken) and isinstance(b, core.AbstractToken):
return True
if getattr(a, "shape", ()) != getattr(b, "shape", ()):
return False
if getattr(a, "dtype", ()) != getattr(b, "dtype", ()):
return False
return True
def _convert_layouts_for_ffi_call(
avals: Sequence[core.AbstractValue],
layouts: Sequence[FfiLayoutOptions]) -> tuple[Sequence[int], ...]:
return tuple(
_convert_layout_for_lowering(
aval,
layout if layout is None or isinstance(layout, DeviceLocalLayout)
else layout[::-1]
)
for aval, layout in zip(avals, layouts))
def ffi_call(
target_name: str,
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
*deprecated_args: ArrayLike,
has_side_effect: bool = False,
vmap_method: str | None = None,
input_layouts: Sequence[FfiLayoutOptions] | None = None,
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = None,
input_output_aliases: dict[int, int] | None = None,
custom_call_api_version: int = 4,
legacy_backend_config: str | None = None,
vectorized: bool | DeprecatedArg = DeprecatedArg(),
**deprecated_kwargs: Any,
) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]:
@ -227,7 +270,7 @@ def ffi_call(
Args:
target_name: the name of the XLA FFI custom call target that was registered
using :func:`~jaxlib.xla_client.register_custom_call_target`.
using :func:`~jax.extend.ffi.register_ffi_target`.
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
``dtype`` attributes which are expected to match the shape and dtype of
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often
@ -238,6 +281,32 @@ def ffi_call(
outputs are not used.
vmap_method: string specifying how the FFI call transforms under
:func:`~jax.vmap` as described above.
input_layouts: a sequence of layouts for each input argument. In each case,
the layout can be (a) ``None`` indicating that this input is in default
row-major order, (b) a ``DeviceLocalLayout`` specifying the axis order,
or (c) a sequence of integers specifying the major-to-minor axis
ordering. Users who are familiar with XLA layouts should note that this
function expects layouts in major-to-minor order instead of the
minor-to-major order that XLA uses. For example, a batch of row-major
matrices could be specified using the layout ``[0, 1, 2]``, whereas a
batch of column-major matrices would have layout ``[0, 2, 1]``. In both
of these examples, the leading/batch dimension is the "slowest" axis. The
``input_layouts`` parameter should be used to request the memory layout
expected by the FFI call target, and XLA will ensure that the buffers
have the correct layouts before the handler is executed.
output_layouts: like ``input_layouts``, but specifying the required layouts
for the output arrays.
input_output_aliases: a dictionary where the keys are input indices and the
values are output indices. This mapping indicates which output arrays
alias specific input arrays.
custom_call_api_version: the version number of the custom call API
implemented by the FFI target ``target_name``. The only formally
supported version is the typed FFI API with ``custom_call_api_version=4``,
but earlier unsupported custom calls can be executed using this argument.
legacy_backend_config: for legacy targets implemented using
``custom_call_api_version<4``, attributes are passed using the opaque
string representation provided by this argument. This parameter cannot be
used with ``custom_call_api_version>=4``.
Returns:
A function that can be called with the input arrays as positional arguments
@ -263,14 +332,73 @@ def ffi_call(
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
f"but got: {vmap_method}")
output_layouts_: Sequence[FfiLayoutOptions] | None
if isinstance(result_shape_dtypes, Sequence):
output_layouts_ = output_layouts # type: ignore
multiple_results = True
result_avals = _result_avals(result_shape_dtypes)
else:
multiple_results = False
result_avals = _result_avals((result_shape_dtypes,))
output_layouts_ = (output_layouts,) # type: ignore
if custom_call_api_version >= 4 and legacy_backend_config is not None:
raise ValueError(
"The use of the legacy_backend_config parameter requires "
f"custom_call_api_version < 4; got {custom_call_api_version}.")
def wrapped(*args: ArrayLike, **kwargs: Any):
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
if input_layouts is None:
static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals))
else:
if len(input_layouts) != len(in_avals):
raise ValueError(
f"The number of input arguments ({len(in_avals)}) must equal the "
f"number of input layouts ({len(input_layouts)}).")
static_input_layouts = _convert_layouts_for_ffi_call(in_avals,
input_layouts)
if output_layouts_ is None:
static_output_layouts = tuple(map(_convert_layout_for_lowering,
result_avals))
else:
if len(output_layouts_) != len(result_avals):
raise ValueError(
f"The number of outputs ({len(result_avals)}) must equal the "
f"number of output layouts ({len(output_layouts_)}).")
static_output_layouts = _convert_layouts_for_ffi_call(result_avals,
output_layouts_)
static_input_output_aliases: tuple[tuple[int, int], ...] = ()
if input_output_aliases is not None:
for i_idx, o_idx in sorted(input_output_aliases.items()):
i_idx, o_idx = int(i_idx), int(o_idx)
if i_idx >= len(args):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"with input index {i_idx} outside the range [0, "
f"{len(args)}).")
if o_idx >= len(result_avals):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"with output index {o_idx} outside the range [0, "
f"{len(result_avals)}).")
in_aval = in_avals[i_idx]
out_aval = result_avals[o_idx]
if not _check_compatible_avals(in_aval, out_aval):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"referring to an input with abstract value {in_aval} and an "
f"output with a different abstract value {out_aval}.")
if static_input_layouts[i_idx] != static_output_layouts[o_idx]:
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"referring to an input with layout {static_input_layouts[i_idx]} "
"and an output with a different layout "
f"{static_output_layouts[o_idx]}.")
static_input_output_aliases += ((i_idx, o_idx),)
results = ffi_call_p.bind(
*args,
result_avals=result_avals,
@ -278,6 +406,11 @@ def ffi_call(
vmap_method=vmap_method,
target_name=target_name,
has_side_effect=has_side_effect,
input_layouts=static_input_layouts,
output_layouts=static_output_layouts,
input_output_aliases=static_input_output_aliases,
custom_call_api_version=custom_call_api_version,
legacy_backend_config=legacy_backend_config,
attributes=_wrap_kwargs_hashable(kwargs),
)
if multiple_results:
@ -383,26 +516,23 @@ effects.control_flow_allowed_effects.add_type(FfiEffect)
def ffi_call_abstract_eval(
*avals_in,
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool | DeprecatedArg,
vmap_method: str | None,
has_side_effect: bool,
attributes: Sequence[tuple[str, Any]],
**_,
):
del avals_in, target_name, vectorized, vmap_method, attributes
del avals_in # unused
effects = {_FfiEffect} if has_side_effect else core.no_effects
return result_avals, effects
def ffi_call_jvp(*args, target_name, **kwargs):
del args, kwargs
def ffi_call_jvp(*args, target_name, **_):
del args
raise ValueError(
f"The FFI call to `{target_name}` cannot be differentiated. "
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")
def ffi_call_transpose(*args, target_name, **kwargs):
del args, kwargs
def ffi_call_transpose(*args, target_name, **_):
del args
raise ValueError(
f"The FFI call to `{target_name}` cannot be differentiated. "
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")
@ -411,15 +541,22 @@ def ffi_call_transpose(*args, target_name, **kwargs):
def ffi_call_lowering(
ctx: mlir.LoweringRuleContext,
*operands: ir.Value,
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool | DeprecatedArg,
vmap_method: str | None,
has_side_effect: bool,
input_layouts: Sequence[Sequence[int]],
output_layouts: Sequence[Sequence[int]],
input_output_aliases: Sequence[tuple[int, int]],
custom_call_api_version: int,
legacy_backend_config: str | None,
attributes: Sequence[tuple[str, Any]],
**_,
) -> Sequence[ir.Value]:
del result_avals, vectorized, vmap_method
rule = ffi_lowering(target_name, has_side_effect=has_side_effect)
rule = ffi_lowering(target_name, has_side_effect=has_side_effect,
operand_layouts=input_layouts,
result_layouts=output_layouts,
operand_output_aliases=dict(input_output_aliases),
api_version=custom_call_api_version,
backend_config=legacy_backend_config)
return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes))

View File

@ -14,6 +14,7 @@
import os
import unittest
from functools import partial
import numpy as np
from absl.testing import absltest
@ -34,6 +35,7 @@ from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal
jax.config.parse_flags_with_absl()
@ -122,7 +124,6 @@ class FfiTest(jtu.JaxTestCase):
# layouts.
def lowering_rule(ctx, x):
aval, = ctx.avals_in
ndim = len(aval.shape)
return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
result_layouts=[layout_spec])(ctx, x)
prim = core.Primitive("test_ffi")
@ -228,51 +229,42 @@ class FfiTest(jtu.JaxTestCase):
fun(jnp.ones(5))
self.assertNotIsInstance(manager.exception, TypeError)
@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
)
@jtu.run_on_devices("gpu")
def testFfiCall(self, shape, dtype):
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size)
self.assertArraysEqual(actual, expected)
@jtu.sample_product(shape=[(6, 5), (4, 5, 6)])
@jtu.run_on_devices("gpu", "cpu")
def testFfiCall(self, shape):
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = ffi_call_geqrf(x)
for a, b in zip(actual, expected):
self.assertArraysEqual(a, b)
@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
vmap_method=("expand_dims", "broadcast_all", "sequential",
"legacy_vectorized"),
shape=[(6, 5), (4, 5, 6)],
vmap_method=["expand_dims", "broadcast_all", "sequential"],
)
@jtu.run_on_devices("gpu")
def testFfiCallBatching(self, shape, dtype, vmap_method):
@jtu.run_on_devices("gpu", "cpu")
def testFfiCallBatching(self, shape, vmap_method):
shape = (10,) + shape
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
pivots = jnp.broadcast_to(pivots, shape)
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation(
x, permutation_size, vmap_method=vmap_method))(pivots)
self.assertArraysEqual(actual, expected)
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
for a, b in zip(actual, expected):
if vmap_method == "sequential" and len(shape) == 3:
# On GPU, the batched FFI call to geqrf uses an algorithm with
# different numerics than the unbatched version (which is used when
# vmap_method="sequential"). Therefore, we need to include floating
# point tolerance for this check.
self.assertArraysAllClose(a, b)
else:
self.assertArraysEqual(a, b)
@jtu.run_on_devices("gpu")
@jtu.run_on_devices("gpu", "cpu")
def testVectorizedDeprecation(self):
pivots_size = 4
shape = (10, pivots_size)
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1,
dtype=np.int32)
pivots = jnp.broadcast_to(pivots, shape)
x = self.rng().randn(3, 5, 4).astype(np.float32)
with self.assertWarns(DeprecationWarning):
ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True)
ffi_call_geqrf(x, vectorized=True)
with self.assertWarns(DeprecationWarning):
jax.vmap(
lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size))(pivots)
jax.vmap(ffi_call_geqrf)(x)
def testBackwardCompatSyntax(self):
def fun(x):
@ -280,20 +272,82 @@ class FfiTest(jtu.JaxTestCase):
with self.assertWarns(DeprecationWarning):
jax.jit(fun).lower(jnp.ones(5))
def testInputOutputAliases(self):
def fun(x):
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]")
# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation`
# custom call target because that's the only one in jaxlib that uses the
# new FFI interface. Once more are available, consider using something that
# can be run on multiple platforms.
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs):
return jex.ffi.ffi_call(
"cu_lu_pivots_to_permutation",
jax.ShapeDtypeStruct(
shape=pivots.shape[:-1] + (permutation_size,),
dtype=pivots.dtype,
),
**kwargs,
)(pivots)
def testInvalidInputOutputAliases(self):
def fun(x):
return jex.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x)
with self.assertRaisesRegex(ValueError, "with input index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x)
with self.assertRaisesRegex(ValueError, "with output index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape,
x.dtype),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def testLegacyBackendConfig(self):
def fun(x):
return jex.ffi.ffi_call("test", x, custom_call_api_version=2,
legacy_backend_config="12345")(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, 'backend_config = "12345"')
def testInvalidBackendConfig(self):
def fun(x):
return jex.ffi.ffi_call("test", x, legacy_backend_config="12345")(x)
with self.assertRaisesRegex(ValueError,
"The use of the legacy_backend_config"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", x,
custom_call_api_version=2)(x, attribute=1)
with self.assertRaisesRegex(ValueError,
"The use of ffi_call attributes requires"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def ffi_call_geqrf(x, **kwargs):
assert x.dtype == np.float32
ndim = x.ndim
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
output_types = [
x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)]
def call(platform, x):
target_name = dict(
cpu="lapack_sgeqrf_ffi",
rocm="hipsolver_geqrf_ffi",
cuda="cusolver_geqrf_ffi",
)[platform]
return jex.ffi.ffi_call(
target_name, output_types, input_output_aliases={0: 0},
input_layouts=[x_major_to_minor],
output_layouts=[x_major_to_minor, None],
**kwargs)(x)
return lax.platform_dependent(
x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"),
cuda=partial(call, "cuda"))
class MlirRegisterLoweringTest(jtu.JaxTestCase):