rocm_jax/jax/_src/ffi.py

782 lines
30 KiB
Python
Raw Normal View History

# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
2024-10-13 20:24:58 -04:00
from collections.abc import Callable, Mapping, Sequence
import ctypes
import functools
import os
from typing import Any, overload
import numpy as np
import jax
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
from jax._src import effects
from jax._src import util
[xla:python] Add a mechanism for "batch partitioning" of FFI calls. This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported. In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`. The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner. In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation. PiperOrigin-RevId: 724367877
2025-02-07 09:13:34 -08:00
from jax._src import xla_bridge
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import jaxlib
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray,
Shape)
2024-10-13 20:24:58 -04:00
# TODO(dfm): Remove after 6 months or less because there aren't any offical
# compatibility guarantees for jax.extend (see JEP 15856)
# Added Oct 13, 2024
deprecations.register("jax-ffi-call-args")
map, unsafe_map = util.safe_map, map
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None
def register_ffi_target(
name: str,
fn: Any,
platform: str = "cpu",
api_version: int = 1,
**kwargs: Any,
) -> None:
"""Registers a foreign function target.
Args:
name: the name of the target.
fn: a ``PyCapsule`` object containing the function pointer, or a ``dict``
where the keys are FFI stage names (e.g. `"execute"`) and the values are
``PyCapsule`` objects continaing a pointer to the handler for that stage.
platform: the target platform.
api_version: the XLA custom call API version to use. Supported versions are:
1 (default) for the typed FFI or 0 for the earlier "custom call" API.
kwargs: any extra keyword arguments are passed directly to
:func:`~jaxlib.xla_client.register_custom_call_target` for more advanced
use cases.
"""
return xla_client.register_custom_call_target(name, fn, platform, api_version,
**kwargs)
def register_ffi_type_id(
name: str,
obj: Any,
platform: str = "cpu",
) -> None:
"""Registers a custom type ID for a FFI target.
Args:
name: the name of the type ID. This name must be unique within the process.
obj: a ``PyCapsule`` object encapsulating a pointer to the type ID.
platform: the target platform.
"""
return xla_client.register_custom_type_id(name, obj, platform=platform)
[xla:python] Add a mechanism for "batch partitioning" of FFI calls. This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported. In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`. The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner. In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation. PiperOrigin-RevId: 724367877
2025-02-07 09:13:34 -08:00
def register_ffi_target_as_batch_partitionable(name: str) -> None:
"""Registers an FFI target as batch partitionable.
Args:
name: the name of the target.
"""
xla_client.register_custom_call_as_batch_partitionable(name)
xla_bridge.register_plugin_callbacks(
functools.partial(xla_client.register_custom_call_as_batch_partitionable,
name))
def pycapsule(funcptr):
"""Wrap a ctypes function pointer in a PyCapsule.
The primary use of this function, and the reason why it lives with in the
2024-12-20 11:26:04 +00:00
``jax.ffi`` submodule, is to wrap function calls from external compiled
libraries to be registered as XLA custom calls.
Example usage::
import ctypes
import jax
from jax.lib import xla_client
libfoo = ctypes.cdll.LoadLibrary('./foo.so')
xla_client.register_custom_call_target(
name="bar",
2024-12-20 11:26:04 +00:00
fn=jax.ffi.pycapsule(libfoo.bar),
platform=PLATFORM,
api_version=API_VERSION
)
Args:
funcptr: A function pointer loaded from a dynamic library using ``ctypes``.
Returns:
An opaque ``PyCapsule`` object wrapping ``funcptr``.
"""
destructor = ctypes.CFUNCTYPE(None, ctypes.py_object)
builder = ctypes.pythonapi.PyCapsule_New
builder.restype = ctypes.py_object
builder.argtypes = (ctypes.c_void_p, ctypes.c_char_p, destructor)
return builder(funcptr, None, destructor(0))
def include_dir() -> str:
"""Get the path to the directory containing header files bundled with jaxlib"""
jaxlib_dir = os.path.dirname(os.path.abspath(jaxlib.__file__))
return os.path.join(jaxlib_dir, "include")
def _aval_shape(aval: core.AbstractValue) -> Shape:
return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error
def _convert_layout_for_lowering(
aval: core.AbstractValue, layout: FfiLayoutOptions = None) -> Sequence[int]:
"""Convert a layout to the minor-to-major order used by the custom call API."""
if layout is None:
return tuple(reversed(range(len(_aval_shape(aval)))))
elif isinstance(layout, DeviceLocalLayout):
if layout._tiling is not None:
raise ValueError("The FFI does not support layouts with tiling")
return layout.major_to_minor[::-1]
else:
return tuple(layout)
def build_ffi_lowering_function(
call_target_name: str,
*,
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
result_layouts: Sequence[FfiLayoutOptions] | None = None,
backend_config: Mapping[str, ir.Attribute] | str | None = None,
**lowering_args: Any,
) -> Callable[..., ir.Operation]:
"""Build a lowering op for an foreign function interface (FFI) target.
By default, this lowering rule can use the input and output abstract values to
compute the input and output types and shapes for the custom call, assuming
row-major layouts.
Note that layouts passed to this function as tuples should be in
minor-to-major order (as expected by XLA) rather than major-to-minor as used
2024-12-20 11:26:04 +00:00
by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``.
If keyword arguments are passed to the lowering rule, these are treated as
attributes, and added to `backend_config`.
Args:
call_target_name: The name of the custom call target.
operand_layouts: A sequence of layouts (dimension orders) for each operand.
By default, the operands are assumed to be row-major.
result_layouts: A sequence of layouts (dimension orders) for each result.
By default, the results are assumed to be row-major.
backend_config: Configuration data for the custom call. Any keyword
arguments passed to the lowering rule will added to this dictionary.
lowering_args: Any other arguments to :func:`mlir.custom_call` will also be
passed through if provided as extra arguments to this function.
"""
def _lowering_op(
ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any
) -> ir.Operation:
kwargs = dict(lowering_args)
kwargs.setdefault("api_version", 4)
if kwargs["api_version"] >= 4:
if backend_config is not None and not isinstance(backend_config, dict):
raise ValueError(
"When api_version > 4, backend_config must be a dictionary.")
kwargs["backend_config"] = dict(
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
else:
if params:
raise ValueError(
"The use of ffi_call attributes requires a custom call API version "
f"of at least 4; got api_version={kwargs['api_version']}.")
kwargs["backend_config"] = backend_config
if "result_types" not in kwargs:
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
if operand_layouts is None:
kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in)
else:
kwargs["operand_layouts"] = [
_convert_layout_for_lowering(*args)
for args in zip(ctx.avals_in, operand_layouts)]
if result_layouts is None:
kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out)
else:
kwargs["result_layouts"] = [
_convert_layout_for_lowering(*args)
for args in zip(ctx.avals_out, result_layouts)]
if "result_shapes" not in kwargs and not all(
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
kwargs["result_shapes"] = [
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval)))
for aval in ctx.avals_out]
return mlir.custom_call(call_target_name, operands=operands, **kwargs)
return _lowering_op
def ffi_lowering(
call_target_name: str,
*,
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
result_layouts: Sequence[FfiLayoutOptions] | None = None,
backend_config: Mapping[str, ir.Attribute] | str | None = None,
**lowering_args: Any
) -> mlir.LoweringRule:
"""Build a lowering rule for an foreign function interface (FFI) target.
By default, this lowering rule can use the input and output abstract values to
compute the input and output types and shapes for the custom call, assuming
row-major layouts.
Note that layouts passed to this function as tuples should be in
minor-to-major order (as expected by XLA) rather than major-to-minor as used
by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``.
If keyword arguments are passed to the lowering rule, these are treated as
attributes, and added to `backend_config`.
Args:
call_target_name: The name of the custom call target.
operand_layouts: A sequence of layouts (dimension orders) for each operand.
By default, the operands are assumed to be row-major.
result_layouts: A sequence of layouts (dimension orders) for each result.
By default, the results are assumed to be row-major.
backend_config: Configuration data for the custom call. Any keyword
arguments passed to the lowering rule will added to this dictionary.
lowering_args: Any other arguments to :func:`mlir.custom_call` will also be
passed through if provided as extra arguments to this function.
"""
def _lowering(
ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any
) -> Sequence[ir.Value | Sequence[ir.Value]]:
result = build_ffi_lowering_function(
call_target_name,
operand_layouts=operand_layouts,
result_layouts=result_layouts,
backend_config=backend_config,
**lowering_args,
)(ctx, *operands, **params)
return result.results # type: ignore
return _lowering
ResultMetadata = DuckTypedArray | core.AbstractToken
def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]:
avals: list[core.AbstractValue] = []
for idx, result in enumerate(results):
if isinstance(result, core.AbstractToken):
avals.append(result)
else:
if not hasattr(result, "shape") or not hasattr(result, "dtype"):
raise ValueError(
"All elements of result_shape_dtypes must have 'shape' and 'dtype' "
f"attributes. Got {result} at position {idx}.")
avals.append(core.ShapedArray(result.shape, result.dtype))
return tuple(avals)
def _check_compatible_avals(a: core.AbstractValue, b: core.AbstractValue) -> bool:
if isinstance(a, core.AbstractToken) and isinstance(b, core.AbstractToken):
return True
if getattr(a, "shape", ()) != getattr(b, "shape", ()):
return False
if getattr(a, "dtype", ()) != getattr(b, "dtype", ()):
return False
return True
def _convert_layouts_for_ffi_call(
avals: Sequence[core.AbstractValue],
layouts: Sequence[FfiLayoutOptions]) -> tuple[Sequence[int], ...]:
return tuple(
_convert_layout_for_lowering(
aval,
layout if layout is None or isinstance(layout, DeviceLocalLayout)
else layout[::-1]
)
for aval, layout in zip(avals, layouts))
# ffi_call() returns as many results as result_shape_dtypes.
@overload
def ffi_call(
target_name: str,
result_shape_dtypes: ResultMetadata,
*deprecated_args: ArrayLike,
has_side_effect: bool = ...,
vmap_method: str | None = ...,
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ...,
input_output_aliases: dict[int, int] | None = ...,
custom_call_api_version: int = ...,
legacy_backend_config: str | None = ...,
vectorized: bool | DeprecatedArg = ...,
**deprecated_kwargs: Any,
) -> Callable[..., Array] | Array:
...
@overload
def ffi_call(
target_name: str,
result_shape_dtypes: Sequence[ResultMetadata],
*deprecated_args: ArrayLike,
has_side_effect: bool = ...,
vmap_method: str | None = ...,
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ...,
input_output_aliases: dict[int, int] | None = ...,
custom_call_api_version: int = ...,
legacy_backend_config: str | None = ...,
vectorized: bool | DeprecatedArg = ...,
**deprecated_kwargs: Any,
) -> Callable[..., Sequence[Array]] | Sequence[Array]:
...
def ffi_call(
target_name: str,
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
2024-10-13 20:24:58 -04:00
*deprecated_args: ArrayLike,
has_side_effect: bool = False,
vmap_method: str | None = None,
input_layouts: Sequence[FfiLayoutOptions] | None = None,
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = None,
input_output_aliases: dict[int, int] | None = None,
custom_call_api_version: int = 4,
legacy_backend_config: str | None = None,
vectorized: bool | DeprecatedArg = DeprecatedArg(),
2024-10-13 20:24:58 -04:00
**deprecated_kwargs: Any,
) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]:
"""Call a foreign function interface (FFI) target.
2024-10-13 20:24:58 -04:00
See the :ref:`ffi-tutorial` tutorial for more information.
Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under
:func:`~jax.vmap` depends on the value of ``vmap_method``. See the
:func:`~jax.pure_callback` documenation for more details about the allowed
values and examples of their behavior.
The current default behavior is to use ``vmap_method="sequential"`` when
not specified, but this behavior is deprecated, and in the future, the
default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is
explicitly specified.
Args:
target_name: the name of the XLA FFI custom call target that was registered
2024-12-20 11:26:04 +00:00
using :func:`~jax.ffi.register_ffi_target`.
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
``dtype`` attributes which are expected to match the shape and dtype of
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often
used to define the elements of ``result_shape_dtypes``.
``jax.core.abstract_token`` may be used to represent a token-typed output.
has_side_effect: boolean specifying whether the custom call has side
effects. When ``True``, the FFI call will be executed even when the
outputs are not used.
vmap_method: string specifying how the FFI call transforms under
:func:`~jax.vmap` as described above.
input_layouts: a sequence of layouts for each input argument. In each case,
the layout can be (a) ``None`` indicating that this input is in default
row-major order, (b) a ``DeviceLocalLayout`` specifying the axis order,
or (c) a sequence of integers specifying the major-to-minor axis
ordering. Users who are familiar with XLA layouts should note that this
function expects layouts in major-to-minor order instead of the
minor-to-major order that XLA uses. For example, a batch of row-major
matrices could be specified using the layout ``[0, 1, 2]``, whereas a
batch of column-major matrices would have layout ``[0, 2, 1]``. In both
of these examples, the leading/batch dimension is the "slowest" axis. The
``input_layouts`` parameter should be used to request the memory layout
expected by the FFI call target, and XLA will ensure that the buffers
have the correct layouts before the handler is executed.
output_layouts: like ``input_layouts``, but specifying the required layouts
for the output arrays.
input_output_aliases: a dictionary where the keys are input indices and the
values are output indices. This mapping indicates which output arrays
alias specific input arrays.
custom_call_api_version: the version number of the custom call API
implemented by the FFI target ``target_name``. The only formally
supported version is the typed FFI API with ``custom_call_api_version=4``,
but earlier unsupported custom calls can be executed using this argument.
legacy_backend_config: for legacy targets implemented using
``custom_call_api_version<4``, attributes are passed using the opaque
string representation provided by this argument. This parameter cannot be
used with ``custom_call_api_version>=4``.
Returns:
2024-10-13 20:24:58 -04:00
A function that can be called with the input arrays as positional arguments
to execute the FFI handler. Any keyword arguments are passed as named
attributes to the FFI handler using XLA's FFI interface.
"""
if not isinstance(vectorized, DeprecatedArg) and not vectorized is None:
deprecations.warn(
"jax-callback-vectorized",
"The vectorized argument of ffi_call is deprecated and setting "
"it will soon raise an error. To avoid an error in the future, and to "
"suppress this warning, please use the vmap_method argument instead.",
stacklevel=2)
if vmap_method is not None:
raise ValueError(
"the vectorized and vmap_method arguments of ffi_call cannot "
"be used together. Please use the vmap_method argument.")
vmap_method = "legacy_vectorized" if vectorized else "sequential"
allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims",
"broadcast_all", "legacy_vectorized", None]
if vmap_method not in allowed_vmap_methods:
raise ValueError(
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
f"but got: {vmap_method}")
output_layouts_: Sequence[FfiLayoutOptions] | None
if isinstance(result_shape_dtypes, Sequence):
output_layouts_ = output_layouts # type: ignore
multiple_results = True
result_avals = _result_avals(result_shape_dtypes)
else:
multiple_results = False
result_avals = _result_avals((result_shape_dtypes,))
output_layouts_ = (output_layouts,) # type: ignore
if custom_call_api_version >= 4 and legacy_backend_config is not None:
raise ValueError(
"The use of the legacy_backend_config parameter requires "
f"custom_call_api_version < 4; got {custom_call_api_version}.")
2024-10-13 20:24:58 -04:00
def wrapped(*args: ArrayLike, **kwargs: Any):
in_avals = [core.get_aval(x) for x in args]
if input_layouts is None:
static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals))
else:
if len(input_layouts) != len(in_avals):
raise ValueError(
f"The number of input arguments ({len(in_avals)}) must equal the "
f"number of input layouts ({len(input_layouts)}).")
static_input_layouts = _convert_layouts_for_ffi_call(in_avals,
input_layouts)
if output_layouts_ is None:
static_output_layouts = tuple(map(_convert_layout_for_lowering,
result_avals))
else:
if len(output_layouts_) != len(result_avals):
raise ValueError(
f"The number of outputs ({len(result_avals)}) must equal the "
f"number of output layouts ({len(output_layouts_)}).")
static_output_layouts = _convert_layouts_for_ffi_call(result_avals,
output_layouts_)
static_input_output_aliases: tuple[tuple[int, int], ...] = ()
if input_output_aliases is not None:
for i_idx, o_idx in sorted(input_output_aliases.items()):
i_idx, o_idx = int(i_idx), int(o_idx)
if i_idx >= len(args):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"with input index {i_idx} outside the range [0, "
f"{len(args)}).")
if o_idx >= len(result_avals):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"with output index {o_idx} outside the range [0, "
f"{len(result_avals)}).")
in_aval = in_avals[i_idx]
out_aval = result_avals[o_idx]
if not _check_compatible_avals(in_aval, out_aval):
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"referring to an input with abstract value {in_aval} and an "
f"output with a different abstract value {out_aval}.")
if static_input_layouts[i_idx] != static_output_layouts[o_idx]:
raise ValueError(
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
f"referring to an input with layout {static_input_layouts[i_idx]} "
"and an output with a different layout "
f"{static_output_layouts[o_idx]}.")
static_input_output_aliases += ((i_idx, o_idx),)
2024-10-13 20:24:58 -04:00
results = ffi_call_p.bind(
*args,
result_avals=result_avals,
vectorized=vectorized,
vmap_method=vmap_method,
target_name=target_name,
has_side_effect=has_side_effect,
input_layouts=static_input_layouts,
output_layouts=static_output_layouts,
input_output_aliases=static_input_output_aliases,
custom_call_api_version=custom_call_api_version,
legacy_backend_config=legacy_backend_config,
2024-10-13 20:24:58 -04:00
attributes=_wrap_kwargs_hashable(kwargs),
)
if multiple_results:
if isinstance(result_shape_dtypes, tuple):
return tuple(results)
2024-10-13 20:24:58 -04:00
return results
else:
return results[0]
if deprecated_args or deprecated_kwargs:
deprecations.warn(
"jax-ffi-call-args",
"Calling ffi_call directly with input arguments is deprecated. "
"Instead, ffi_call should be used to construct a callable, which can "
"then be called with the appropriate inputs. For example,\n"
" ffi_call('target_name', output_type, x, argument=5)\n"
"should be replaced with\n"
" ffi_call('target_name', output_type)(x, argument=5)",
stacklevel=2)
return wrapped(*deprecated_args, **deprecated_kwargs)
else:
2024-10-13 20:24:58 -04:00
return wrapped
# ffi_call must support some small non-hashable input arguments, like np.arrays
# and dicts, to support calling FFI targets with array inputs or user defined
# structs. Since these arguments will eventually be embedded in the HLO as
# dense attributes, we assume that they are small and hash by making an
# immutable copy and hashing by value.
2024-10-13 20:24:58 -04:00
def _wrap_kwargs_hashable(kwargs: dict[str, Any]) -> Sequence[tuple[str, Any]]:
hashable_kwargs: list[tuple[str, Any]] = []
for k, v in sorted(kwargs.items()):
if isinstance(v, np.ndarray):
2024-10-13 20:24:58 -04:00
hashable_kwargs.append((k, HashableArray(v)))
elif isinstance(v, dict):
2024-10-13 20:24:58 -04:00
hashable_kwargs.append((k, HashableDict(v)))
else:
try:
hash(v)
except TypeError as e:
raise TypeError(
f"Non-hashable keyword argument to ffi_call {k}: {v}") from e
else:
2024-10-13 20:24:58 -04:00
hashable_kwargs.append((k, v))
return tuple(hashable_kwargs)
2024-10-13 20:24:58 -04:00
def _unwrap_kwargs_hashable(kwargs: Sequence[tuple[str, Any]]) -> dict[str, Any]:
unwrapped_kwargs: dict[str, Any] = {}
2024-10-13 20:24:58 -04:00
for k, v in kwargs:
if isinstance(v, HashableArray):
unwrapped_kwargs[k] = v.val
elif isinstance(v, HashableDict):
unwrapped_kwargs[k] = dict(v.val)
else:
unwrapped_kwargs[k] = v
return unwrapped_kwargs
class HashableArray:
__slots__ = ["val"]
def __init__(self, val):
assert isinstance(val, np.ndarray)
self.val = np.copy(val)
self.val.setflags(write=False)
def __repr__(self):
return f"HashableArray({self.val})"
def __hash__(self):
return hash((self.val.shape, self.val.dtype, self.val.tobytes()))
def __eq__(self, other):
return isinstance(other, HashableArray) and np.array_equal(self.val, other.val)
class HashableDict:
__slots__ = ["val"]
def __init__(self, val):
assert isinstance(val, dict)
self.val = tuple(sorted(val.items()))
def __repr__(self):
return f"HashableDict({dict(self.val)})"
def __hash__(self):
return hash(self.val)
def __eq__(self, other):
return isinstance(other, HashableDict) and self.val == other.val
class FfiEffect(effects.Effect):
def __str__(self):
return "FFI"
_FfiEffect = FfiEffect()
effects.lowerable_effects.add_type(FfiEffect)
effects.control_flow_allowed_effects.add_type(FfiEffect)
def ffi_call_abstract_eval(
*avals_in,
result_avals: tuple[core.AbstractValue, ...],
has_side_effect: bool,
**_,
):
del avals_in # unused
effects = {_FfiEffect} if has_side_effect else core.no_effects
return result_avals, effects
def ffi_call_jvp(*args, target_name, **_):
del args
raise ValueError(
f"The FFI call to `{target_name}` cannot be differentiated. "
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")
def ffi_call_transpose(*args, target_name, **_):
del args
raise ValueError(
f"The FFI call to `{target_name}` cannot be differentiated. "
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")
def ffi_call_lowering(
ctx: mlir.LoweringRuleContext,
*operands: ir.Value,
target_name: str,
has_side_effect: bool,
input_layouts: Sequence[Sequence[int]],
output_layouts: Sequence[Sequence[int]],
input_output_aliases: Sequence[tuple[int, int]],
custom_call_api_version: int,
legacy_backend_config: str | None,
2024-10-13 20:24:58 -04:00
attributes: Sequence[tuple[str, Any]],
**_,
) -> Sequence[ir.Value]:
rule = ffi_lowering(target_name, has_side_effect=has_side_effect,
operand_layouts=input_layouts,
result_layouts=output_layouts,
operand_output_aliases=dict(input_output_aliases),
api_version=custom_call_api_version,
backend_config=legacy_backend_config)
2024-10-13 20:24:58 -04:00
return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes))
def ffi_batching_rule(
prim,
args,
dims,
*,
vectorized: bool | None | DeprecatedArg,
vmap_method: str | None,
result_avals: Sequence[core.ShapedArray],
**kwargs: Any,
):
if isinstance(vectorized, DeprecatedArg) and vmap_method is None:
deprecations.warn(
"jax-callback-vectorized",
f"The default behavior of {prim.name} under vmap will soon "
"change. Currently, the default behavior is to generate a sequential "
"vmap (i.e. a loop), but in the future the default will be to raise "
"an error. To keep the current default, set vmap_method='sequential'.",
stacklevel=6)
vmap_method = "sequential"
axis_size, = {a.shape[d] for a, d in zip(args, dims)
if d is not batching.not_mapped}
new_args = [arg if dim is batching.not_mapped else
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
batched_result_avals = tuple(
core.unmapped_aval(axis_size, 0, aval) for aval in result_avals)
# For FFI calls we must update the layouts. We handle the output layouts
# here, but the input layout updates depend on the vmap_method parameter.
if (
vmap_method not in ("sequential", "sequential_unrolled") and
kwargs.get("output_layouts") is not None
):
kwargs["output_layouts"] = tuple(
None if layout is None else tuple(n + 1 for n in layout) + (0,)
for layout in kwargs["output_layouts"])
if vmap_method == "legacy_vectorized":
# This method is kept to support the behavior that was previously exposed
# when using `vectorized=True`.
if kwargs.get("input_layouts") is not None:
kwargs["input_layouts"] = tuple(
layout if d is batching.not_mapped else
(None if layout is None else tuple(n + 1 for n in layout) + (0,))
for layout, d in zip(kwargs["input_layouts"], dims))
outvals = prim.bind(
*new_args,
vectorized=vectorized,
vmap_method=vmap_method,
result_avals=batched_result_avals,
**kwargs,
)
elif vmap_method == "expand_dims" or vmap_method == "broadcast_all":
size = axis_size if vmap_method == "broadcast_all" else 1
bcast_args = [
jax.lax.broadcast(x, (size,)) if d is batching.not_mapped else x
for x, d in zip(new_args, dims)]
if kwargs.get("input_layouts") is not None:
kwargs["input_layouts"] = tuple(
None if layout is None else tuple(n + 1 for n in layout) + (0,)
for layout in kwargs["input_layouts"])
outvals = prim.bind(
*bcast_args,
vectorized=vectorized,
vmap_method=vmap_method,
result_avals=batched_result_avals,
**kwargs,
)
elif vmap_method == "sequential" or vmap_method == "sequential_unrolled":
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(batched_args):
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
return prim.bind(
*merged_args,
result_avals=result_avals,
vectorized=vectorized,
vmap_method=vmap_method,
**kwargs,
)
unroll = vmap_method == "sequential_unrolled"
g = lambda _, x: ((), _batch_fun(x))
_, outvals = jax.lax.scan(g, (), batched_args, unroll=unroll)
else:
raise NotImplementedError(
f"vmap is only supported for the {prim.name} primitive when vmap_method "
"is one of 'sequential', 'sequential_unrolled', 'expand_dims', "
f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.")
return tuple(outvals), (0,) * len(outvals)
ffi_call_p = core.Primitive("ffi_call")
ffi_call_p.multiple_results = True
dispatch.simple_impl(ffi_call_p)
ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval)
ad.primitive_jvps[ffi_call_p] = ffi_call_jvp
ad.primitive_transposes[ffi_call_p] = ffi_call_transpose
batching.primitive_batchers[ffi_call_p] = functools.partial(
ffi_batching_rule, ffi_call_p)
mlir.register_lowering(ffi_call_p, ffi_call_lowering)