mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Deprecate the vectorized
argument to pure_callback and ffi_call.
This commit is contained in:
parent
816947b656
commit
1d27d420ac
@ -46,6 +46,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
|
||||
* `jax.lib.xla_client.Device` is deprecated; use `jax.Device` instead.
|
||||
* `jax.lib.xla_client.XlaRuntimeError` has been deprecated. Use
|
||||
`jax.errors.JaxRuntimeError` instead.
|
||||
* The default behavior of {func}`jax.pure_callback` and
|
||||
{func}`jax.extend.ffi.ffi_call` under `vmap` has been deprecated and so has
|
||||
the `vectorized` parameter to those functions. The `vmap_method` parameter
|
||||
should be used instead for better defined behavior. See the discussion in
|
||||
{jax-issue}`#23881` for more details.
|
||||
|
||||
* Deletion:
|
||||
* `jax.xla_computation` is deleted. It's been 3 months since it's deprecation
|
||||
|
@ -303,9 +303,9 @@
|
||||
" # type (which corresponds to numpy's `float32` type), and it must be a\n",
|
||||
" # static parameter (i.e. not a JAX array).\n",
|
||||
" eps=np.float32(eps),\n",
|
||||
" # The `vectorized` parameter controls this function's behavior under `vmap`\n",
|
||||
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
|
||||
" # as discussed below.\n",
|
||||
" vectorized=True,\n",
|
||||
" vmap_method=\"broadcast_fullrank\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
@ -325,7 +325,7 @@
|
||||
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n",
|
||||
"Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n",
|
||||
"\n",
|
||||
"The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
|
||||
"The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
|
||||
"\n",
|
||||
"```{tip}\n",
|
||||
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n",
|
||||
@ -336,19 +336,29 @@
|
||||
"(ffi-call-vmap)=\n",
|
||||
"### Batching with `vmap`\n",
|
||||
"\n",
|
||||
"All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.\n",
|
||||
"By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
|
||||
"This default implementation is general purpose, but it doesn't parallelize very well.\n",
|
||||
"But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.\n",
|
||||
"{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
|
||||
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n",
|
||||
"\n",
|
||||
"The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.\n",
|
||||
"The simplest `vmap_method` is `\"sequential\"`.\n",
|
||||
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
|
||||
"This implementation is general purpose, but it doesn't parallelize very well.\n",
|
||||
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n",
|
||||
"\n",
|
||||
"In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n",
|
||||
"The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n",
|
||||
"Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n",
|
||||
"\n",
|
||||
"```python\n",
|
||||
"ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:"
|
||||
"```{tip}\n",
|
||||
"Note that things get a bit more complicated when we have multiple input arguments.\n",
|
||||
"For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n",
|
||||
"The documentation for {func}`~jax.pure_callback` includes some examples of this\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -380,7 +390,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:"
|
||||
"Using `vmap_method=\"sequential\"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -389,24 +399,24 @@
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def rms_norm_not_vectorized(x, eps=1e-5):\n",
|
||||
"def rms_norm_sequential(x, eps=1e-5):\n",
|
||||
" return jex.ffi.ffi_call(\n",
|
||||
" \"rms_norm\",\n",
|
||||
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
|
||||
" x,\n",
|
||||
" eps=np.float32(eps),\n",
|
||||
" vectorized=False, # This is the default behavior\n",
|
||||
" vmap_method=\"sequential\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)"
|
||||
"jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
|
||||
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -454,7 +464,7 @@
|
||||
" ),\n",
|
||||
" x,\n",
|
||||
" eps=np.float32(eps),\n",
|
||||
" vectorized=True,\n",
|
||||
" vmap_method=\"broadcast_fullrank\",\n",
|
||||
" )\n",
|
||||
" return y, (res, x)\n",
|
||||
"\n",
|
||||
@ -471,7 +481,7 @@
|
||||
" res,\n",
|
||||
" x,\n",
|
||||
" ct,\n",
|
||||
" vectorized=True,\n",
|
||||
" vmap_method=\"broadcast_fullrank\",\n",
|
||||
" ),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@ -561,7 +571,7 @@
|
||||
" out_type,\n",
|
||||
" x,\n",
|
||||
" eps=np.float32(eps),\n",
|
||||
" vectorized=True,\n",
|
||||
" vmap_method=\"broadcast_fullrank\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n",
|
||||
|
44
docs/ffi.md
44
docs/ffi.md
@ -264,9 +264,9 @@ def rms_norm(x, eps=1e-5):
|
||||
# type (which corresponds to numpy's `float32` type), and it must be a
|
||||
# static parameter (i.e. not a JAX array).
|
||||
eps=np.float32(eps),
|
||||
# The `vectorized` parameter controls this function's behavior under `vmap`
|
||||
# The `vmap_method` parameter controls this function's behavior under `vmap`
|
||||
# as discussed below.
|
||||
vectorized=True,
|
||||
vmap_method="broadcast_fullrank",
|
||||
)
|
||||
|
||||
|
||||
@ -282,7 +282,7 @@ It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_cal
|
||||
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.
|
||||
Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.
|
||||
|
||||
The `vectorized` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
|
||||
The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
|
||||
|
||||
```{tip}
|
||||
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.
|
||||
@ -293,19 +293,29 @@ One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support so
|
||||
(ffi-call-vmap)=
|
||||
### Batching with `vmap`
|
||||
|
||||
All uses of {func}`~jax.extend.ffi.ffi_call` support {func}`~jax.vmap` out of the box, but this implementation won't necessarily be very efficient.
|
||||
By default, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
|
||||
This default implementation is general purpose, but it doesn't parallelize very well.
|
||||
But, many FFI calls provide more efficient batching behavior and, in some simple cases, the `vectorized` parameter to {func}`~jax.extend.ffi.ffi_call` can be used to expose a better implementation.
|
||||
{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
|
||||
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.
|
||||
|
||||
The specific assumption required to use the `vectorized` parameter is that all leading dimensions of the inputs should be treated as batch axes.
|
||||
The simplest `vmap_method` is `"sequential"`.
|
||||
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
|
||||
This implementation is general purpose, but it doesn't parallelize very well.
|
||||
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation.
|
||||
|
||||
In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior.
|
||||
The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.
|
||||
Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:
|
||||
|
||||
```python
|
||||
ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
|
||||
```
|
||||
|
||||
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vectorized=True` out of the box:
|
||||
```{tip}
|
||||
Note that things get a bit more complicated when we have multiple input arguments.
|
||||
For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method.
|
||||
The documentation for {func}`~jax.pure_callback` includes some examples of this
|
||||
```
|
||||
|
||||
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box:
|
||||
|
||||
```{code-cell} ipython3
|
||||
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
|
||||
@ -317,23 +327,23 @@ We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms
|
||||
jax.make_jaxpr(jax.vmap(rms_norm))(x)
|
||||
```
|
||||
|
||||
If `vectorized` is `False` or omitted, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:
|
||||
Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {func}`jax.lax.scan` with the `ffi_call` in the body:
|
||||
|
||||
```{code-cell} ipython3
|
||||
def rms_norm_not_vectorized(x, eps=1e-5):
|
||||
def rms_norm_sequential(x, eps=1e-5):
|
||||
return jex.ffi.ffi_call(
|
||||
"rms_norm",
|
||||
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
x,
|
||||
eps=np.float32(eps),
|
||||
vectorized=False, # This is the default behavior
|
||||
vmap_method="sequential",
|
||||
)
|
||||
|
||||
|
||||
jax.make_jaxpr(jax.vmap(rms_norm_not_vectorized))(x)
|
||||
jax.make_jaxpr(jax.vmap(rms_norm_sequential))(x)
|
||||
```
|
||||
|
||||
If your foreign function provides an efficient batching rule that isn't supported by this simple `vectorized` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
|
||||
If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues).
|
||||
|
||||
+++
|
||||
|
||||
@ -372,7 +382,7 @@ def rms_norm_fwd(x, eps=1e-5):
|
||||
),
|
||||
x,
|
||||
eps=np.float32(eps),
|
||||
vectorized=True,
|
||||
vmap_method="broadcast_fullrank",
|
||||
)
|
||||
return y, (res, x)
|
||||
|
||||
@ -389,7 +399,7 @@ def rms_norm_bwd(eps, res, ct):
|
||||
res,
|
||||
x,
|
||||
ct,
|
||||
vectorized=True,
|
||||
vmap_method="broadcast_fullrank",
|
||||
),
|
||||
)
|
||||
|
||||
@ -469,7 +479,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
|
||||
out_type,
|
||||
x,
|
||||
eps=np.float32(eps),
|
||||
vectorized=True,
|
||||
vmap_method="broadcast_fullrank",
|
||||
)
|
||||
|
||||
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))
|
||||
|
@ -60,8 +60,7 @@ def rms_norm(x, eps=1e-5):
|
||||
# type (which corresponds to numpy's `float32` type), and it must be a
|
||||
# static parameter (i.e. not a JAX array).
|
||||
eps=np.float32(eps),
|
||||
# The `vectorized` parameter controls this function's behavior under `vmap`.
|
||||
vectorized=True,
|
||||
vmap_method="broadcast_fullrank",
|
||||
)
|
||||
|
||||
|
||||
@ -74,7 +73,7 @@ def rms_norm_fwd(x, eps=1e-5):
|
||||
),
|
||||
x,
|
||||
eps=np.float32(eps),
|
||||
vectorized=True,
|
||||
vmap_method="broadcast_fullrank",
|
||||
)
|
||||
return y, (res, x)
|
||||
|
||||
@ -91,7 +90,7 @@ def rms_norm_bwd(eps, res, ct):
|
||||
res,
|
||||
x,
|
||||
ct,
|
||||
vectorized=True,
|
||||
vmap_method="broadcast_fullrank",
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -22,6 +22,7 @@ from typing import Any
|
||||
|
||||
import jax
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import dtypes
|
||||
from jax._src import effects
|
||||
@ -31,13 +32,18 @@ from jax._src import util
|
||||
from jax._src.interpreters import ad
|
||||
from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax.control_flow.loops import map as lax_map
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding_impls import SingleDeviceSharding
|
||||
from jax._src.typing import DeprecatedArg
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TODO(dfm): Remove after 6 months.
|
||||
# Added Oct 1, 2024
|
||||
deprecations.register("jax-callback-vectorized")
|
||||
|
||||
# `pure_callback_p` is the main primitive for staging out Python pure callbacks.
|
||||
pure_callback_p = core.Primitive("pure_callback")
|
||||
@ -45,6 +51,7 @@ pure_callback_p.multiple_results = True
|
||||
dispatch.prim_requires_devices_during_lowering.add(pure_callback_p)
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
zip, unsafe_zip = util.safe_zip, zip
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
@ -69,9 +76,10 @@ def pure_callback_impl(
|
||||
result_avals,
|
||||
callback: _FlatCallback,
|
||||
sharding: SingleDeviceSharding | None,
|
||||
vectorized: bool,
|
||||
vectorized: bool | DeprecatedArg,
|
||||
vmap_method: str | None,
|
||||
):
|
||||
del sharding, vectorized, result_avals
|
||||
del sharding, vectorized, vmap_method, result_avals
|
||||
try:
|
||||
cpu_device, *_ = jax.local_devices(backend="cpu")
|
||||
except RuntimeError as e:
|
||||
@ -99,9 +107,10 @@ def pure_callback_abstract_eval(
|
||||
callback: _FlatCallback,
|
||||
result_avals,
|
||||
sharding: SingleDeviceSharding | None,
|
||||
vectorized: bool,
|
||||
vectorized: bool | DeprecatedArg,
|
||||
vmap_method: str | None,
|
||||
):
|
||||
del avals, callback, sharding, vectorized
|
||||
del avals, callback, sharding, vectorized, vmap_method
|
||||
return result_avals
|
||||
|
||||
|
||||
@ -129,25 +138,51 @@ def callback_batching_rule(
|
||||
args,
|
||||
dims,
|
||||
*,
|
||||
vectorized: bool,
|
||||
vectorized: bool | None | DeprecatedArg,
|
||||
vmap_method: str | None,
|
||||
result_avals: Sequence[core.ShapedArray],
|
||||
**kwargs: Any,
|
||||
):
|
||||
axis_size = next(a.shape[d] for a, d in zip(args, dims)
|
||||
if d is not batching.not_mapped)
|
||||
if isinstance(vectorized, DeprecatedArg) and vmap_method is None:
|
||||
deprecations.warn(
|
||||
"jax-callback-vectorized",
|
||||
f"The default behavior of {prim.name} under vmap will soon "
|
||||
"change. Currently, the default behavior is to generate a sequential "
|
||||
"vmap (i.e. a loop), but in the future the default will be to raise "
|
||||
"an error. To keep the current default, set vmap_method='sequential'.",
|
||||
stacklevel=6)
|
||||
vmap_method = "sequential"
|
||||
|
||||
axis_size, = {a.shape[d] for a, d in zip(args, dims)
|
||||
if d is not batching.not_mapped}
|
||||
new_args = [arg if dim is batching.not_mapped else
|
||||
batching.moveaxis(arg, dim, 0) for arg, dim in zip(args, dims)]
|
||||
if vectorized:
|
||||
result_avals = tuple(
|
||||
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval) # type: ignore
|
||||
for aval in result_avals)
|
||||
batched_result_avals = tuple(
|
||||
core.unmapped_aval(axis_size, core.no_axis_name, 0, aval)
|
||||
for aval in result_avals)
|
||||
if vmap_method == "legacy_vectorized":
|
||||
# This method is kept to support the behavior that was previously exposed
|
||||
# when using `vectorized=True`.
|
||||
outvals = prim.bind(
|
||||
*new_args,
|
||||
vectorized=vectorized,
|
||||
result_avals=result_avals,
|
||||
vmap_method=vmap_method,
|
||||
result_avals=batched_result_avals,
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
elif vmap_method == "broadcast" or vmap_method == "broadcast_fullrank":
|
||||
size = axis_size if vmap_method == "broadcast_fullrank" else 1
|
||||
bcast_args = [
|
||||
lax.broadcast(x, (size,)) if d is batching.not_mapped else x
|
||||
for x, d in zip(new_args, dims)]
|
||||
outvals = prim.bind(
|
||||
*bcast_args,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
result_avals=batched_result_avals,
|
||||
**kwargs,
|
||||
)
|
||||
elif vmap_method == "sequential":
|
||||
is_batched = [d is not batching.not_mapped for d in dims]
|
||||
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
||||
def _batch_fun(batched_args):
|
||||
@ -156,9 +191,15 @@ def callback_batching_rule(
|
||||
*merged_args,
|
||||
result_avals=result_avals,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
**kwargs,
|
||||
)
|
||||
outvals = lax_map(_batch_fun, batched_args)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"vmap is only supported for the {prim.name} primitive when vmap_method "
|
||||
"is one of 'sequential', 'broadcast', 'broadcast_fullrank', or "
|
||||
"'legacy_vectorized'.")
|
||||
return tuple(outvals), (0,) * len(outvals)
|
||||
|
||||
|
||||
@ -261,7 +302,8 @@ def pure_callback(
|
||||
result_shape_dtypes: Any,
|
||||
*args: Any,
|
||||
sharding: SingleDeviceSharding | None = None,
|
||||
vectorized: bool = False,
|
||||
vectorized: bool | None | DeprecatedArg = DeprecatedArg(),
|
||||
vmap_method: str | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Calls a pure Python callback. Works under :func:`jit`/:func:`~vmap`/etc.
|
||||
@ -279,17 +321,25 @@ def pure_callback(
|
||||
`jit`-decorated function has no data dependence on its value. Pure callbacks
|
||||
may also be reordered if data-dependence allows.
|
||||
|
||||
When `vmap`-ed the behavior will depend on the value of the
|
||||
``vectorized`` keyword argument. When ``vectorized`` is ``True``, the callback
|
||||
is assumed to obey
|
||||
``jax.vmap(callback)(xs) == callback(xs) == jnp.stack([callback(x) for x in xs])``.
|
||||
Therefore, the callback will be called directly on batched inputs (where the
|
||||
batch axes are the leading dimensions). Additionally, the callbacks should
|
||||
return outputs that have corresponding leading batch axes. If not vectorized
|
||||
``callback`` will be mapped sequentially across the batched axis.
|
||||
For example, if ``callback = lambda x, y: np.matmul(x, y)``, then we are free
|
||||
to set ``vectorized=True`` because the ``np.matmul`` function handles
|
||||
arbitrary leading batch dimensions.
|
||||
When `vmap`-ed the behavior will depend on the value of the ``vmap_method``.
|
||||
|
||||
* Calling :func:`~jax.vmap` on a callback without an explicit ``vmap_method``
|
||||
is deprecated and it will eventually raise ``NotImplementedError``.
|
||||
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
|
||||
the batched arugments, calling ``callback`` once for each batch element.
|
||||
* ``vmap_method="broadcast"`` calls ``callback`` with new axes of size ``1``
|
||||
added as the leading dimension unbatched inputs.
|
||||
* ``vmap_method="broadcast_fullrank"`` behaves like ``broadcast``, but the
|
||||
inputs are tiled to the expected batched shape.
|
||||
|
||||
If necessary, the legacy behavior provided by the deprecated
|
||||
``vectorized=True`` argument can be recovered using
|
||||
``vmap_method="legacy_vectorized"``.
|
||||
|
||||
The current default behavior is to use ``vmap_method="sequential"`` when
|
||||
not specified, but this behavior is deprecated, and in the future, the
|
||||
default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is
|
||||
explicitly specified.
|
||||
|
||||
Args:
|
||||
callback: function to execute on the host. The callback is assumed to be a pure
|
||||
@ -303,8 +353,8 @@ def pure_callback(
|
||||
*args: arguments to be passed to the callback function
|
||||
sharding: optional sharding that specifies the device from which the callback should
|
||||
be invoked.
|
||||
vectorized: boolean specifying whether the callback function can operate in a
|
||||
vectorized manner.
|
||||
vmap_method: string specifying how the callback transforms under
|
||||
:func:`~jax.vmap` as described above.
|
||||
**kwargs: keyword arguments to be passed to the callback function
|
||||
|
||||
Returns:
|
||||
@ -316,8 +366,62 @@ def pure_callback(
|
||||
- :func:`jax.debug.callback`: callback designed for general-purpose debugging.
|
||||
- :func:`jax.debug.print`: callback designed for printing.
|
||||
|
||||
Examples:
|
||||
The behavior of ``pure_callback`` under :func:`~jax.vmap` is controlled by
|
||||
the ``vmap_method`` argument as described above. It is useful to consider
|
||||
some explicit examples that demonstrate the semantics. For example,
|
||||
consider the following function:
|
||||
|
||||
>>> def callback(x, y):
|
||||
... print(jnp.shape(x), jnp.shape(y))
|
||||
... return x + y
|
||||
|
||||
>>> def fun(x, y, *, vmap_method):
|
||||
... shape = jnp.broadcast_shapes(jnp.shape(x), jnp.shape(y))
|
||||
... dtype = jnp.result_type(x, y)
|
||||
... out_type = jax.ShapeDtypeStruct(shape, dtype)
|
||||
... return jax.pure_callback(callback, out_type, x, y,
|
||||
... vmap_method=vmap_method)
|
||||
|
||||
Calling this with ``vmap_method="broadcast"`` adds a new axis of size ``1``
|
||||
to ``y``:
|
||||
|
||||
>>> from functools import partial
|
||||
>>> x = jnp.arange(4)
|
||||
>>> y = 1.0
|
||||
>>> jax.vmap(partial(fun, vmap_method="broadcast"), in_axes=(0, None))(x, y)
|
||||
(4,) (1,)
|
||||
Array([1., 2., 3., 4.], dtype=float32)
|
||||
|
||||
Whereas, ``vmap_method="broadcast_fullrank"`` adds an axis of size ``4`` to
|
||||
``y``:
|
||||
|
||||
>>> jax.vmap(partial(fun, vmap_method="broadcast_fullrank"),
|
||||
... in_axes=(0, None))(x, y)
|
||||
(4,) (4,)
|
||||
Array([1., 2., 3., 4.], dtype=float32)
|
||||
|
||||
.. _External Callbacks: https://jax.readthedocs.io/en/latest/notebooks/external_callbacks.html
|
||||
"""
|
||||
if not isinstance(vectorized, DeprecatedArg) and not vectorized is None:
|
||||
deprecations.warn(
|
||||
"jax-callback-vectorized",
|
||||
"The vectorized argument of jax.pure_callback is deprecated and setting "
|
||||
"it will soon raise an error. To avoid an error in the future, and to "
|
||||
"suppress this warning, please use the vmap_method argument instead.",
|
||||
stacklevel=2)
|
||||
if vmap_method is not None:
|
||||
raise ValueError(
|
||||
"the vectorized and vmap_method arguments of jax.pure_callback cannot "
|
||||
"be used together. Please use the vmap_method argument.")
|
||||
vmap_method = "legacy_vectorized" if vectorized else "sequential"
|
||||
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
|
||||
"legacy_vectorized", None]
|
||||
if vmap_method not in allowed_vmap_methods:
|
||||
raise ValueError(
|
||||
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
|
||||
f"but got: {vmap_method}")
|
||||
|
||||
flat_args, in_tree = tree_util.tree_flatten((args, kwargs))
|
||||
tree_util.tree_map(_check_shape_dtype, result_shape_dtypes)
|
||||
result_avals = tree_util.tree_map(
|
||||
@ -329,6 +433,7 @@ def pure_callback(
|
||||
result_avals=tuple(flat_result_avals),
|
||||
sharding=sharding,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
)
|
||||
return tree_util.tree_unflatten(out_tree, out_flat)
|
||||
|
||||
|
@ -23,6 +23,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from jax._src import core
|
||||
from jax._src import deprecations
|
||||
from jax._src import dispatch
|
||||
from jax._src import effects
|
||||
from jax._src import util
|
||||
@ -34,7 +35,8 @@ from jax._src.layout import DeviceLocalLayout
|
||||
from jax._src.lib import jaxlib
|
||||
from jax._src.lib import xla_client
|
||||
from jax._src.lib.mlir import ir
|
||||
from jax._src.typing import Array, ArrayLike, DuckTypedArray, Shape
|
||||
from jax._src.typing import (Array, ArrayLike, DeprecatedArg, DuckTypedArray,
|
||||
Shape)
|
||||
|
||||
map, unsafe_map = util.safe_map, map
|
||||
FfiLayoutOptions = Sequence[int] | DeviceLocalLayout | None
|
||||
@ -199,23 +201,22 @@ def ffi_call(
|
||||
target_name: str,
|
||||
result_shape_dtypes: ResultMetadata | Sequence[ResultMetadata],
|
||||
*args: ArrayLike,
|
||||
vectorized: bool = False,
|
||||
has_side_effect: bool = False,
|
||||
vmap_method: str | None = None,
|
||||
vectorized: bool | DeprecatedArg = DeprecatedArg(),
|
||||
**kwargs: Any,
|
||||
) -> Array | list[Array]:
|
||||
"""Call a foreign function interface (FFI) target.
|
||||
|
||||
Like :func:`~jax.pure_callback`, the behavior of ``ffi_call`` under
|
||||
:func:`~jax.vmap` depends on the value of ``vectorized``. When ``vectorized``
|
||||
is ``True``, the FFI target is assumed to satisfy: ``ffi_call(xs) ==
|
||||
jnp.stack([ffi_call(x) for x in xs])``. In other words, calling the FFI target
|
||||
with an extra leading dimension should return the same result as calling it
|
||||
within a loop and stacking along the zeroth axis. Therefore, the FFI target
|
||||
will be called directly on batched inputs (where the batch axes are the
|
||||
leading dimensions). Additionally, the callbacks should return outputs that
|
||||
have corresponding leading batch axes. If ``vectorized`` is ``False`` (the
|
||||
default behavior), transforming this ``ffi_call`` under :func:`~jax.vmap` will
|
||||
result in a :func:`~jax.lax.scan` with the ``ffi_call`` in the body.
|
||||
:func:`~jax.vmap` depends on the value of ``vmap_method``. See the
|
||||
:func:`~jax.pure_callback` documenation for more details about the allowed
|
||||
values and examples of their behavior.
|
||||
|
||||
The current default behavior is to use ``vmap_method="sequential"`` when
|
||||
not specified, but this behavior is deprecated, and in the future, the
|
||||
default will be to raise a ``NotImplementedError`` unless ``vmap_method`` is
|
||||
explicitly specified.
|
||||
|
||||
Args:
|
||||
target_name: the name of the XLA FFI custom call target that was registered
|
||||
@ -226,11 +227,11 @@ def ffi_call(
|
||||
used to define the elements of ``result_shape_dtypes``.
|
||||
``jax.core.abstract_token`` may be used to represent a token-typed output.
|
||||
*args: the arguments passed to the custom call.
|
||||
vectorized: boolean specifying whether the FFI call can operate in a
|
||||
vectorized manner, as described above.
|
||||
has_side_effect: boolean specifying whether the custom call has side
|
||||
effects. When ``True``, the FFI call will be executed even when the
|
||||
outputs are not used.
|
||||
vmap_method: string specifying how the FFI call transforms under
|
||||
:func:`~jax.vmap` as described above.
|
||||
**kwargs: keyword arguments that are passed as named attributes to the
|
||||
custom call using XLA's FFI interface.
|
||||
|
||||
@ -238,6 +239,25 @@ def ffi_call(
|
||||
One or more :class:`~jax.Array` objects whose shapes and dtypes match
|
||||
``result_shape_dtypes``.
|
||||
"""
|
||||
if not isinstance(vectorized, DeprecatedArg) and not vectorized is None:
|
||||
deprecations.warn(
|
||||
"jax-callback-vectorized",
|
||||
"The vectorized argument of ffi_call is deprecated and setting "
|
||||
"it will soon raise an error. To avoid an error in the future, and to "
|
||||
"suppress this warning, please use the vmap_method argument instead.",
|
||||
stacklevel=2)
|
||||
if vmap_method is not None:
|
||||
raise ValueError(
|
||||
"the vectorized and vmap_method arguments of ffi_call cannot "
|
||||
"be used together. Please use the vmap_method argument.")
|
||||
vmap_method = "legacy_vectorized" if vectorized else "sequential"
|
||||
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
|
||||
"legacy_vectorized", None]
|
||||
if vmap_method not in allowed_vmap_methods:
|
||||
raise ValueError(
|
||||
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
|
||||
f"but got: {vmap_method}")
|
||||
|
||||
if isinstance(result_shape_dtypes, Sequence):
|
||||
multiple_results = True
|
||||
result_avals = _result_avals(result_shape_dtypes)
|
||||
@ -248,6 +268,7 @@ def ffi_call(
|
||||
*args,
|
||||
result_avals=result_avals,
|
||||
vectorized=vectorized,
|
||||
vmap_method=vmap_method,
|
||||
target_name=target_name,
|
||||
has_side_effect=has_side_effect,
|
||||
**_wrap_kwargs_hashable(kwargs),
|
||||
@ -342,11 +363,12 @@ def ffi_call_abstract_eval(
|
||||
*avals_in,
|
||||
result_avals: tuple[core.AbstractValue, ...],
|
||||
target_name: str,
|
||||
vectorized: bool,
|
||||
vectorized: bool | DeprecatedArg,
|
||||
vmap_method: str | None,
|
||||
has_side_effect: bool,
|
||||
**kwargs: Any,
|
||||
):
|
||||
del avals_in, target_name, vectorized, kwargs
|
||||
del avals_in, target_name, vectorized, vmap_method, kwargs
|
||||
effects = {_FfiEffect} if has_side_effect else core.no_effects
|
||||
return result_avals, effects
|
||||
|
||||
@ -370,11 +392,12 @@ def ffi_call_lowering(
|
||||
*operands: ir.Value,
|
||||
result_avals: tuple[core.AbstractValue, ...],
|
||||
target_name: str,
|
||||
vectorized: bool,
|
||||
vectorized: bool | DeprecatedArg,
|
||||
vmap_method: str | None,
|
||||
has_side_effect: bool,
|
||||
**kwargs: Any,
|
||||
) -> Sequence[ir.Value]:
|
||||
del result_avals, vectorized
|
||||
del result_avals, vectorized, vmap_method
|
||||
rule = ffi_lowering(target_name, has_side_effect=has_side_effect)
|
||||
return rule(ctx, *operands, **_unwrap_kwargs_hashable(kwargs))
|
||||
|
||||
|
@ -245,10 +245,11 @@ class FfiTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
shape=[(1,), (4,), (5,)],
|
||||
dtype=(np.int32,),
|
||||
vectorized=(False, True),
|
||||
vmap_method=("broadcast", "broadcast_fullrank", "sequential",
|
||||
"legacy_vectorized"),
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
def testFfiCallBatching(self, shape, dtype, vectorized):
|
||||
def testFfiCallBatching(self, shape, dtype, vmap_method):
|
||||
shape = (10,) + shape
|
||||
pivots_size = shape[-1]
|
||||
permutation_size = 2 * pivots_size
|
||||
@ -256,15 +257,29 @@ class FfiTest(jtu.JaxTestCase):
|
||||
pivots = jnp.broadcast_to(pivots, shape)
|
||||
expected = lax.linalg.lu_pivots_to_permutation(pivots, permutation_size)
|
||||
actual = jax.vmap(lambda x: ffi_call_lu_pivots_to_permutation(
|
||||
x, permutation_size, vectorized=vectorized))(pivots)
|
||||
x, permutation_size, vmap_method=vmap_method))(pivots)
|
||||
self.assertArraysEqual(actual, expected)
|
||||
|
||||
@jtu.run_on_devices("gpu")
|
||||
def testVectorizedDeprecation(self):
|
||||
pivots_size = 4
|
||||
shape = (10, pivots_size)
|
||||
permutation_size = 2 * pivots_size
|
||||
pivots = jnp.arange(permutation_size - 1, pivots_size - 1, -1,
|
||||
dtype=np.int32)
|
||||
pivots = jnp.broadcast_to(pivots, shape)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True)
|
||||
with self.assertWarns(DeprecationWarning):
|
||||
jax.vmap(
|
||||
lambda x: ffi_call_lu_pivots_to_permutation(x, permutation_size))(pivots)
|
||||
|
||||
|
||||
# TODO(dfm): For now this test uses the `cu_lu_pivots_to_permutation`
|
||||
# custom call target because that's the only one in jaxlib that uses the
|
||||
# new FFI interface. Once more are available, consider using something that
|
||||
# can be run on multiple platforms.
|
||||
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True):
|
||||
def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, **kwargs):
|
||||
return jex.ffi.ffi_call(
|
||||
"cu_lu_pivots_to_permutation",
|
||||
jax.ShapeDtypeStruct(
|
||||
@ -272,7 +287,7 @@ def ffi_call_lu_pivots_to_permutation(pivots, permutation_size, vectorized=True)
|
||||
dtype=pivots.dtype,
|
||||
),
|
||||
pivots,
|
||||
vectorized=vectorized,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1724,7 +1724,8 @@ class HostCallbackCallTest(jtu.JaxTestCase):
|
||||
"batching rules are implemented only for id_tap, not for call"):
|
||||
jax.vmap(fun)(np.ones((2, 3)))
|
||||
else:
|
||||
jax.vmap(fun)(np.ones((2, 3)))
|
||||
with jtu.ignore_warning(category=DeprecationWarning):
|
||||
jax.vmap(fun)(np.ones((2, 3)))
|
||||
|
||||
def test_call_error_bad_result_shape(self):
|
||||
with self.assertRaisesRegex(
|
||||
|
@ -635,13 +635,13 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
return jax.pure_callback(np.sin, x, x)
|
||||
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
|
||||
out = f(jnp.arange(4.))
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(4.)))
|
||||
|
||||
@jax.jit
|
||||
def g(x):
|
||||
return jax.pure_callback(np.sin, x, x)
|
||||
return jax.pure_callback(np.sin, x, x, vmap_method="sequential")
|
||||
out = jax.vmap(g, in_axes=1)(jnp.arange(8.).reshape((4, 2)))
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(8.).reshape((4, 2))).T)
|
||||
|
||||
@ -649,7 +649,8 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@functools.partial(jax.vmap, in_axes=(0, None))
|
||||
def h(x, y):
|
||||
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y)
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
|
||||
vmap_method="sequential")
|
||||
out = h(jnp.arange(4.), 4.)
|
||||
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + 4.,
|
||||
rtol=1E-7, check_dtypes=False)
|
||||
@ -658,7 +659,8 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@functools.partial(jax.vmap)
|
||||
def h(x, y):
|
||||
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y)
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
|
||||
vmap_method="sequential")
|
||||
out = h(jnp.arange(4.), jnp.arange(10., 14.))
|
||||
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10., 14.),
|
||||
rtol=1E-7, check_dtypes=False)
|
||||
@ -667,7 +669,8 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@functools.partial(jax.vmap, in_axes=1, out_axes=1)
|
||||
def h(x, y):
|
||||
out_shape = jax.ShapeDtypeStruct(x.shape, np.result_type(x.dtype, y.dtype))
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y)
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, out_shape, x, y,
|
||||
vmap_method="sequential")
|
||||
out = h(jnp.arange(4.)[None], jnp.arange(10., 14.)[None])
|
||||
self.assertArraysAllClose(out, np.sin(np.arange(4.)) + np.arange(10.,
|
||||
14.)[None],
|
||||
@ -682,7 +685,7 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
return jax.pure_callback(cb, x, x)
|
||||
return jax.pure_callback(cb, x, x, vmap_method="sequential")
|
||||
|
||||
np.testing.assert_allclose(f(jnp.arange(4.)), np.sin(np.arange(4.)))
|
||||
|
||||
@ -693,7 +696,7 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def g(x):
|
||||
return jax.pure_callback(cb2, x, x, vectorized=True)
|
||||
return jax.pure_callback(cb2, x, x, vmap_method="broadcast")
|
||||
|
||||
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
|
||||
|
||||
@ -701,7 +704,7 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@functools.partial(jax.vmap, in_axes=(0, None))
|
||||
def h(x, y):
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
|
||||
vectorized=True)
|
||||
vmap_method="broadcast")
|
||||
out = h(jnp.arange(4.), 4.)
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
|
||||
|
||||
@ -709,7 +712,7 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@functools.partial(jax.vmap, in_axes=(1, None), out_axes=1)
|
||||
def h(x, y):
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
|
||||
vectorized=True)
|
||||
vmap_method="legacy_vectorized")
|
||||
out = h(jnp.arange(4.)[None], 4.)
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(4.)[None]) + 4.)
|
||||
|
||||
@ -722,7 +725,7 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
return jax.pure_callback(cb, x, x, vectorized=True)
|
||||
return jax.pure_callback(cb, x, x, vmap_method="broadcast")
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
f(jnp.arange(4.))
|
||||
@ -981,6 +984,52 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
out = jax.pure_callback(f, jax.ShapeDtypeStruct(x.shape, x.dtype), x)
|
||||
np.testing.assert_allclose(out, 2 * jnp.log(x + 1))
|
||||
|
||||
def test_vmap_method_raise(self):
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
# Setting vectorized to None disables the current default behavior of
|
||||
# falling back on sequential.
|
||||
return jax.pure_callback(np.sin, x, x, vectorized=None)
|
||||
|
||||
with self.assertRaisesRegex(NotImplementedError, "vmap is only supported"):
|
||||
f(jnp.arange(4.))
|
||||
|
||||
def test_deprecated_vectorized(self):
|
||||
def f(x, **kwargs):
|
||||
return jax.pure_callback(np.sin, x, x, **kwargs)
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "The default behavior"):
|
||||
jax.vmap(f)(jnp.arange(4.0))
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
|
||||
f(jnp.arange(4.0), vectorized=True)
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
|
||||
f(jnp.arange(4.0), vectorized=False)
|
||||
|
||||
def test_vmap_method_broadcast(self):
|
||||
def callback(x, y):
|
||||
self.assertTupleEqual(x.shape, (4,))
|
||||
self.assertTupleEqual(y.shape, (1,))
|
||||
return x + y
|
||||
|
||||
def f(x, y):
|
||||
return jax.pure_callback(callback, x, x, y, vmap_method="broadcast")
|
||||
|
||||
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
|
||||
|
||||
def test_vmap_method_broadcast_fullrank(self):
|
||||
def callback(x, y):
|
||||
self.assertTupleEqual(x.shape, (4,))
|
||||
self.assertTupleEqual(y.shape, (4,))
|
||||
return x + y
|
||||
|
||||
def f(x, y):
|
||||
return jax.pure_callback(callback, x, x, y,
|
||||
vmap_method="broadcast_fullrank")
|
||||
|
||||
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
|
||||
|
||||
|
||||
class IOCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user