mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add support for layouts and other advanced features in ffi_call.
This commit is contained in:
parent
4972f84c94
commit
21f3353544
@ -160,9 +160,22 @@ def callback_batching_rule(
|
||||
batched_result_avals = tuple(
|
||||
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval)
|
||||
for aval in result_avals)
|
||||
|
||||
# For FFI calls we must update the layouts. We handle the output layouts
|
||||
# here, but the input layout updates depend on the vmap_method parameter.
|
||||
if vmap_method != "sequential" and kwargs.get("output_layouts") is not None:
|
||||
kwargs["output_layouts"] = tuple(
|
||||
None if layout is None else tuple(n + 1 for n in layout) + (0,)
|
||||
for layout in kwargs["output_layouts"])
|
||||
|
||||
if vmap_method == "legacy_vectorized":
|
||||
# This method is kept to support the behavior that was previously exposed
|
||||
# when using `vectorized=True`.
|
||||
if kwargs.get("input_layouts") is not None:
|
||||
kwargs["input_layouts"] = tuple(
|
||||
layout if d is batching.not_mapped else
|
||||
(None if layout is None else tuple(n + 1 for n in layout) + (0,))
|
||||
for layout, d in zip(kwargs["input_layouts"], dims))
|
||||
outvals = prim.bind(
|
||||
*new_args,
|
||||
vectorized=vectorized,
|
||||
@ -175,6 +188,10 @@ def callback_batching_rule(
|
||||
bcast_args = [
|
||||
lax.broadcast(x, (size,)) if d is batching.not_mapped else x
|
||||
for x, d in zip(new_args, dims)]
|
||||
if kwargs.get("input_layouts") is not None:
|
||||
kwargs["input_layouts"] = tuple(
|
||||
None if layout is None else tuple(n + 1 for n in layout) + (0,)
|
||||
for layout in kwargs["input_layouts"])
|
||||
outvals = prim.bind(
|
||||
*bcast_args,
|
||||
vectorized=vectorized,
|
||||
|
@ -116,17 +116,17 @@ def _aval_shape(aval: core.AbstractValue) -> Shape:
|
||||
return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error
|
||||
|
||||
|
||||
def _convert_layout(aval: core.AbstractValue,
|
||||
layout: FfiLayoutOptions = None) -> Sequence[int]:
|
||||
def _convert_layout_for_lowering(
|
||||
aval: core.AbstractValue, layout: FfiLayoutOptions = None) -> Sequence[int]:
|
||||
"""Convert a layout to the minor-to-major order used by the custom call API."""
|
||||
if layout is None:
|
||||
return list(reversed(range(len(_aval_shape(aval)))))
|
||||
return tuple(reversed(range(len(_aval_shape(aval)))))
|
||||
elif isinstance(layout, DeviceLocalLayout):
|
||||
if layout._tiling is not None:
|
||||
raise ValueError("The FFI does not support layouts with tiling")
|
||||
return layout.major_to_minor[::-1]
|
||||
else:
|
||||
return layout
|
||||
return tuple(layout)
|
||||
|
||||
|
||||
def ffi_lowering(
|
||||
@ -134,7 +134,7 @@ def ffi_lowering(
|
||||
*,
|
||||
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
|
||||
result_layouts: Sequence[FfiLayoutOptions] | None = None,
|
||||
backend_config: Mapping[str, ir.Attribute] | None = None,
|
||||
backend_config: Mapping[str, ir.Attribute] | str | None = None,
|
||||
**lowering_args: Any
|
||||
) -> mlir.LoweringRule:
|
||||
"""Build a lowering rule for an foreign function interface (FFI) target.
|
||||
@ -143,6 +143,10 @@ def ffi_lowering(
|
||||
compute the input and output types and shapes for the custom call, assuming
|
||||
row-major layouts.
|
||||
|
||||
Note that layouts passed to this function as tuples should be in
|
||||
minor-to-major order (as expected by XLA) rather than major-to-minor as used
|
||||
by :func:`~jax.extend.ffi.ffi_call` and ``DeviceLocalLayout``.
|
||||
|
||||
If keyword arguments are passed to the lowering rule, these are treated as
|
||||
attributes, and added to `backend_config`.
|
||||
|
||||
@ -163,20 +167,32 @@ def ffi_lowering(
|
||||
) -> Sequence[ir.Value | Sequence[ir.Value]]:
|
||||
kwargs = dict(lowering_args)
|
||||
kwargs.setdefault("api_version", 4)
|
||||
kwargs["backend_config"] = dict(
|
||||
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
|
||||
if kwargs["api_version"] >= 4:
|
||||
if backend_config is not None and not isinstance(backend_config, dict):
|
||||
raise ValueError(
|
||||
"When api_version > 4, backend_config must be a dictionary.")
|
||||
kwargs["backend_config"] = dict(
|
||||
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
|
||||
else:
|
||||
if params:
|
||||
raise ValueError(
|
||||
"The use of ffi_call attributes requires a custom call API version "
|
||||
f"of at least 4; got api_version={kwargs['api_version']}.")
|
||||
kwargs["backend_config"] = backend_config
|
||||
if "result_types" not in kwargs:
|
||||
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
|
||||
if operand_layouts is None:
|
||||
kwargs["operand_layouts"] = map(_convert_layout, ctx.avals_in)
|
||||
kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in)
|
||||
else:
|
||||
kwargs["operand_layouts"] = [
|
||||
_convert_layout(*args) for args in zip(ctx.avals_in, operand_layouts)]
|
||||
_convert_layout_for_lowering(*args)
|
||||
for args in zip(ctx.avals_in, operand_layouts)]
|
||||
if result_layouts is None:
|
||||
kwargs["result_layouts"] = map(_convert_layout, ctx.avals_out)
|
||||
kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out)
|
||||
else:
|
||||
kwargs["result_layouts"] = [
|
||||
_convert_layout(*args) for args in zip(ctx.avals_out, result_layouts)]
|
||||
_convert_layout_for_lowering(*args)
|
||||
for args in zip(ctx.avals_out, result_layouts)]
|
||||
if "result_shapes" not in kwargs and not all(
|
||||
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
|
||||
kwargs["result_shapes"] = [
|
||||
@ -202,12 +218,39 @@ def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue
|
||||
return tuple(avals)
|
||||
|
||||
|
||||
def _check_compatible_avals(a: core.AbstractValue, b: core.AbstractValue) -> bool:
|
||||
if isinstance(a, core.AbstractToken) and isinstance(b, core.AbstractToken):
|
||||
return True
|
||||
if getattr(a, "shape", ()) != getattr(b, "shape", ()):
|
||||
return False
|
||||
if getattr(a, "dtype", ()) != getattr(b, "dtype", ()):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _convert_layouts_for_ffi_call(
|
||||
avals: Sequence[core.AbstractValue],
|
||||
layouts: Sequence[FfiLayoutOptions]) -> tuple[Sequence[int], ...]:
|
||||
return tuple(
|
||||
_convert_layout_for_lowering(
|
||||
aval,
|
||||
layout if layout is None or isinstance(layout, DeviceLocalLayout)
|
||||
else layout[::-1]
|
||||
)
|
||||
for aval, layout in zip(avals, layouts))
|
||||
|
||||
|
||||
def ffi_call(
|
||||
target_name: str,
|
||||
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
|
||||
*deprecated_args: ArrayLike,
|
||||
has_side_effect: bool = False,
|
||||
vmap_method: str | None = None,
|
||||
input_layouts: Sequence[FfiLayoutOptions] | None = None,
|
||||
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = None,
|
||||
input_output_aliases: dict[int, int] | None = None,
|
||||
custom_call_api_version: int = 4,
|
||||
legacy_backend_config: str | None = None,
|
||||
vectorized: bool | DeprecatedArg = DeprecatedArg(),
|
||||
**deprecated_kwargs: Any,
|
||||
) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]:
|
||||
@ -227,7 +270,7 @@ def ffi_call(
|
||||
|
||||
Args:
|
||||
target_name: the name of the XLA FFI custom call target that was registered
|
||||
using :func:`~jaxlib.xla_client.register_custom_call_target`.
|
||||
using :func:`~jax.extend.ffi.register_ffi_target`.
|
||||
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
|
||||
``dtype`` attributes which are expected to match the shape and dtype of
|
||||
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often
|
||||
@ -238,6 +281,32 @@ def ffi_call(
|
||||
outputs are not used.
|
||||
vmap_method: string specifying how the FFI call transforms under
|
||||
:func:`~jax.vmap` as described above.
|
||||
input_layouts: a sequence of layouts for each input argument. In each case,
|
||||
the layout can be (a) ``None`` indicating that this input is in default
|
||||
row-major order, (b) a ``DeviceLocalLayout`` specifying the axis order,
|
||||
or (c) a sequence of integers specifying the major-to-minor axis
|
||||
ordering. Users who are familiar with XLA layouts should note that this
|
||||
function expects layouts in major-to-minor order instead of the
|
||||
minor-to-major order that XLA uses. For example, a batch of row-major
|
||||
matrices could be specified using the layout ``[0, 1, 2]``, whereas a
|
||||
batch of column-major matrices would have layout ``[0, 2, 1]``. In both
|
||||
of these examples, the leading/batch dimension is the "slowest" axis. The
|
||||
``input_layouts`` parameter should be used to request the memory layout
|
||||
expected by the FFI call target, and XLA will ensure that the buffers
|
||||
have the correct layouts before the handler is executed.
|
||||
output_layouts: like ``input_layouts``, but specifying the required layouts
|
||||
for the output arrays.
|
||||
input_output_aliases: a dictionary where the keys are input indices and the
|
||||
values are output indices. This mapping indicates which output arrays
|
||||
alias specific input arrays.
|
||||
custom_call_api_version: the version number of the custom call API
|
||||
implemented by the FFI target ``target_name``. The only formally
|
||||
supported version is the typed FFI API with ``custom_call_api_version=4``,
|
||||
but earlier unsupported custom calls can be executed using this argument.
|
||||
legacy_backend_config: for legacy targets implemented using
|
||||
``custom_call_api_version<4``, attributes are passed using the opaque
|
||||
string representation provided by this argument. This parameter cannot be
|
||||
used with ``custom_call_api_version>=4``.
|
||||
|
||||
Returns:
|
||||
A function that can be called with the input arrays as positional arguments
|
||||
@ -263,14 +332,73 @@ def ffi_call(
|
||||
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
|
||||
f"but got: {vmap_method}")
|
||||
|
||||
output_layouts_: Sequence[FfiLayoutOptions] | None
|
||||
if isinstance(result_shape_dtypes, Sequence):
|
||||
output_layouts_ = output_layouts # type: ignore
|
||||
multiple_results = True
|
||||
result_avals = _result_avals(result_shape_dtypes)
|
||||
else:
|
||||
multiple_results = False
|
||||
result_avals = _result_avals((result_shape_dtypes,))
|
||||
output_layouts_ = (output_layouts,) # type: ignore
|
||||
|
||||
if custom_call_api_version >= 4 and legacy_backend_config is not None:
|
||||
raise ValueError(
|
||||
"The use of the legacy_backend_config parameter requires "
|
||||
f"custom_call_api_version < 4; got {custom_call_api_version}.")
|
||||
|
||||
def wrapped(*args: ArrayLike, **kwargs: Any):
|
||||
in_avals = [core.raise_to_shaped(core.get_aval(x)) for x in args]
|
||||
|
||||
if input_layouts is None:
|
||||
static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals))
|
||||
else:
|
||||
if len(input_layouts) != len(in_avals):
|
||||
raise ValueError(
|
||||
f"The number of input arguments ({len(in_avals)}) must equal the "
|
||||
f"number of input layouts ({len(input_layouts)}).")
|
||||
static_input_layouts = _convert_layouts_for_ffi_call(in_avals,
|
||||
input_layouts)
|
||||
if output_layouts_ is None:
|
||||
static_output_layouts = tuple(map(_convert_layout_for_lowering,
|
||||
result_avals))
|
||||
else:
|
||||
if len(output_layouts_) != len(result_avals):
|
||||
raise ValueError(
|
||||
f"The number of outputs ({len(result_avals)}) must equal the "
|
||||
f"number of output layouts ({len(output_layouts_)}).")
|
||||
static_output_layouts = _convert_layouts_for_ffi_call(result_avals,
|
||||
output_layouts_)
|
||||
|
||||
static_input_output_aliases: tuple[tuple[int, int], ...] = ()
|
||||
if input_output_aliases is not None:
|
||||
for i_idx, o_idx in sorted(input_output_aliases.items()):
|
||||
i_idx, o_idx = int(i_idx), int(o_idx)
|
||||
if i_idx >= len(args):
|
||||
raise ValueError(
|
||||
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
||||
f"with input index {i_idx} outside the range [0, "
|
||||
f"{len(args)}).")
|
||||
if o_idx >= len(result_avals):
|
||||
raise ValueError(
|
||||
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
||||
f"with output index {o_idx} outside the range [0, "
|
||||
f"{len(result_avals)}).")
|
||||
in_aval = in_avals[i_idx]
|
||||
out_aval = result_avals[o_idx]
|
||||
if not _check_compatible_avals(in_aval, out_aval):
|
||||
raise ValueError(
|
||||
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
||||
f"referring to an input with abstract value {in_aval} and an "
|
||||
f"output with a different abstract value {out_aval}.")
|
||||
if static_input_layouts[i_idx] != static_output_layouts[o_idx]:
|
||||
raise ValueError(
|
||||
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
||||
f"referring to an input with layout {static_input_layouts[i_idx]} "
|
||||
"and an output with a different layout "
|
||||
f"{static_output_layouts[o_idx]}.")
|
||||
static_input_output_aliases += ((i_idx, o_idx),)
|
||||
|
||||
results = ffi_call_p.bind(
|
||||
*args,
|
||||
result_avals=result_avals,
|
||||
@ -278,6 +406,11 @@ def ffi_call(
|
||||
vmap_method=vmap_method,
|
||||
target_name=target_name,
|
||||
has_side_effect=has_side_effect,
|
||||
input_layouts=static_input_layouts,
|
||||
output_layouts=static_output_layouts,
|
||||
input_output_aliases=static_input_output_aliases,
|
||||
custom_call_api_version=custom_call_api_version,
|
||||
legacy_backend_config=legacy_backend_config,
|
||||
attributes=_wrap_kwargs_hashable(kwargs),
|
||||
)
|
||||
if multiple_results:
|
||||
@ -383,26 +516,23 @@ effects.control_flow_allowed_effects.add_type(FfiEffect)
|
||||
def ffi_call_abstract_eval(
|
||||
*avals_in,
|
||||
result_avals: tuple[core.AbstractValue, ...],
|
||||
target_name: str,
|
||||
vectorized: bool | DeprecatedArg,
|
||||
vmap_method: str | None,
|
||||
has_side_effect: bool,
|
||||
attributes: Sequence[tuple[str, Any]],
|
||||
**_,
|
||||
):
|
||||
del avals_in, target_name, vectorized, vmap_method, attributes
|
||||
del avals_in # unused
|
||||
effects = {_FfiEffect} if has_side_effect else core.no_effects
|
||||
return result_avals, effects
|
||||
|
||||
|
||||
def ffi_call_jvp(*args, target_name, **kwargs):
|
||||
del args, kwargs
|
||||
def ffi_call_jvp(*args, target_name, **_):
|
||||
del args
|
||||
raise ValueError(
|
||||
f"The FFI call to `{target_name}` cannot be differentiated. "
|
||||
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")
|
||||
|
||||
|
||||
def ffi_call_transpose(*args, target_name, **kwargs):
|
||||
del args, kwargs
|
||||
def ffi_call_transpose(*args, target_name, **_):
|
||||
del args
|
||||
raise ValueError(
|
||||
f"The FFI call to `{target_name}` cannot be differentiated. "
|
||||
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")
|
||||
@ -411,15 +541,22 @@ def ffi_call_transpose(*args, target_name, **kwargs):
|
||||
def ffi_call_lowering(
|
||||
ctx: mlir.LoweringRuleContext,
|
||||
*operands: ir.Value,
|
||||
result_avals: tuple[core.AbstractValue, ...],
|
||||
target_name: str,
|
||||
vectorized: bool | DeprecatedArg,
|
||||
vmap_method: str | None,
|
||||
has_side_effect: bool,
|
||||
input_layouts: Sequence[Sequence[int]],
|
||||
output_layouts: Sequence[Sequence[int]],
|
||||
input_output_aliases: Sequence[tuple[int, int]],
|
||||
custom_call_api_version: int,
|
||||
legacy_backend_config: str | None,
|
||||
attributes: Sequence[tuple[str, Any]],
|
||||
**_,
|
||||
) -> Sequence[ir.Value]:
|
||||
del result_avals, vectorized, vmap_method
|
||||
rule = ffi_lowering(target_name, has_side_effect=has_side_effect)
|
||||
rule = ffi_lowering(target_name, has_side_effect=has_side_effect,
|
||||
operand_layouts=input_layouts,
|
||||
result_layouts=output_layouts,
|
||||
operand_output_aliases=dict(input_output_aliases),
|
||||
api_version=custom_call_api_version,
|
||||
backend_config=legacy_backend_config)
|
||||
return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes))
|
||||
|
||||
|
||||
|
@ -14,6 +14,7 @@
|
||||
|
||||
import os
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from absl.testing import absltest
|
||||
@ -34,6 +35,7 @@ from jax._src import xla_bridge
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.lib.mlir.dialects import hlo
|
||||
from jax._src.lax import linalg as lax_linalg_internal
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -122,7 +124,6 @@ class FfiTest(jtu.JaxTestCase):
|
||||
# layouts.
|
||||
def lowering_rule(ctx, x):
|
||||
aval, = ctx.avals_in
|
||||
ndim = len(aval.shape)
|
||||
return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
|
||||
result_layouts=[layout_spec])(ctx, x)
|
||||
prim = core.Primitive("test_ffi")
|
||||
@ -228,51 +229,42 @@ class FfiTest(jtu.JaxTestCase):
|
||||
fun(jnp.ones(5))
|
||||
self.assertNotIsInstance(manager.exception, TypeError)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1,), (4,), (5,)],
|
||||
dtype=(np.int32,),
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
def testFfiCall(self, shape, dtype):
|
||||
pivots_size = shape[-1]
|
||||
permutation_size = 2 * pivots_size
|
||||
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
||||
pivots = jnp.broadcast_to(pivots, shape)
|
||||
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
|
||||
actual = ffi_call_lu_pivots_to_permutation(pivots, permutation_size)
|
||||
self.assertArraysEqual(actual, expected)
|
||||
@jtu.sample_product(shape=[(6, 5), (4, 5, 6)])
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def testFfiCall(self, shape):
|
||||
x = self.rng().randn(*shape).astype(np.float32)
|
||||
expected = lax_linalg_internal.geqrf(x)
|
||||
actual = ffi_call_geqrf(x)
|
||||
for a, b in zip(actual, expected):
|
||||
self.assertArraysEqual(a, b)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(1,), (4,), (5,)],
|
||||
dtype=(np.int32,),
|
||||
vmap_method=("expand_dims", "broadcast_all", "sequential",
|
||||
"legacy_vectorized"),
|
||||
shape=[(6, 5), (4, 5, 6)],
|
||||
vmap_method=["expand_dims", "broadcast_all", "sequential"],
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
def testFfiCallBatching(self, shape, dtype, vmap_method):
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def testFfiCallBatching(self, shape, vmap_method):
|
||||
shape = (10,) + shape
|
||||
pivots_size = shape[-1]
|
||||
permutation_size = 2 * pivots_size
|
||||
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1, dtype=dtype)
|
||||
pivots = jnp.broadcast_to(pivots, shape)
|
||||
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
|
||||
actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation(
|
||||
x, permutation_size, vmap_method=vmap_method))(pivots)
|
||||
self.assertArraysEqual(actual, expected)
|
||||
x = self.rng().randn(*shape).astype(np.float32)
|
||||
expected = lax_linalg_internal.geqrf(x)
|
||||
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
|
||||
for a, b in zip(actual, expected):
|
||||
if vmap_method == "sequential" and len(shape) == 3:
|
||||
# On GPU, the batched FFI call to geqrf uses an algorithm with
|
||||
# different numerics than the unbatched version (which is used when
|
||||
# vmap_method="sequential"). Therefore, we need to include floating
|
||||
# point tolerance for this check.
|
||||
self.assertArraysAllClose(a, b)
|
||||
else:
|
||||
self.assertArraysEqual(a, b)
|
||||
|
||||
@jtu.run_on_devices("gpu")
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def testVectorizedDeprecation(self):
|
||||
pivots_size = 4
|
||||
shape = (10, pivots_size)
|
||||
permutation_size = 2 * pivots_size
|
||||
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1,
|
||||
dtype=np.int32)
|
||||
pivots = jnp.broadcast_to(pivots, shape)
|
||||
x = self.rng().randn(3, 5, 4).astype(np.float32)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True)
|
||||
ffi_call_geqrf(x, vectorized=True)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
jax.vmap(
|
||||
lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size))(pivots)
|
||||
jax.vmap(ffi_call_geqrf)(x)
|
||||
|
||||
def testBackwardCompatSyntax(self):
|
||||
def fun(x):
|
||||
@ -280,20 +272,82 @@ class FfiTest(jtu.JaxTestCase):
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
jax.jit(fun).lower(jnp.ones(5))
|
||||
|
||||
def testInputOutputAliases(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]")
|
||||
|
||||
# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation`
|
||||
# custom call target because that's the only one in jaxlib that uses the
|
||||
# new FFI interface. Once more are available, consider using something that
|
||||
# can be run on multiple platforms.
|
||||
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs):
|
||||
return jex.ffi.ffi_call(
|
||||
"cu_lu_pivots_to_permutation",
|
||||
jax.ShapeDtypeStruct(
|
||||
shape=pivots.shape[:-1] + (permutation_size,),
|
||||
dtype=pivots.dtype,
|
||||
),
|
||||
**kwargs,
|
||||
)(pivots)
|
||||
def testInvalidInputOutputAliases(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x)
|
||||
with self.assertRaisesRegex(ValueError, "with input index"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x)
|
||||
with self.assertRaisesRegex(ValueError, "with output index"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32),
|
||||
input_output_aliases={0: 0})(x)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"referring to an input with abstract value"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape,
|
||||
x.dtype),
|
||||
input_output_aliases={0: 0})(x)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"referring to an input with abstract value"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def testLegacyBackendConfig(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, custom_call_api_version=2,
|
||||
legacy_backend_config="12345")(x)
|
||||
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
self.assertRegex(hlo, 'backend_config = "12345"')
|
||||
|
||||
def testInvalidBackendConfig(self):
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x, legacy_backend_config="12345")(x)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The use of the legacy_backend_config"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
def fun(x):
|
||||
return jex.ffi.ffi_call("test", x,
|
||||
custom_call_api_version=2)(x, attribute=1)
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
"The use of ffi_call attributes requires"):
|
||||
jax.jit(fun).lower(jnp.ones(5)).as_text()
|
||||
|
||||
|
||||
def ffi_call_geqrf(x, **kwargs):
|
||||
assert x.dtype == np.float32
|
||||
ndim = x.ndim
|
||||
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
|
||||
output_types = [
|
||||
x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)]
|
||||
|
||||
def call(platform, x):
|
||||
target_name = dict(
|
||||
cpu="lapack_sgeqrf_ffi",
|
||||
rocm="hipsolver_geqrf_ffi",
|
||||
cuda="cusolver_geqrf_ffi",
|
||||
)[platform]
|
||||
return jex.ffi.ffi_call(
|
||||
target_name, output_types, input_output_aliases={0: 0},
|
||||
input_layouts=[x_major_to_minor],
|
||||
output_layouts=[x_major_to_minor, None],
|
||||
**kwargs)(x)
|
||||
|
||||
return lax.platform_dependent(
|
||||
x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"),
|
||||
cuda=partial(call, "cuda"))
|
||||
|
||||
|
||||
class MlirRegisterLoweringTest(jtu.JaxTestCase):
|
||||
|
Loading…
x
Reference in New Issue
Block a user