2024-05-21 10:22:13 -04:00
|
|
|
# Copyright 2024 The JAX Authors.
|
|
|
|
#
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
#
|
|
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
#
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
# limitations under the License.
|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-10-13 20:24:58 -04:00
|
|
|
from collections.abc import Callable, Mapping, Sequence
|
2024-06-07 11:47:04 -07:00
|
|
|
import ctypes
|
|
|
|
import functools
|
|
|
|
import os
|
2024-11-08 14:20:57 -08:00
|
|
|
from typing import Any, overload
|
2024-05-21 10:22:13 -04:00
|
|
|
|
2024-09-26 15:02:58 -04:00
|
|
|
import numpy as np
|
|
|
|
|
2025-02-21 09:45:14 -08:00
|
|
|
import jax
|
2024-06-07 11:47:04 -07:00
|
|
|
from jax._src import core
|
2024-07-19 17:24:33 -04:00
|
|
|
from jax._src import deprecations
|
2024-06-07 11:47:04 -07:00
|
|
|
from jax._src import dispatch
|
2024-09-27 14:50:44 -04:00
|
|
|
from jax._src import effects
|
2024-06-07 11:47:04 -07:00
|
|
|
from jax._src import util
|
[xla:python] Add a mechanism for "batch partitioning" of FFI calls.
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.
In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.
The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.
In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.
PiperOrigin-RevId: 724367877
2025-02-07 09:13:34 -08:00
|
|
|
from jax._src import xla_bridge
|
2024-06-07 11:47:04 -07:00
|
|
|
from jax._src.interpreters import ad
|
|
|
|
from jax._src.interpreters import batching
|
2024-05-29 09:38:08 -04:00
|
|
|
from jax._src.interpreters import mlir
|
2024-09-05 12:19:59 -04:00
|
|
|
from jax._src.layout import DeviceLocalLayout
|
2024-05-21 10:22:13 -04:00
|
|
|
from jax._src.lib import jaxlib
|
2024-07-19 08:05:57 -04:00
|
|
|
from jax._src.lib import xla_client
|
2024-05-29 09:38:08 -04:00
|
|
|
from jax._src.lib.mlir import ir
|
2024-07-19 17:24:33 -04:00
|
|
|
from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray,
|
|
|
|
Shape)
|
2024-06-07 11:47:04 -07:00
|
|
|
|
2024-10-13 20:24:58 -04:00
|
|
|
# TODO(dfm): Remove after 6 months or less because there aren't any offical
|
|
|
|
# compatibility guarantees for jax.extend (see JEP 15856)
|
|
|
|
# Added Oct 13, 2024
|
|
|
|
deprecations.register("jax-ffi-call-args")
|
|
|
|
|
2024-06-07 11:47:04 -07:00
|
|
|
map, unsafe_map = util.safe_map, map
|
2024-09-05 12:19:59 -04:00
|
|
|
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None
|
2024-05-21 10:22:13 -04:00
|
|
|
|
|
|
|
|
2024-07-19 08:05:57 -04:00
|
|
|
def register_ffi_target(
|
|
|
|
name: str,
|
|
|
|
fn: Any,
|
|
|
|
platform: str = "cpu",
|
|
|
|
api_version: int = 1,
|
|
|
|
**kwargs: Any,
|
|
|
|
) -> None:
|
|
|
|
"""Registers a foreign function target.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: the name of the target.
|
|
|
|
fn: a ``PyCapsule`` object containing the function pointer, or a ``dict``
|
|
|
|
where the keys are FFI stage names (e.g. `"execute"`) and the values are
|
|
|
|
``PyCapsule`` objects continaing a pointer to the handler for that stage.
|
|
|
|
platform: the target platform.
|
|
|
|
api_version: the XLA custom call API version to use. Supported versions are:
|
|
|
|
1 (default) for the typed FFI or 0 for the earlier "custom call" API.
|
|
|
|
kwargs: any extra keyword arguments are passed directly to
|
|
|
|
:func:`~jaxlib.xla_client.register_custom_call_target` for more advanced
|
|
|
|
use cases.
|
|
|
|
"""
|
|
|
|
return xla_client.register_custom_call_target(name, fn, platform, api_version,
|
|
|
|
**kwargs)
|
|
|
|
|
|
|
|
|
2025-01-07 07:28:21 -08:00
|
|
|
def register_ffi_type_id(
|
|
|
|
name: str,
|
|
|
|
obj: Any,
|
|
|
|
platform: str = "cpu",
|
|
|
|
) -> None:
|
|
|
|
"""Registers a custom type ID for a FFI target.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: the name of the type ID. This name must be unique within the process.
|
|
|
|
obj: a ``PyCapsule`` object encapsulating a pointer to the type ID.
|
|
|
|
platform: the target platform.
|
|
|
|
"""
|
|
|
|
return xla_client.register_custom_type_id(name, obj, platform=platform)
|
|
|
|
|
|
|
|
|
[xla:python] Add a mechanism for "batch partitioning" of FFI calls.
This is the first in a series of changes to add a simple API for supporting a set of common sharding and partitioning patterns for FFI calls. The high level motivation is that custom calls (including FFI calls) are opaque to the SPMD partitioner, and the only ways to customize the partitioning behavior is to (a) explicitly register an `xla::CustomCallPartitoner` with XLA, or (b) use the `jax.experimental.custom_partitioning` APIs. Option (a) isn't generally practical for most use cases where the FFI handler lives in an external binary. Option (b) is flexible, and supports all common use cases, but it requires embedding Python callbacks in to the HLO, which can lead to issues including cache misses. Furthermore, `custom_partitioning` is overpowered for many use cases, where only (what I will call) "batch partitioning" is supported.
In this case, "batch partitioning" refers to the behavior of many FFI calls where they can be trivially partitioned on some number of (leading) dimensions, with the same call being executed independently on each shard of data. If the data are sharded on non-batch dimensions, partitioning will still re-shard the data to be replicated on the non-batch dimensions. This kind of partitioning logic applies to all the LAPACK/cuSOLVER/etc.-backed linear algebra functions in jaxlib, as well as some external users of `custom_partitioning`.
The approach I'm taking here is to add a new registration function to the XLA client, which let's a user label their FFI call as batch partitionable. Then, when lowering the custom call, the user passes the number of batch dimensions as a frontend attribute, which is then interpreted by the SPMD partitioner.
In parallel with this change, shardy has added support for sharding propagation across custom calls using a string representation that is similar in spirit to this approach, but somewhat more general. However, the shardy implementation still requires a Python callback for the partitioning step, so it doesn't (yet!) solve all of the relevant problems with the `custom_partitioning` approach. Ultimately, it should be possible to have the partitioner parse the shardy sharding rule representation, but I wanted to start with the minimal implementation.
PiperOrigin-RevId: 724367877
2025-02-07 09:13:34 -08:00
|
|
|
def register_ffi_target_as_batch_partitionable(name: str) -> None:
|
|
|
|
"""Registers an FFI target as batch partitionable.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
name: the name of the target.
|
|
|
|
"""
|
|
|
|
xla_client.register_custom_call_as_batch_partitionable(name)
|
|
|
|
xla_bridge.register_plugin_callbacks(
|
|
|
|
functools.partial(xla_client.register_custom_call_as_batch_partitionable,
|
|
|
|
name))
|
|
|
|
|
|
|
|
|
2024-06-03 15:30:31 -04:00
|
|
|
def pycapsule(funcptr):
|
2024-06-07 11:00:50 -04:00
|
|
|
"""Wrap a ctypes function pointer in a PyCapsule.
|
2024-06-03 15:30:31 -04:00
|
|
|
|
2024-06-07 11:00:50 -04:00
|
|
|
The primary use of this function, and the reason why it lives with in the
|
2024-12-20 11:26:04 +00:00
|
|
|
``jax.ffi`` submodule, is to wrap function calls from external compiled
|
|
|
|
libraries to be registered as XLA custom calls.
|
2024-06-07 11:00:50 -04:00
|
|
|
|
|
|
|
Example usage::
|
2024-06-03 15:30:31 -04:00
|
|
|
|
|
|
|
import ctypes
|
|
|
|
import jax
|
|
|
|
from jax.lib import xla_client
|
|
|
|
|
2024-06-07 11:00:50 -04:00
|
|
|
libfoo = ctypes.cdll.LoadLibrary('./foo.so')
|
2024-06-03 15:30:31 -04:00
|
|
|
xla_client.register_custom_call_target(
|
|
|
|
name="bar",
|
2024-12-20 11:26:04 +00:00
|
|
|
fn=jax.ffi.pycapsule(libfoo.bar),
|
2024-06-03 15:30:31 -04:00
|
|
|
platform=PLATFORM,
|
|
|
|
api_version=API_VERSION
|
|
|
|
)
|
2024-06-07 11:00:50 -04:00
|
|
|
|
|
|
|
Args:
|
|
|
|
funcptr: A function pointer loaded from a dynamic library using ``ctypes``.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
An opaque ``PyCapsule`` object wrapping ``funcptr``.
|
2024-06-03 15:30:31 -04:00
|
|
|
"""
|
2024-06-07 11:00:50 -04:00
|
|
|
destructor = ctypes.CFUNCTYPE(None, ctypes.py_object)
|
|
|
|
builder = ctypes.pythonapi.PyCapsule_New
|
|
|
|
builder.restype = ctypes.py_object
|
|
|
|
builder.argtypes = (ctypes.c_void_p, ctypes.c_char_p, destructor)
|
|
|
|
return builder(funcptr, None, destructor(0))
|
2024-06-03 15:30:31 -04:00
|
|
|
|
|
|
|
|
2024-05-21 10:22:13 -04:00
|
|
|
def include_dir() -> str:
|
|
|
|
"""Get the path to the directory containing header files bundled with jaxlib"""
|
|
|
|
jaxlib_dir = os.path.dirname(os.path.abspath(jaxlib.__file__))
|
|
|
|
return os.path.join(jaxlib_dir, "include")
|
2024-05-29 09:38:08 -04:00
|
|
|
|
|
|
|
|
2024-09-03 05:08:13 -07:00
|
|
|
def _aval_shape(aval: core.AbstractValue) -> Shape:
|
|
|
|
return () if aval is core.abstract_token else aval.shape # pytype: disable=attribute-error
|
|
|
|
|
|
|
|
|
2024-10-21 11:34:57 -04:00
|
|
|
def _convert_layout_for_lowering(
|
|
|
|
aval: core.AbstractValue, layout: FfiLayoutOptions = None) -> Sequence[int]:
|
2024-09-05 12:19:59 -04:00
|
|
|
"""Convert a layout to the minor-to-major order used by the custom call API."""
|
|
|
|
if layout is None:
|
2024-10-21 11:34:57 -04:00
|
|
|
return tuple(reversed(range(len(_aval_shape(aval)))))
|
2024-09-05 12:19:59 -04:00
|
|
|
elif isinstance(layout, DeviceLocalLayout):
|
|
|
|
if layout._tiling is not None:
|
|
|
|
raise ValueError("The FFI does not support layouts with tiling")
|
|
|
|
return layout.major_to_minor[::-1]
|
|
|
|
else:
|
2024-10-21 11:34:57 -04:00
|
|
|
return tuple(layout)
|
2024-09-03 05:08:13 -07:00
|
|
|
|
|
|
|
|
2025-02-21 09:45:14 -08:00
|
|
|
def build_ffi_lowering_function(
|
2024-05-29 09:38:08 -04:00
|
|
|
call_target_name: str,
|
|
|
|
*,
|
2024-09-05 12:19:59 -04:00
|
|
|
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
|
|
|
|
result_layouts: Sequence[FfiLayoutOptions] | None = None,
|
2024-10-21 11:34:57 -04:00
|
|
|
backend_config: Mapping[str, ir.Attribute] | str | None = None,
|
2025-02-21 09:45:14 -08:00
|
|
|
**lowering_args: Any,
|
|
|
|
) -> Callable[..., ir.Operation]:
|
|
|
|
"""Build a lowering op for an foreign function interface (FFI) target.
|
2024-05-29 09:38:08 -04:00
|
|
|
|
|
|
|
By default, this lowering rule can use the input and output abstract values to
|
|
|
|
compute the input and output types and shapes for the custom call, assuming
|
|
|
|
row-major layouts.
|
|
|
|
|
2024-10-21 11:34:57 -04:00
|
|
|
Note that layouts passed to this function as tuples should be in
|
|
|
|
minor-to-major order (as expected by XLA) rather than major-to-minor as used
|
2024-12-20 11:26:04 +00:00
|
|
|
by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``.
|
2024-10-21 11:34:57 -04:00
|
|
|
|
2024-05-29 09:38:08 -04:00
|
|
|
If keyword arguments are passed to the lowering rule, these are treated as
|
|
|
|
attributes, and added to `backend_config`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
call_target_name: The name of the custom call target.
|
|
|
|
operand_layouts: A sequence of layouts (dimension orders) for each operand.
|
|
|
|
By default, the operands are assumed to be row-major.
|
|
|
|
result_layouts: A sequence of layouts (dimension orders) for each result.
|
|
|
|
By default, the results are assumed to be row-major.
|
|
|
|
backend_config: Configuration data for the custom call. Any keyword
|
|
|
|
arguments passed to the lowering rule will added to this dictionary.
|
|
|
|
lowering_args: Any other arguments to :func:`mlir.custom_call` will also be
|
|
|
|
passed through if provided as extra arguments to this function.
|
|
|
|
"""
|
|
|
|
|
2025-02-21 09:45:14 -08:00
|
|
|
def _lowering_op(
|
2024-05-29 09:38:08 -04:00
|
|
|
ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any
|
2025-02-21 09:45:14 -08:00
|
|
|
) -> ir.Operation:
|
2024-05-29 09:38:08 -04:00
|
|
|
kwargs = dict(lowering_args)
|
|
|
|
kwargs.setdefault("api_version", 4)
|
2024-10-21 11:34:57 -04:00
|
|
|
if kwargs["api_version"] >= 4:
|
|
|
|
if backend_config is not None and not isinstance(backend_config, dict):
|
|
|
|
raise ValueError(
|
|
|
|
"When api_version > 4, backend_config must be a dictionary.")
|
|
|
|
kwargs["backend_config"] = dict(
|
|
|
|
backend_config or {}, **{k: mlir.ir_attribute(v) for k, v in params.items()})
|
|
|
|
else:
|
|
|
|
if params:
|
|
|
|
raise ValueError(
|
|
|
|
"The use of ffi_call attributes requires a custom call API version "
|
|
|
|
f"of at least 4; got api_version={kwargs['api_version']}.")
|
|
|
|
kwargs["backend_config"] = backend_config
|
2024-05-29 09:38:08 -04:00
|
|
|
if "result_types" not in kwargs:
|
|
|
|
kwargs["result_types"] = [mlir.aval_to_ir_type(aval) for aval in ctx.avals_out]
|
|
|
|
if operand_layouts is None:
|
2024-10-21 11:34:57 -04:00
|
|
|
kwargs["operand_layouts"] = map(_convert_layout_for_lowering, ctx.avals_in)
|
2024-08-19 07:19:22 -07:00
|
|
|
else:
|
2024-09-05 12:19:59 -04:00
|
|
|
kwargs["operand_layouts"] = [
|
2024-10-21 11:34:57 -04:00
|
|
|
_convert_layout_for_lowering(*args)
|
|
|
|
for args in zip(ctx.avals_in, operand_layouts)]
|
2024-05-29 09:38:08 -04:00
|
|
|
if result_layouts is None:
|
2024-10-21 11:34:57 -04:00
|
|
|
kwargs["result_layouts"] = map(_convert_layout_for_lowering, ctx.avals_out)
|
2024-08-19 07:19:22 -07:00
|
|
|
else:
|
2024-09-05 12:19:59 -04:00
|
|
|
kwargs["result_layouts"] = [
|
2024-10-21 11:34:57 -04:00
|
|
|
_convert_layout_for_lowering(*args)
|
|
|
|
for args in zip(ctx.avals_out, result_layouts)]
|
2024-08-19 01:04:52 -07:00
|
|
|
if "result_shapes" not in kwargs and not all(
|
2024-09-03 05:08:13 -07:00
|
|
|
core.is_constant_shape(_aval_shape(aval)) for aval in ctx.avals_out):
|
2024-08-19 01:04:52 -07:00
|
|
|
kwargs["result_shapes"] = [
|
2024-09-03 05:08:13 -07:00
|
|
|
mlir.shape_tensor(mlir.eval_dynamic_shape_as_ivals(ctx, _aval_shape(aval)))
|
2024-08-19 01:04:52 -07:00
|
|
|
for aval in ctx.avals_out]
|
2024-05-29 09:38:08 -04:00
|
|
|
|
2025-02-21 09:45:14 -08:00
|
|
|
return mlir.custom_call(call_target_name, operands=operands, **kwargs)
|
|
|
|
|
|
|
|
return _lowering_op
|
|
|
|
|
|
|
|
|
|
|
|
def ffi_lowering(
|
|
|
|
call_target_name: str,
|
|
|
|
*,
|
|
|
|
operand_layouts: Sequence[FfiLayoutOptions] | None = None,
|
|
|
|
result_layouts: Sequence[FfiLayoutOptions] | None = None,
|
|
|
|
backend_config: Mapping[str, ir.Attribute] | str | None = None,
|
|
|
|
**lowering_args: Any
|
|
|
|
) -> mlir.LoweringRule:
|
|
|
|
"""Build a lowering rule for an foreign function interface (FFI) target.
|
|
|
|
|
|
|
|
By default, this lowering rule can use the input and output abstract values to
|
|
|
|
compute the input and output types and shapes for the custom call, assuming
|
|
|
|
row-major layouts.
|
|
|
|
|
|
|
|
Note that layouts passed to this function as tuples should be in
|
|
|
|
minor-to-major order (as expected by XLA) rather than major-to-minor as used
|
|
|
|
by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``.
|
|
|
|
|
|
|
|
If keyword arguments are passed to the lowering rule, these are treated as
|
|
|
|
attributes, and added to `backend_config`.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
call_target_name: The name of the custom call target.
|
|
|
|
operand_layouts: A sequence of layouts (dimension orders) for each operand.
|
|
|
|
By default, the operands are assumed to be row-major.
|
|
|
|
result_layouts: A sequence of layouts (dimension orders) for each result.
|
|
|
|
By default, the results are assumed to be row-major.
|
|
|
|
backend_config: Configuration data for the custom call. Any keyword
|
|
|
|
arguments passed to the lowering rule will added to this dictionary.
|
|
|
|
lowering_args: Any other arguments to :func:`mlir.custom_call` will also be
|
|
|
|
passed through if provided as extra arguments to this function.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def _lowering(
|
|
|
|
ctx: mlir.LoweringRuleContext, *operands: ir.Value, **params: Any
|
|
|
|
) -> Sequence[ir.Value | Sequence[ir.Value]]:
|
|
|
|
result = build_ffi_lowering_function(
|
|
|
|
call_target_name,
|
|
|
|
operand_layouts=operand_layouts,
|
|
|
|
result_layouts=result_layouts,
|
|
|
|
backend_config=backend_config,
|
|
|
|
**lowering_args,
|
|
|
|
)(ctx, *operands, **params)
|
|
|
|
|
|
|
|
return result.results # type: ignore
|
2024-05-29 09:38:08 -04:00
|
|
|
|
|
|
|
return _lowering
|
|
|
|
|
|
|
|
|
2024-09-03 05:08:13 -07:00
|
|
|
ResultMetadata = DuckTypedArray | core.AbstractToken
|
|
|
|
|
|
|
|
|
|
|
|
def _result_avals(results: Sequence[ResultMetadata]) -> tuple[core.AbstractValue, ...]:
|
|
|
|
avals: list[core.AbstractValue] = []
|
2024-11-07 19:46:55 -05:00
|
|
|
for idx, result in enumerate(results):
|
2024-09-03 05:08:13 -07:00
|
|
|
if isinstance(result, core.AbstractToken):
|
|
|
|
avals.append(result)
|
|
|
|
else:
|
2024-11-07 19:46:55 -05:00
|
|
|
if not hasattr(result, "shape") or not hasattr(result, "dtype"):
|
|
|
|
raise ValueError(
|
|
|
|
"All elements of result_shape_dtypes must have 'shape' and 'dtype' "
|
|
|
|
f"attributes. Got {result} at position {idx}.")
|
2024-09-03 05:08:13 -07:00
|
|
|
avals.append(core.ShapedArray(result.shape, result.dtype))
|
|
|
|
return tuple(avals)
|
2024-05-29 09:38:08 -04:00
|
|
|
|
|
|
|
|
2024-10-21 11:34:57 -04:00
|
|
|
def _check_compatible_avals(a: core.AbstractValue, b: core.AbstractValue) -> bool:
|
|
|
|
if isinstance(a, core.AbstractToken) and isinstance(b, core.AbstractToken):
|
|
|
|
return True
|
|
|
|
if getattr(a, "shape", ()) != getattr(b, "shape", ()):
|
|
|
|
return False
|
|
|
|
if getattr(a, "dtype", ()) != getattr(b, "dtype", ()):
|
|
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
def _convert_layouts_for_ffi_call(
|
|
|
|
avals: Sequence[core.AbstractValue],
|
|
|
|
layouts: Sequence[FfiLayoutOptions]) -> tuple[Sequence[int], ...]:
|
|
|
|
return tuple(
|
|
|
|
_convert_layout_for_lowering(
|
|
|
|
aval,
|
|
|
|
layout if layout is None or isinstance(layout, DeviceLocalLayout)
|
|
|
|
else layout[::-1]
|
|
|
|
)
|
|
|
|
for aval, layout in zip(avals, layouts))
|
|
|
|
|
|
|
|
|
2024-11-08 14:20:57 -08:00
|
|
|
# ffi_call() returns as many results as result_shape_dtypes.
|
|
|
|
@overload
|
|
|
|
def ffi_call(
|
|
|
|
target_name: str,
|
|
|
|
result_shape_dtypes: ResultMetadata,
|
|
|
|
*deprecated_args: ArrayLike,
|
|
|
|
has_side_effect: bool = ...,
|
|
|
|
vmap_method: str | None = ...,
|
|
|
|
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
|
|
|
|
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ...,
|
|
|
|
input_output_aliases: dict[int, int] | None = ...,
|
|
|
|
custom_call_api_version: int = ...,
|
|
|
|
legacy_backend_config: str | None = ...,
|
|
|
|
vectorized: bool | DeprecatedArg = ...,
|
|
|
|
**deprecated_kwargs: Any,
|
|
|
|
) -> Callable[..., Array] | Array:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
@overload
|
|
|
|
def ffi_call(
|
|
|
|
target_name: str,
|
|
|
|
result_shape_dtypes: Sequence[ResultMetadata],
|
|
|
|
*deprecated_args: ArrayLike,
|
|
|
|
has_side_effect: bool = ...,
|
|
|
|
vmap_method: str | None = ...,
|
|
|
|
input_layouts: Sequence[FfiLayoutOptions] | None = ...,
|
|
|
|
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = ...,
|
|
|
|
input_output_aliases: dict[int, int] | None = ...,
|
|
|
|
custom_call_api_version: int = ...,
|
|
|
|
legacy_backend_config: str | None = ...,
|
|
|
|
vectorized: bool | DeprecatedArg = ...,
|
|
|
|
**deprecated_kwargs: Any,
|
|
|
|
) -> Callable[..., Sequence[Array]] | Sequence[Array]:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
2024-06-07 11:47:04 -07:00
|
|
|
def ffi_call(
|
|
|
|
target_name: str,
|
2024-09-03 05:08:13 -07:00
|
|
|
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
|
2024-10-13 20:24:58 -04:00
|
|
|
*deprecated_args: ArrayLike,
|
2024-09-27 14:50:44 -04:00
|
|
|
has_side_effect: bool = False,
|
2024-07-19 17:24:33 -04:00
|
|
|
vmap_method: str | None = None,
|
2024-10-21 11:34:57 -04:00
|
|
|
input_layouts: Sequence[FfiLayoutOptions] | None = None,
|
|
|
|
output_layouts: FfiLayoutOptions | Sequence[FfiLayoutOptions] | None = None,
|
|
|
|
input_output_aliases: dict[int, int] | None = None,
|
|
|
|
custom_call_api_version: int = 4,
|
|
|
|
legacy_backend_config: str | None = None,
|
2024-07-19 17:24:33 -04:00
|
|
|
vectorized: bool | DeprecatedArg = DeprecatedArg(),
|
2024-10-13 20:24:58 -04:00
|
|
|
**deprecated_kwargs: Any,
|
|
|
|
) -> Callable[..., Array | Sequence[Array]] | Array | Sequence[Array]:
|
2024-06-07 11:47:04 -07:00
|
|
|
"""Call a foreign function interface (FFI) target.
|
|
|
|
|
2024-10-13 20:24:58 -04:00
|
|
|
See the :ref:`ffi-tutorial` tutorial for more information.
|
|
|
|
|
2024-06-07 11:47:04 -07:00
|
|
|
Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under
|
2024-07-19 17:24:33 -04:00
|
|
|
:func:`~jax.vmap` depends on the value of ``vmap_method``. See the
|
|
|
|
:func:`~jax.pure_callback` documenation for more details about the allowed
|
|
|
|
values and examples of their behavior.
|
|
|
|
|
|
|
|
The current default behavior is to use ``vmap_method="sequential"`` when
|
|
|
|
not specified, but this behavior is deprecated, and in the future, the
|
|
|
|
default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is
|
|
|
|
explicitly specified.
|
2024-06-07 11:47:04 -07:00
|
|
|
|
|
|
|
Args:
|
|
|
|
target_name: the name of the XLA FFI custom call target that was registered
|
2024-12-20 11:26:04 +00:00
|
|
|
using :func:`~jax.ffi.register_ffi_target`.
|
2024-06-07 11:47:04 -07:00
|
|
|
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
|
|
|
|
``dtype`` attributes which are expected to match the shape and dtype of
|
|
|
|
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often
|
|
|
|
used to define the elements of ``result_shape_dtypes``.
|
2024-09-03 05:08:13 -07:00
|
|
|
``jax.core.abstract_token`` may be used to represent a token-typed output.
|
2024-09-27 14:50:44 -04:00
|
|
|
has_side_effect: boolean specifying whether the custom call has side
|
|
|
|
effects. When ``True``, the FFI call will be executed even when the
|
|
|
|
outputs are not used.
|
2024-07-19 17:24:33 -04:00
|
|
|
vmap_method: string specifying how the FFI call transforms under
|
|
|
|
:func:`~jax.vmap` as described above.
|
2024-10-21 11:34:57 -04:00
|
|
|
input_layouts: a sequence of layouts for each input argument. In each case,
|
|
|
|
the layout can be (a) ``None`` indicating that this input is in default
|
|
|
|
row-major order, (b) a ``DeviceLocalLayout`` specifying the axis order,
|
|
|
|
or (c) a sequence of integers specifying the major-to-minor axis
|
|
|
|
ordering. Users who are familiar with XLA layouts should note that this
|
|
|
|
function expects layouts in major-to-minor order instead of the
|
|
|
|
minor-to-major order that XLA uses. For example, a batch of row-major
|
|
|
|
matrices could be specified using the layout ``[0, 1, 2]``, whereas a
|
|
|
|
batch of column-major matrices would have layout ``[0, 2, 1]``. In both
|
|
|
|
of these examples, the leading/batch dimension is the "slowest" axis. The
|
|
|
|
``input_layouts`` parameter should be used to request the memory layout
|
|
|
|
expected by the FFI call target, and XLA will ensure that the buffers
|
|
|
|
have the correct layouts before the handler is executed.
|
|
|
|
output_layouts: like ``input_layouts``, but specifying the required layouts
|
|
|
|
for the output arrays.
|
|
|
|
input_output_aliases: a dictionary where the keys are input indices and the
|
|
|
|
values are output indices. This mapping indicates which output arrays
|
|
|
|
alias specific input arrays.
|
|
|
|
custom_call_api_version: the version number of the custom call API
|
|
|
|
implemented by the FFI target ``target_name``. The only formally
|
|
|
|
supported version is the typed FFI API with ``custom_call_api_version=4``,
|
|
|
|
but earlier unsupported custom calls can be executed using this argument.
|
|
|
|
legacy_backend_config: for legacy targets implemented using
|
|
|
|
``custom_call_api_version<4``, attributes are passed using the opaque
|
|
|
|
string representation provided by this argument. This parameter cannot be
|
|
|
|
used with ``custom_call_api_version>=4``.
|
2024-06-07 11:47:04 -07:00
|
|
|
|
|
|
|
Returns:
|
2024-10-13 20:24:58 -04:00
|
|
|
A function that can be called with the input arrays as positional arguments
|
|
|
|
to execute the FFI handler. Any keyword arguments are passed as named
|
|
|
|
attributes to the FFI handler using XLA's FFI interface.
|
2024-06-07 11:47:04 -07:00
|
|
|
"""
|
2024-07-19 17:24:33 -04:00
|
|
|
if not isinstance(vectorized, DeprecatedArg) and not vectorized is None:
|
|
|
|
deprecations.warn(
|
|
|
|
"jax-callback-vectorized",
|
|
|
|
"The vectorized argument of ffi_call is deprecated and setting "
|
|
|
|
"it will soon raise an error. To avoid an error in the future, and to "
|
|
|
|
"suppress this warning, please use the vmap_method argument instead.",
|
|
|
|
stacklevel=2)
|
|
|
|
if vmap_method is not None:
|
|
|
|
raise ValueError(
|
|
|
|
"the vectorized and vmap_method arguments of ffi_call cannot "
|
|
|
|
"be used together. Please use the vmap_method argument.")
|
|
|
|
vmap_method = "legacy_vectorized" if vectorized else "sequential"
|
2025-02-11 06:08:32 -08:00
|
|
|
allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims",
|
|
|
|
"broadcast_all", "legacy_vectorized", None]
|
2024-07-19 17:24:33 -04:00
|
|
|
if vmap_method not in allowed_vmap_methods:
|
|
|
|
raise ValueError(
|
|
|
|
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
|
|
|
|
f"but got: {vmap_method}")
|
|
|
|
|
2024-10-21 11:34:57 -04:00
|
|
|
output_layouts_: Sequence[FfiLayoutOptions] | None
|
2024-06-07 11:47:04 -07:00
|
|
|
if isinstance(result_shape_dtypes, Sequence):
|
2024-10-21 11:34:57 -04:00
|
|
|
output_layouts_ = output_layouts # type: ignore
|
2024-06-07 11:47:04 -07:00
|
|
|
multiple_results = True
|
2024-09-03 05:08:13 -07:00
|
|
|
result_avals = _result_avals(result_shape_dtypes)
|
2024-06-07 11:47:04 -07:00
|
|
|
else:
|
|
|
|
multiple_results = False
|
2024-09-03 05:08:13 -07:00
|
|
|
result_avals = _result_avals((result_shape_dtypes,))
|
2024-10-21 11:34:57 -04:00
|
|
|
output_layouts_ = (output_layouts,) # type: ignore
|
|
|
|
|
|
|
|
if custom_call_api_version >= 4 and legacy_backend_config is not None:
|
|
|
|
raise ValueError(
|
|
|
|
"The use of the legacy_backend_config parameter requires "
|
|
|
|
f"custom_call_api_version < 4; got {custom_call_api_version}.")
|
2024-10-13 20:24:58 -04:00
|
|
|
|
|
|
|
def wrapped(*args: ArrayLike, **kwargs: Any):
|
2024-12-12 09:49:06 -08:00
|
|
|
in_avals = [core.get_aval(x) for x in args]
|
2024-10-21 11:34:57 -04:00
|
|
|
|
|
|
|
if input_layouts is None:
|
|
|
|
static_input_layouts = tuple(map(_convert_layout_for_lowering, in_avals))
|
|
|
|
else:
|
|
|
|
if len(input_layouts) != len(in_avals):
|
|
|
|
raise ValueError(
|
|
|
|
f"The number of input arguments ({len(in_avals)}) must equal the "
|
|
|
|
f"number of input layouts ({len(input_layouts)}).")
|
|
|
|
static_input_layouts = _convert_layouts_for_ffi_call(in_avals,
|
|
|
|
input_layouts)
|
|
|
|
if output_layouts_ is None:
|
|
|
|
static_output_layouts = tuple(map(_convert_layout_for_lowering,
|
|
|
|
result_avals))
|
|
|
|
else:
|
|
|
|
if len(output_layouts_) != len(result_avals):
|
|
|
|
raise ValueError(
|
|
|
|
f"The number of outputs ({len(result_avals)}) must equal the "
|
|
|
|
f"number of output layouts ({len(output_layouts_)}).")
|
|
|
|
static_output_layouts = _convert_layouts_for_ffi_call(result_avals,
|
|
|
|
output_layouts_)
|
|
|
|
|
|
|
|
static_input_output_aliases: tuple[tuple[int, int], ...] = ()
|
|
|
|
if input_output_aliases is not None:
|
|
|
|
for i_idx, o_idx in sorted(input_output_aliases.items()):
|
|
|
|
i_idx, o_idx = int(i_idx), int(o_idx)
|
|
|
|
if i_idx >= len(args):
|
|
|
|
raise ValueError(
|
|
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
|
|
|
f"with input index {i_idx} outside the range [0, "
|
|
|
|
f"{len(args)}).")
|
|
|
|
if o_idx >= len(result_avals):
|
|
|
|
raise ValueError(
|
|
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
|
|
|
f"with output index {o_idx} outside the range [0, "
|
|
|
|
f"{len(result_avals)}).")
|
|
|
|
in_aval = in_avals[i_idx]
|
|
|
|
out_aval = result_avals[o_idx]
|
|
|
|
if not _check_compatible_avals(in_aval, out_aval):
|
|
|
|
raise ValueError(
|
|
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
|
|
|
f"referring to an input with abstract value {in_aval} and an "
|
|
|
|
f"output with a different abstract value {out_aval}.")
|
|
|
|
if static_input_layouts[i_idx] != static_output_layouts[o_idx]:
|
|
|
|
raise ValueError(
|
|
|
|
f"input_output_aliases contains the mapping '{i_idx}:{o_idx}' "
|
|
|
|
f"referring to an input with layout {static_input_layouts[i_idx]} "
|
|
|
|
"and an output with a different layout "
|
|
|
|
f"{static_output_layouts[o_idx]}.")
|
|
|
|
static_input_output_aliases += ((i_idx, o_idx),)
|
|
|
|
|
2024-10-13 20:24:58 -04:00
|
|
|
results = ffi_call_p.bind(
|
|
|
|
*args,
|
|
|
|
result_avals=result_avals,
|
|
|
|
vectorized=vectorized,
|
|
|
|
vmap_method=vmap_method,
|
|
|
|
target_name=target_name,
|
|
|
|
has_side_effect=has_side_effect,
|
2024-10-21 11:34:57 -04:00
|
|
|
input_layouts=static_input_layouts,
|
|
|
|
output_layouts=static_output_layouts,
|
|
|
|
input_output_aliases=static_input_output_aliases,
|
|
|
|
custom_call_api_version=custom_call_api_version,
|
|
|
|
legacy_backend_config=legacy_backend_config,
|
2024-10-13 20:24:58 -04:00
|
|
|
attributes=_wrap_kwargs_hashable(kwargs),
|
|
|
|
)
|
|
|
|
if multiple_results:
|
2025-02-11 17:32:48 -08:00
|
|
|
if isinstance(result_shape_dtypes, tuple):
|
|
|
|
return tuple(results)
|
2024-10-13 20:24:58 -04:00
|
|
|
return results
|
|
|
|
else:
|
|
|
|
return results[0]
|
|
|
|
|
|
|
|
if deprecated_args or deprecated_kwargs:
|
|
|
|
deprecations.warn(
|
|
|
|
"jax-ffi-call-args",
|
|
|
|
"Calling ffi_call directly with input arguments is deprecated. "
|
|
|
|
"Instead, ffi_call should be used to construct a callable, which can "
|
|
|
|
"then be called with the appropriate inputs. For example,\n"
|
|
|
|
" ffi_call('target_name', output_type, x, argument=5)\n"
|
|
|
|
"should be replaced with\n"
|
|
|
|
" ffi_call('target_name', output_type)(x, argument=5)",
|
|
|
|
stacklevel=2)
|
|
|
|
return wrapped(*deprecated_args, **deprecated_kwargs)
|
2024-06-07 11:47:04 -07:00
|
|
|
else:
|
2024-10-13 20:24:58 -04:00
|
|
|
return wrapped
|
2024-06-07 11:47:04 -07:00
|
|
|
|
|
|
|
|
2024-09-26 15:02:58 -04:00
|
|
|
# ffi_call must support some small non-hashable input arguments, like np.arrays
|
|
|
|
# and dicts, to support calling FFI targets with array inputs or user defined
|
|
|
|
# structs. Since these arguments will eventually be embedded in the HLO as
|
|
|
|
# dense attributes, we assume that they are small and hash by making an
|
|
|
|
# immutable copy and hashing by value.
|
2024-10-13 20:24:58 -04:00
|
|
|
def _wrap_kwargs_hashable(kwargs: dict[str, Any]) -> Sequence[tuple[str, Any]]:
|
|
|
|
hashable_kwargs: list[tuple[str, Any]] = []
|
|
|
|
for k, v in sorted(kwargs.items()):
|
2024-09-26 15:02:58 -04:00
|
|
|
if isinstance(v, np.ndarray):
|
2024-10-13 20:24:58 -04:00
|
|
|
hashable_kwargs.append((k, HashableArray(v)))
|
2024-09-26 15:02:58 -04:00
|
|
|
elif isinstance(v, dict):
|
2024-10-13 20:24:58 -04:00
|
|
|
hashable_kwargs.append((k, HashableDict(v)))
|
2024-09-26 15:02:58 -04:00
|
|
|
else:
|
|
|
|
try:
|
|
|
|
hash(v)
|
|
|
|
except TypeError as e:
|
|
|
|
raise TypeError(
|
|
|
|
f"Non-hashable keyword argument to ffi_call {k}: {v}") from e
|
|
|
|
else:
|
2024-10-13 20:24:58 -04:00
|
|
|
hashable_kwargs.append((k, v))
|
|
|
|
return tuple(hashable_kwargs)
|
2024-09-26 15:02:58 -04:00
|
|
|
|
|
|
|
|
2024-10-13 20:24:58 -04:00
|
|
|
def _unwrap_kwargs_hashable(kwargs: Sequence[tuple[str, Any]]) -> dict[str, Any]:
|
2024-09-26 15:02:58 -04:00
|
|
|
unwrapped_kwargs: dict[str, Any] = {}
|
2024-10-13 20:24:58 -04:00
|
|
|
for k, v in kwargs:
|
2024-09-26 15:02:58 -04:00
|
|
|
if isinstance(v, HashableArray):
|
|
|
|
unwrapped_kwargs[k] = v.val
|
|
|
|
elif isinstance(v, HashableDict):
|
|
|
|
unwrapped_kwargs[k] = dict(v.val)
|
|
|
|
else:
|
|
|
|
unwrapped_kwargs[k] = v
|
|
|
|
return unwrapped_kwargs
|
|
|
|
|
|
|
|
|
|
|
|
class HashableArray:
|
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
assert isinstance(val, np.ndarray)
|
|
|
|
self.val = np.copy(val)
|
|
|
|
self.val.setflags(write=False)
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f"HashableArray({self.val})"
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash((self.val.shape, self.val.dtype, self.val.tobytes()))
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return isinstance(other, HashableArray) and np.array_equal(self.val, other.val)
|
|
|
|
|
|
|
|
|
|
|
|
class HashableDict:
|
|
|
|
__slots__ = ["val"]
|
|
|
|
|
|
|
|
def __init__(self, val):
|
|
|
|
assert isinstance(val, dict)
|
|
|
|
self.val = tuple(sorted(val.items()))
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
return f"HashableDict({dict(self.val)})"
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
return hash(self.val)
|
|
|
|
|
|
|
|
def __eq__(self, other):
|
|
|
|
return isinstance(other, HashableDict) and self.val == other.val
|
|
|
|
|
|
|
|
|
2024-09-27 14:50:44 -04:00
|
|
|
class FfiEffect(effects.Effect):
|
|
|
|
def __str__(self):
|
|
|
|
return "FFI"
|
|
|
|
|
2024-09-26 15:02:58 -04:00
|
|
|
|
2024-09-27 14:50:44 -04:00
|
|
|
_FfiEffect = FfiEffect()
|
|
|
|
effects.lowerable_effects.add_type(FfiEffect)
|
|
|
|
effects.control_flow_allowed_effects.add_type(FfiEffect)
|
|
|
|
|
|
|
|
|
2024-06-07 11:47:04 -07:00
|
|
|
def ffi_call_abstract_eval(
|
|
|
|
*avals_in,
|
2024-09-03 05:08:13 -07:00
|
|
|
result_avals: tuple[core.AbstractValue, ...],
|
2024-09-27 14:50:44 -04:00
|
|
|
has_side_effect: bool,
|
2024-10-21 11:34:57 -04:00
|
|
|
**_,
|
2024-06-07 11:47:04 -07:00
|
|
|
):
|
2024-10-21 11:34:57 -04:00
|
|
|
del avals_in # unused
|
2024-09-27 14:50:44 -04:00
|
|
|
effects = {_FfiEffect} if has_side_effect else core.no_effects
|
|
|
|
return result_avals, effects
|
2024-06-07 11:47:04 -07:00
|
|
|
|
|
|
|
|
2024-10-21 11:34:57 -04:00
|
|
|
def ffi_call_jvp(*args, target_name, **_):
|
|
|
|
del args
|
2024-06-07 11:47:04 -07:00
|
|
|
raise ValueError(
|
|
|
|
f"The FFI call to `{target_name}` cannot be differentiated. "
|
|
|
|
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")
|
|
|
|
|
|
|
|
|
2024-10-21 11:34:57 -04:00
|
|
|
def ffi_call_transpose(*args, target_name, **_):
|
|
|
|
del args
|
2024-06-07 11:47:04 -07:00
|
|
|
raise ValueError(
|
|
|
|
f"The FFI call to `{target_name}` cannot be differentiated. "
|
|
|
|
"You can use `jax.custom_jvp` or `jax.custom_jvp` to add support.")
|
|
|
|
|
|
|
|
|
|
|
|
def ffi_call_lowering(
|
|
|
|
ctx: mlir.LoweringRuleContext,
|
|
|
|
*operands: ir.Value,
|
|
|
|
target_name: str,
|
2024-09-27 14:50:44 -04:00
|
|
|
has_side_effect: bool,
|
2024-10-21 11:34:57 -04:00
|
|
|
input_layouts: Sequence[Sequence[int]],
|
|
|
|
output_layouts: Sequence[Sequence[int]],
|
|
|
|
input_output_aliases: Sequence[tuple[int, int]],
|
|
|
|
custom_call_api_version: int,
|
|
|
|
legacy_backend_config: str | None,
|
2024-10-13 20:24:58 -04:00
|
|
|
attributes: Sequence[tuple[str, Any]],
|
2024-10-21 11:34:57 -04:00
|
|
|
**_,
|
2024-06-07 11:47:04 -07:00
|
|
|
) -> Sequence[ir.Value]:
|
2024-10-21 11:34:57 -04:00
|
|
|
rule = ffi_lowering(target_name, has_side_effect=has_side_effect,
|
|
|
|
operand_layouts=input_layouts,
|
|
|
|
result_layouts=output_layouts,
|
|
|
|
operand_output_aliases=dict(input_output_aliases),
|
|
|
|
api_version=custom_call_api_version,
|
|
|
|
backend_config=legacy_backend_config)
|
2024-10-13 20:24:58 -04:00
|
|
|
return rule(ctx, *operands, **_unwrap_kwargs_hashable(attributes))
|
2024-06-07 11:47:04 -07:00
|
|
|
|
|
|
|
|
2025-02-21 09:45:14 -08:00
|
|
|
def ffi_batching_rule(
|
|
|
|
prim,
|
|
|
|
args,
|
|
|
|
dims,
|
|
|
|
*,
|
|
|
|
vectorized: bool | None | DeprecatedArg,
|
|
|
|
vmap_method: str | None,
|
|
|
|
result_avals: Sequence[core.ShapedArray],
|
|
|
|
**kwargs: Any,
|
|
|
|
):
|
|
|
|
if isinstance(vectorized, DeprecatedArg) and vmap_method is None:
|
|
|
|
deprecations.warn(
|
|
|
|
"jax-callback-vectorized",
|
|
|
|
f"The default behavior of {prim.name} under vmap will soon "
|
|
|
|
"change. Currently, the default behavior is to generate a sequential "
|
|
|
|
"vmap (i.e. a loop), but in the future the default will be to raise "
|
|
|
|
"an error. To keep the current default, set vmap_method='sequential'.",
|
|
|
|
stacklevel=6)
|
|
|
|
vmap_method = "sequential"
|
|
|
|
|
|
|
|
axis_size, = {a.shape[d] for a, d in zip(args, dims)
|
|
|
|
if d is not batching.not_mapped}
|
|
|
|
new_args = [arg if dim is batching.not_mapped else
|
|
|
|
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
|
|
|
batched_result_avals = tuple(
|
|
|
|
core.unmapped_aval(axis_size, 0, aval) for aval in result_avals)
|
|
|
|
|
|
|
|
# For FFI calls we must update the layouts. We handle the output layouts
|
|
|
|
# here, but the input layout updates depend on the vmap_method parameter.
|
|
|
|
if (
|
|
|
|
vmap_method not in ("sequential", "sequential_unrolled") and
|
|
|
|
kwargs.get("output_layouts") is not None
|
|
|
|
):
|
|
|
|
kwargs["output_layouts"] = tuple(
|
|
|
|
None if layout is None else tuple(n + 1 for n in layout) + (0,)
|
|
|
|
for layout in kwargs["output_layouts"])
|
|
|
|
|
|
|
|
if vmap_method == "legacy_vectorized":
|
|
|
|
# This method is kept to support the behavior that was previously exposed
|
|
|
|
# when using `vectorized=True`.
|
|
|
|
if kwargs.get("input_layouts") is not None:
|
|
|
|
kwargs["input_layouts"] = tuple(
|
|
|
|
layout if d is batching.not_mapped else
|
|
|
|
(None if layout is None else tuple(n + 1 for n in layout) + (0,))
|
|
|
|
for layout, d in zip(kwargs["input_layouts"], dims))
|
|
|
|
outvals = prim.bind(
|
|
|
|
*new_args,
|
|
|
|
vectorized=vectorized,
|
|
|
|
vmap_method=vmap_method,
|
|
|
|
result_avals=batched_result_avals,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif vmap_method == "expand_dims" or vmap_method == "broadcast_all":
|
|
|
|
size = axis_size if vmap_method == "broadcast_all" else 1
|
|
|
|
bcast_args = [
|
|
|
|
jax.lax.broadcast(x, (size,)) if d is batching.not_mapped else x
|
|
|
|
for x, d in zip(new_args, dims)]
|
|
|
|
if kwargs.get("input_layouts") is not None:
|
|
|
|
kwargs["input_layouts"] = tuple(
|
|
|
|
None if layout is None else tuple(n + 1 for n in layout) + (0,)
|
|
|
|
for layout in kwargs["input_layouts"])
|
|
|
|
outvals = prim.bind(
|
|
|
|
*bcast_args,
|
|
|
|
vectorized=vectorized,
|
|
|
|
vmap_method=vmap_method,
|
|
|
|
result_avals=batched_result_avals,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
elif vmap_method == "sequential" or vmap_method == "sequential_unrolled":
|
|
|
|
is_batched = [d is not batching.not_mapped for d in dims]
|
|
|
|
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
|
|
|
def _batch_fun(batched_args):
|
|
|
|
merged_args = util.merge_lists(is_batched, unbatched_args, batched_args)
|
|
|
|
return prim.bind(
|
|
|
|
*merged_args,
|
|
|
|
result_avals=result_avals,
|
|
|
|
vectorized=vectorized,
|
|
|
|
vmap_method=vmap_method,
|
|
|
|
**kwargs,
|
|
|
|
)
|
|
|
|
unroll = vmap_method == "sequential_unrolled"
|
|
|
|
g = lambda _, x: ((), _batch_fun(x))
|
|
|
|
_, outvals = jax.lax.scan(g, (), batched_args, unroll=unroll)
|
|
|
|
else:
|
|
|
|
raise NotImplementedError(
|
|
|
|
f"vmap is only supported for the {prim.name} primitive when vmap_method "
|
|
|
|
"is one of 'sequential', 'sequential_unrolled', 'expand_dims', "
|
|
|
|
f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.")
|
|
|
|
return tuple(outvals), (0,) * len(outvals)
|
|
|
|
|
|
|
|
|
2024-06-07 11:47:04 -07:00
|
|
|
ffi_call_p = core.Primitive("ffi_call")
|
|
|
|
ffi_call_p.multiple_results = True
|
2024-09-26 15:02:58 -04:00
|
|
|
dispatch.simple_impl(ffi_call_p)
|
2024-09-27 14:50:44 -04:00
|
|
|
ffi_call_p.def_effectful_abstract_eval(ffi_call_abstract_eval)
|
2024-06-07 11:47:04 -07:00
|
|
|
ad.primitive_jvps[ffi_call_p] = ffi_call_jvp
|
|
|
|
ad.primitive_transposes[ffi_call_p] = ffi_call_transpose
|
|
|
|
batching.primitive_batchers[ffi_call_p] = functools.partial(
|
2025-02-21 09:45:14 -08:00
|
|
|
ffi_batching_rule, ffi_call_p)
|
2024-06-07 11:47:04 -07:00
|
|
|
mlir.register_lowering(ffi_call_p, ffi_call_lowering)
|