Deprecate the vectorized argument to pure_callback and ffi_call.

This commit is contained in:
Dan Foreman-Mackey 2024-07-19 17:24:33 -04:00
parent 816947b656
commit 1d27d420ac
9 changed files with 316 additions and 99 deletions

View File

@ -46,6 +46,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead.
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
`jax.errors.JaxRuntimeError` instead.
* The default behavior of {func}`jax.pure_callback` and
{func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
the `vectorized` parameter to those functions. The `vmap_method` parameter
should be used instead for better defined behavior. See the discussion in
{jax-issue}`#23881` for more details.
* Deletion:
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation

View File

@ -303,9 +303,9 @@
" # type (which corresponds to numpy's `float32` type), and it must be a\n",
" # static parameter (i.e. not a JAX array).\n",
" eps=np.float32(eps),\n",
" # The `vectorized` parameter controls this function's behavior under `vmap`\n",
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
" # as discussed below.\n",
" vectorized=True,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )\n",
"\n",
"\n",
@ -325,7 +325,7 @@
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n",
"Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n",
"\n",
"The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"\n",
"```{tip}\n",
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n",
@ -336,19 +336,29 @@
"(ffi-call-vmap)=\n",
"### Batching with `vmap`\n",
"\n",
"All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n",
"By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
"This default implementation is general purpose, but it doesn't parallelize very well.\n",
"But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n",
"{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n",
"\n",
"The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n",
"The simplest `vmap_method` is `\"sequential\"`.\n",
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
"This implementation is general purpose, but it doesn't parallelize very well.\n",
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n",
"\n",
"In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n",
"The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n",
"Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n",
"\n",
"```python\n",
"ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n",
"```\n",
"\n",
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:"
"```{tip}\n",
"Note that things get a bit more complicated when we have multiple input arguments.\n",
"For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n",
"The documentation for {func}`~jax.pure_callback` includes some examples of this\n",
"```\n",
"\n",
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:"
]
},
{
@ -380,7 +390,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:"
"Using `vmap_method=\"sequential\"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:"
]
},
{
@ -389,24 +399,24 @@
"metadata": {},
"outputs": [],
"source": [
"def rms_norm_not_vectorized(x, eps=1e-5):\n",
"def rms_norm_sequential(x, eps=1e-5):\n",
" return jex.ffi.ffi_call(\n",
" \"rms_norm\",\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
" x,\n",
" eps=np.float32(eps),\n",
" vectorized=False, # This is the default behavior\n",
" vmap_method=\"sequential\",\n",
" )\n",
"\n",
"\n",
"jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)"
"jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
]
},
{
@ -454,7 +464,7 @@
" ),\n",
" x,\n",
" eps=np.float32(eps),\n",
" vectorized=True,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )\n",
" return y, (res, x)\n",
"\n",
@ -471,7 +481,7 @@
" res,\n",
" x,\n",
" ct,\n",
" vectorized=True,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" ),\n",
" )\n",
"\n",
@ -561,7 +571,7 @@
" out_type,\n",
" x,\n",
" eps=np.float32(eps),\n",
" vectorized=True,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" )\n",
"\n",
" return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n",

View File

@ -264,9 +264,9 @@ def rms_norm(x, eps=1e-5):
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
eps=np.float32(eps),
# The `vectorized` parameter controls this function's behavior under `vmap`
# The `vmap_method` parameter controls this function's behavior under `vmap`
# as discussed below.
vectorized=True,
vmap_method="broadcast_fullrank",
)
@ -282,7 +282,7 @@ It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_cal
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.
Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.
The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
```{tip}
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.
@ -293,19 +293,29 @@ One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support so
(ffi-call-vmap)=
### Batching with `vmap`
All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.
By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
This default implementation is general purpose, but it doesn't parallelize very well.
But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.
{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.
The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.
The simplest `vmap_method` is `"sequential"`.
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
This implementation is general purpose, but it doesn't parallelize very well.
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation.
In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior.
The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.
Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:
```python
ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
```
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:
```{tip}
Note that things get a bit more complicated when we have multiple input arguments.
For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method.
The documentation for {func}`~jax.pure_callback` includes some examples of this
```
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box:
```{code-cell} ipython3
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
@ -317,23 +327,23 @@ We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms
jax.make_jaxpr(jax.vmap(rms_norm))(x)
```
If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:
Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:
```{code-cell} ipython3
def rms_norm_not_vectorized(x, eps=1e-5):
def rms_norm_sequential(x, eps=1e-5):
return jex.ffi.ffi_call(
"rms_norm",
jax.ShapeDtypeStruct(x.shape, x.dtype),
x,
eps=np.float32(eps),
vectorized=False, # This is the default behavior
vmap_method="sequential",
)
jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)
jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
```
If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
+++
@ -372,7 +382,7 @@ def rms_norm_fwd(x, eps=1e-5):
),
x,
eps=np.float32(eps),
vectorized=True,
vmap_method="broadcast_fullrank",
)
return y, (res, x)
@ -389,7 +399,7 @@ def rms_norm_bwd(eps, res, ct):
res,
x,
ct,
vectorized=True,
vmap_method="broadcast_fullrank",
),
)
@ -469,7 +479,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
out_type,
x,
eps=np.float32(eps),
vectorized=True,
vmap_method="broadcast_fullrank",
)
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))

View File

@ -60,8 +60,7 @@ def rms_norm(x, eps=1e-5):
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
eps=np.float32(eps),
# The `vectorized` parameter controls this function's behavior under `vmap`.
vectorized=True,
vmap_method="broadcast_fullrank",
)
@ -74,7 +73,7 @@ def rms_norm_fwd(x, eps=1e-5):
),
x,
eps=np.float32(eps),
vectorized=True,
vmap_method="broadcast_fullrank",
)
return y, (res, x)
@ -91,7 +90,7 @@ def rms_norm_bwd(eps, res, ct):
res,
x,
ct,
vectorized=True,
vmap_method="broadcast_fullrank",
),
)

View File

@ -22,6 +22,7 @@ from typing import Any
import jax
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
@ -31,13 +32,18 @@ from jax._src import util
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lax.control_flow.loops import map as lax_map
from jax._src.lib import xla_client as xc
from jax._src.sharding_impls import SingleDeviceSharding
from jax._src.typing import DeprecatedArg
import numpy as np
logger = logging.getLogger(__name__)
# TODO(dfm): Remove after 6 months.
# Added Oct 1, 2024
deprecations.register("jax-callback-vectorized")
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
pure_callback_p = core.Primitive("pure_callback")
@ -45,6 +51,7 @@ pure_callback_p.multiple_results = True
dispatch.prim_requires_devices_during_lowering.add(pure_callback_p)
map, unsafe_map = util.safe_map, map
zip, unsafe_zip = util.safe_zip, zip
@dataclasses.dataclass(frozen=True)
@ -69,9 +76,10 @@ def pure_callback_impl(
result_avals,
callback: _FlatCallback,
sharding: SingleDeviceSharding | None,
vectorized: bool,
vectorized: bool | DeprecatedArg,
vmap_method: str | None,
):
del sharding, vectorized, result_avals
del sharding, vectorized, vmap_method, result_avals
try:
cpu_device, *_ = jax.local_devices(backend="cpu")
except RuntimeError as e:
@ -99,9 +107,10 @@ def pure_callback_abstract_eval(
callback: _FlatCallback,
result_avals,
sharding: SingleDeviceSharding | None,
vectorized: bool,
vectorized: bool | DeprecatedArg,
vmap_method: str | None,
):
del avals, callback, sharding, vectorized
del avals, callback, sharding, vectorized, vmap_method
return result_avals
@ -129,25 +138,51 @@ def callback_batching_rule(
args,
dims,
*,
vectorized: bool,
vectorized: bool | None | DeprecatedArg,
vmap_method: str | None,
result_avals: Sequence[core.ShapedArray],
**kwargs: Any,
):
axis_size = next(a.shape[d] for a, d in zip(args, dims)
if d is not batching.not_mapped)
if isinstance(vectorized, DeprecatedArg) and vmap_method is None:
deprecations.warn(
"jax-callback-vectorized",
f"The default behavior of {prim.name} under vmap will soon "
"change. Currently, the default behavior is to generate a sequential "
"vmap (i.e. a loop), but in the future the default will be to raise "
"an error. To keep the current default, set vmap_method='sequential'.",
stacklevel=6)
vmap_method = "sequential"
axis_size, = {a.shape[d] for a, d in zip(args, dims)
if d is not batching.not_mapped}
new_args = [arg if dim is batching.not_mapped else
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
if vectorized:
result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
for aval in result_avals)
batched_result_avals = tuple(
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval)
for aval in result_avals)
if vmap_method == "legacy_vectorized":
# This method is kept to support the behavior that was previously exposed
# when using `vectorized=True`.
outvals = prim.bind(
*new_args,
vectorized=vectorized,
result_avals=result_avals,
vmap_method=vmap_method,
result_avals=batched_result_avals,
**kwargs,
)
else:
elif vmap_method == "broadcast" or vmap_method == "broadcast_fullrank":
size = axis_size if vmap_method == "broadcast_fullrank" else 1
bcast_args = [
lax.broadcast(x, (size,)) if d is batching.not_mapped else x
for x, d in zip(new_args, dims)]
outvals = prim.bind(
*bcast_args,
vectorized=vectorized,
vmap_method=vmap_method,
result_avals=batched_result_avals,
**kwargs,
)
elif vmap_method == "sequential":
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(batched_args):
@ -156,9 +191,15 @@ def callback_batching_rule(
*merged_args,
result_avals=result_avals,
vectorized=vectorized,
vmap_method=vmap_method,
**kwargs,
)
outvals = lax_map(_batch_fun, batched_args)
else:
raise NotImplementedError(
f"vmap is only supported for the {prim.name} primitive when vmap_method "
"is one of 'sequential', 'broadcast', 'broadcast_fullrank', or "
"'legacy_vectorized'.")
return tuple(outvals), (0,) * len(outvals)
@ -261,7 +302,8 @@ def pure_callback(
result_shape_dtypes: Any,
*args: Any,
sharding: SingleDeviceSharding | None = None,
vectorized: bool = False,
vectorized: bool | None | DeprecatedArg = DeprecatedArg(),
vmap_method: str | None = None,
**kwargs: Any,
):
"""Calls a pure Python callback. Works under :func:`jit`/:func:`~vmap`/etc.
@ -279,17 +321,25 @@ def pure_callback(
`jit`-decorated function has no data dependence on its value. Pure callbacks
may also be reordered if data-dependence allows.
When `vmap`-ed the behavior will depend on the value of the
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
is assumed to obey
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
Therefore, the callback will be called directly on batched inputs (where the
batch axes are the leading dimensions). Additionally, the callbacks should
return outputs that have corresponding leading batch axes. If not vectorized
``callback`` will be mapped sequentially across the batched axis.
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
to set ``vectorized=True`` because the ``np.matmul`` function handles
arbitrary leading batch dimensions.
When `vmap`-ed the behavior will depend on the value of the ``vmap_method``.
* Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method``
is deprecated and it will eventually raise ``NotImplementedError``.
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
the batched arugments, calling ``callback`` once for each batch element.
* ``vmap_method="broadcast"`` calls ``callback`` with new axes of size ``1``
added as the leading dimension unbatched inputs.
* ``vmap_method="broadcast_fullrank"`` behaves like ``broadcast``, but the
inputs are tiled to the expected batched shape.
If necessary, the legacy behavior provided by the deprecated
``vectorized=True`` argument can be recovered using
``vmap_method="legacy_vectorized"``.
The current default behavior is to use ``vmap_method="sequential"`` when
not specified, but this behavior is deprecated, and in the future, the
default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is
explicitly specified.
Args:
callback: function to execute on the host. The callback is assumed to be a pure
@ -303,8 +353,8 @@ def pure_callback(
*args: arguments to be passed to the callback function
sharding: optional sharding that specifies the device from which the callback should
be invoked.
vectorized: boolean specifying whether the callback function can operate in a
vectorized manner.
vmap_method: string specifying how the callback transforms under
:func:`~jax.vmap` as described above.
**kwargs: keyword arguments to be passed to the callback function
Returns:
@ -316,8 +366,62 @@ def pure_callback(
- :func:`jax.debug.callback`: callback designed for general-purpose debugging.
- :func:`jax.debug.print`: callback designed for printing.
Examples:
The behavior of ``pure_callback`` under :func:`~jax.vmap` is controlled by
the ``vmap_method`` argument as described above. It is useful to consider
some explicit examples that demonstrate the semantics. For example,
consider the following function:
>>> def callback(x, y):
... print(jnp.shape(x), jnp.shape(y))
... return x + y
>>> def fun(x, y, *, vmap_method):
... shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y))
... dtype = jnp.result_type(x, y)
... out_type = jax.ShapeDtypeStruct(shape, dtype)
... return jax.pure_callback(callback, out_type, x, y,
... vmap_method=vmap_method)
Calling this with ``vmap_method="broadcast"`` adds a new axis of size ``1``
to ``y``:
>>> from functools import partial
>>> x = jnp.arange(4)
>>> y = 1.0
>>> jax.vmap(partial(fun, vmap_method="broadcast"), in_axes=(0, None))(x, y)
(4,) (1,)
Array([1., 2., 3., 4.], dtype=float32)
Whereas, ``vmap_method="broadcast_fullrank"`` adds an axis of size ``4`` to
``y``:
>>> jax.vmap(partial(fun, vmap_method="broadcast_fullrank"),
... in_axes=(0, None))(x, y)
(4,) (4,)
Array([1., 2., 3., 4.], dtype=float32)
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
"""
if not isinstance(vectorized, DeprecatedArg) and not vectorized is None:
deprecations.warn(
"jax-callback-vectorized",
"The vectorized argument of jax.pure_callback is deprecated and setting "
"it will soon raise an error. To avoid an error in the future, and to "
"suppress this warning, please use the vmap_method argument instead.",
stacklevel=2)
if vmap_method is not None:
raise ValueError(
"the vectorized and vmap_method arguments of jax.pure_callback cannot "
"be used together. Please use the vmap_method argument.")
vmap_method = "legacy_vectorized" if vectorized else "sequential"
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
"legacy_vectorized", None]
if vmap_method not in allowed_vmap_methods:
raise ValueError(
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
f"but got: {vmap_method}")
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
result_avals = tree_util.tree_map(
@ -329,6 +433,7 @@ def pure_callback(
result_avals=tuple(flat_result_avals),
sharding=sharding,
vectorized=vectorized,
vmap_method=vmap_method,
)
return tree_util.tree_unflatten(out_tree, out_flat)

View File

@ -23,6 +23,7 @@ from typing import Any
import numpy as np
from jax._src import core
from jax._src import deprecations
from jax._src import dispatch
from jax._src import effects
from jax._src import util
@ -34,7 +35,8 @@ from jax._src.layout import DeviceLocalLayout
from jax._src.lib import jaxlib
from jax._src.lib import xla_client
from jax._src.lib.mlir import ir
from jax._src.typing import Array, ArrayLike, DuckTypedArray, Shape
from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray,
Shape)
map, unsafe_map = util.safe_map, map
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None
@ -199,23 +201,22 @@ def ffi_call(
target_name: str,
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
*args: ArrayLike,
vectorized: bool = False,
has_side_effect: bool = False,
vmap_method: str | None = None,
vectorized: bool | DeprecatedArg = DeprecatedArg(),
**kwargs: Any,
) -> Array | list[Array]:
"""Call a foreign function interface (FFI) target.
Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under
:func:`~jax.vmap` depends on the value of ``vectorized``. When ``vectorized``
is ``True``, the FFI target is assumed to satisfy: ``ffi_call(xs) ==
jnp.stack([ffi_call(x) for x in xs])``. In other words, calling the FFI target
with an extra leading dimension should return the same result as calling it
within a loop and stacking along the zeroth axis. Therefore, the FFI target
will be called directly on batched inputs (where the batch axes are the
leading dimensions). Additionally, the callbacks should return outputs that
have corresponding leading batch axes. If ``vectorized`` is ``False`` (the
default behavior), transforming this ``ffi_call`` under :func:`~jax.vmap` will
result in a :func:`~jax.lax.scan` with the ``ffi_call`` in the body.
:func:`~jax.vmap` depends on the value of ``vmap_method``. See the
:func:`~jax.pure_callback` documenation for more details about the allowed
values and examples of their behavior.
The current default behavior is to use ``vmap_method="sequential"`` when
not specified, but this behavior is deprecated, and in the future, the
default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is
explicitly specified.
Args:
target_name: the name of the XLA FFI custom call target that was registered
@ -226,11 +227,11 @@ def ffi_call(
used to define the elements of ``result_shape_dtypes``.
``jax.core.abstract_token`` may be used to represent a token-typed output.
*args: the arguments passed to the custom call.
vectorized: boolean specifying whether the FFI call can operate in a
vectorized manner, as described above.
has_side_effect: boolean specifying whether the custom call has side
effects. When ``True``, the FFI call will be executed even when the
outputs are not used.
vmap_method: string specifying how the FFI call transforms under
:func:`~jax.vmap` as described above.
**kwargs: keyword arguments that are passed as named attributes to the
custom call using XLA's FFI interface.
@ -238,6 +239,25 @@ def ffi_call(
One or more :class:`~jax.Array` objects whose shapes and dtypes match
``result_shape_dtypes``.
"""
if not isinstance(vectorized, DeprecatedArg) and not vectorized is None:
deprecations.warn(
"jax-callback-vectorized",
"The vectorized argument of ffi_call is deprecated and setting "
"it will soon raise an error. To avoid an error in the future, and to "
"suppress this warning, please use the vmap_method argument instead.",
stacklevel=2)
if vmap_method is not None:
raise ValueError(
"the vectorized and vmap_method arguments of ffi_call cannot "
"be used together. Please use the vmap_method argument.")
vmap_method = "legacy_vectorized" if vectorized else "sequential"
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
"legacy_vectorized", None]
if vmap_method not in allowed_vmap_methods:
raise ValueError(
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
f"but got: {vmap_method}")
if isinstance(result_shape_dtypes, Sequence):
multiple_results = True
result_avals = _result_avals(result_shape_dtypes)
@ -248,6 +268,7 @@ def ffi_call(
*args,
result_avals=result_avals,
vectorized=vectorized,
vmap_method=vmap_method,
target_name=target_name,
has_side_effect=has_side_effect,
**_wrap_kwargs_hashable(kwargs),
@ -342,11 +363,12 @@ def ffi_call_abstract_eval(
*avals_in,
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool,
vectorized: bool | DeprecatedArg,
vmap_method: str | None,
has_side_effect: bool,
**kwargs: Any,
):
del avals_in, target_name, vectorized, kwargs
del avals_in, target_name, vectorized, vmap_method, kwargs
effects = {_FfiEffect} if has_side_effect else core.no_effects
return result_avals, effects
@ -370,11 +392,12 @@ def ffi_call_lowering(
*operands: ir.Value,
result_avals: tuple[core.AbstractValue, ...],
target_name: str,
vectorized: bool,
vectorized: bool | DeprecatedArg,
vmap_method: str | None,
has_side_effect: bool,
**kwargs: Any,
) -> Sequence[ir.Value]:
del result_avals, vectorized
del result_avals, vectorized, vmap_method
rule = ffi_lowering(target_name, has_side_effect=has_side_effect)
return rule(ctx, *operands, **_unwrap_kwargs_hashable(kwargs))

View File

@ -245,10 +245,11 @@ class FfiTest(jtu.JaxTestCase):
@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
vectorized=(False, True),
vmap_method=("broadcast", "broadcast_fullrank", "sequential",
"legacy_vectorized"),
)
@jtu.run_on_devices("gpu")
def testFfiCallBatching(self, shape, dtype, vectorized):
def testFfiCallBatching(self, shape, dtype, vmap_method):
shape = (10,) + shape
pivots_size = shape[-1]
permutation_size = 2 * pivots_size
@ -256,15 +257,29 @@ class FfiTest(jtu.JaxTestCase):
pivots = jnp.broadcast_to(pivots, shape)
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation(
x, permutation_size, vectorized=vectorized))(pivots)
x, permutation_size, vmap_method=vmap_method))(pivots)
self.assertArraysEqual(actual, expected)
@jtu.run_on_devices("gpu")
def testVectorizedDeprecation(self):
pivots_size = 4
shape = (10, pivots_size)
permutation_size = 2 * pivots_size
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1,
dtype=np.int32)
pivots = jnp.broadcast_to(pivots, shape)
with self.assertWarns(DeprecationWarning):
ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True)
with self.assertWarns(DeprecationWarning):
jax.vmap(
lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size))(pivots)
# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation`
# custom call target because that's the only one in jaxlib that uses the
# new FFI interface. Once more are available, consider using something that
# can be run on multiple platforms.
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True):
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs):
return jex.ffi.ffi_call(
"cu_lu_pivots_to_permutation",
jax.ShapeDtypeStruct(
@ -272,7 +287,7 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True)
dtype=pivots.dtype,
),
pivots,
vectorized=vectorized,
**kwargs,
)

View File

@ -1724,7 +1724,8 @@ class HostCallbackCallTest(jtu.JaxTestCase):
"batching rules are implemented only for id_tap, not for call"):
jax.vmap(fun)(np.ones((2, 3)))
else:
jax.vmap(fun)(np.ones((2, 3)))
with jtu.ignore_warning(category=DeprecationWarning):
jax.vmap(fun)(np.ones((2, 3)))
def test_call_error_bad_result_shape(self):
with self.assertRaisesRegex(

View File

@ -635,13 +635,13 @@ class PureCallbackTest(jtu.JaxTestCase):
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(np.sin, x, x)
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
out = f(jnp.arange(4.))
np.testing.assert_allclose(out, np.sin(np.arange(4.)))
@jax.jit
def g(x):
return jax.pure_callback(np.sin, x, x)
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
@ -649,7 +649,8 @@ class PureCallbackTest(jtu.JaxTestCase):
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y)
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
vmap_method="sequential")
out = h(jnp.arange(4.), 4.)
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + 4.,
rtol=1E-7, check_dtypes=False)
@ -658,7 +659,8 @@ class PureCallbackTest(jtu.JaxTestCase):
@functools.partial(jax.vmap)
def h(x, y):
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y)
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
vmap_method="sequential")
out = h(jnp.arange(4.), jnp.arange(10., 14.))
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.),
rtol=1E-7, check_dtypes=False)
@ -667,7 +669,8 @@ class PureCallbackTest(jtu.JaxTestCase):
@functools.partial(jax.vmap, in_axes=1, out_axes=1)
def h(x, y):
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y)
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
vmap_method="sequential")
out = h(jnp.arange(4.)[None], jnp.arange(10., 14.)[None])
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10.,
14.)[None],
@ -682,7 +685,7 @@ class PureCallbackTest(jtu.JaxTestCase):
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(cb, x, x)
return jax.pure_callback(cb, x, x, vmap_method="sequential")
np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
@ -693,7 +696,7 @@ class PureCallbackTest(jtu.JaxTestCase):
@jax.jit
@jax.vmap
def g(x):
return jax.pure_callback(cb2, x, x, vectorized=True)
return jax.pure_callback(cb2, x, x, vmap_method="broadcast")
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
@ -701,7 +704,7 @@ class PureCallbackTest(jtu.JaxTestCase):
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
vectorized=True)
vmap_method="broadcast")
out = h(jnp.arange(4.), 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
@ -709,7 +712,7 @@ class PureCallbackTest(jtu.JaxTestCase):
@functools.partial(jax.vmap, in_axes=(1, None), out_axes=1)
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
vectorized=True)
vmap_method="legacy_vectorized")
out = h(jnp.arange(4.)[None], 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.)
@ -722,7 +725,7 @@ class PureCallbackTest(jtu.JaxTestCase):
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(cb, x, x, vectorized=True)
return jax.pure_callback(cb, x, x, vmap_method="broadcast")
with self.assertRaises(RuntimeError):
f(jnp.arange(4.))
@ -981,6 +984,52 @@ class PureCallbackTest(jtu.JaxTestCase):
out = jax.pure_callback(f, jax.ShapeDtypeStruct(x.shape, x.dtype), x)
np.testing.assert_allclose(out, 2 * jnp.log(x + 1))
def test_vmap_method_raise(self):
@jax.vmap
def f(x):
# Setting vectorized to None disables the current default behavior of
# falling back on sequential.
return jax.pure_callback(np.sin, x, x, vectorized=None)
with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"):
f(jnp.arange(4.))
def test_deprecated_vectorized(self):
def f(x, **kwargs):
return jax.pure_callback(np.sin, x, x, **kwargs)
with self.assertWarnsRegex(DeprecationWarning, "The default behavior"):
jax.vmap(f)(jnp.arange(4.0))
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
f(jnp.arange(4.0), vectorized=True)
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
f(jnp.arange(4.0), vectorized=False)
def test_vmap_method_broadcast(self):
def callback(x, y):
self.assertTupleEqual(x.shape, (4,))
self.assertTupleEqual(y.shape, (1,))
return x + y
def f(x, y):
return jax.pure_callback(callback, x, x, y, vmap_method="broadcast")
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
def test_vmap_method_broadcast_fullrank(self):
def callback(x, y):
self.assertTupleEqual(x.shape, (4,))
self.assertTupleEqual(y.shape, (4,))
return x + y
def f(x, y):
return jax.pure_callback(callback, x, x, y,
vmap_method="broadcast_fullrank")
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
class IOCallbackTest(jtu.JaxTestCase):