Add "sequential_unrolled" vmap method for callbacks.

Like the `sequential` method, this loops over calls to the callback, but in this case, the loop is unrolled.

PiperOrigin-RevId: 725601366
This commit is contained in:
Dan Foreman-Mackey 2025-02-11 06:08:32 -08:00 committed by jax authors
parent 849ea268a1
commit c502332ed5
3 changed files with 20 additions and 11 deletions

View File

@ -35,6 +35,7 @@ from jax._src.interpreters import batching
from jax._src.interpreters import mlir
from jax._src.lax import lax
from jax._src.lax.control_flow.loops import map as lax_map
from jax._src.lax.control_flow.loops import scan
from jax._src.lib import xla_client as xc
from jax._src.sharding_impls import SingleDeviceSharding
from jax._src.typing import DeprecatedArg
@ -163,7 +164,10 @@ def callback_batching_rule(
# For FFI calls we must update the layouts. We handle the output layouts
# here, but the input layout updates depend on the vmap_method parameter.
if vmap_method != "sequential" and kwargs.get("output_layouts") is not None:
if (
vmap_method not in ("sequential", "sequential_unrolled") and
kwargs.get("output_layouts") is not None
):
kwargs["output_layouts"] = tuple(
None if layout is None else tuple(n + 1 for n in layout) + (0,)
for layout in kwargs["output_layouts"])
@ -199,7 +203,7 @@ def callback_batching_rule(
result_avals=batched_result_avals,
**kwargs,
)
elif vmap_method == "sequential":
elif vmap_method == "sequential" or vmap_method == "sequential_unrolled":
is_batched = [d is not batching.not_mapped for d in dims]
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
def _batch_fun(batched_args):
@ -211,12 +215,14 @@ def callback_batching_rule(
vmap_method=vmap_method,
**kwargs,
)
outvals = lax_map(_batch_fun, batched_args)
unroll = vmap_method == "sequential_unrolled"
g = lambda _, x: ((), _batch_fun(x))
_, outvals = scan(g, (), batched_args, unroll=unroll)
else:
raise NotImplementedError(
f"vmap is only supported for the {prim.name} primitive when vmap_method "
"is one of 'sequential', 'expand_dims', 'broadcast_all', or "
"'legacy_vectorized'.")
"is one of 'sequential', 'sequential_unrolled', 'expand_dims', "
f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.")
return tuple(outvals), (0,) * len(outvals)
@ -371,6 +377,8 @@ def pure_callback(
is deprecated and it will eventually raise ``NotImplementedError``.
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
the batched arguments, calling ``callback`` once for each batch element.
* ``vmap_method="sequential_unrolled"`` is like ``sequential``, but the loop
is unrolled.
* ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1``
added as the leading dimension unbatched inputs.
* ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the
@ -459,8 +467,8 @@ def pure_callback(
"the vectorized and vmap_method arguments of jax.pure_callback cannot "
"be used together. Please use the vmap_method argument.")
vmap_method = "legacy_vectorized" if vectorized else "sequential"
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
"legacy_vectorized", None]
allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims",
"broadcast_all", "legacy_vectorized", None]
if vmap_method not in allowed_vmap_methods:
raise ValueError(
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "

View File

@ -393,8 +393,8 @@ def ffi_call(
"the vectorized and vmap_method arguments of ffi_call cannot "
"be used together. Please use the vmap_method argument.")
vmap_method = "legacy_vectorized" if vectorized else "sequential"
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
"legacy_vectorized", None]
allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims",
"broadcast_all", "legacy_vectorized", None]
if vmap_method not in allowed_vmap_methods:
raise ValueError(
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "

View File

@ -181,7 +181,8 @@ class FfiTest(jtu.JaxTestCase):
@jtu.sample_product(
shape=[(6, 5), (4, 5, 6)],
vmap_method=["expand_dims", "broadcast_all", "sequential"],
vmap_method=["expand_dims", "broadcast_all", "sequential",
"sequential_unrolled"],
)
@jtu.run_on_devices("gpu", "cpu")
def test_ffi_call_batching(self, shape, vmap_method):
@ -190,7 +191,7 @@ class FfiTest(jtu.JaxTestCase):
expected = lax_linalg_internal.geqrf(x)
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
for a, b in zip(actual, expected):
if vmap_method == "sequential" and len(shape) == 3:
if vmap_method.startswith("sequential") and len(shape) == 3:
# On GPU, the batched FFI call to geqrf uses an algorithm with
# different numerics than the unbatched version (which is used when
# vmap_method="sequential"). Therefore, we need to include floating