mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add "sequential_unrolled" vmap method for callbacks.
Like the `sequential` method, this loops over calls to the callback, but in this case, the loop is unrolled. PiperOrigin-RevId: 725601366
This commit is contained in:
parent
849ea268a1
commit
c502332ed5
@ -35,6 +35,7 @@ from jax._src.interpreters import batching
|
||||
from jax._src.interpreters import mlir
|
||||
from jax._src.lax import lax
|
||||
from jax._src.lax.control_flow.loops import map as lax_map
|
||||
from jax._src.lax.control_flow.loops import scan
|
||||
from jax._src.lib import xla_client as xc
|
||||
from jax._src.sharding_impls import SingleDeviceSharding
|
||||
from jax._src.typing import DeprecatedArg
|
||||
@ -163,7 +164,10 @@ def callback_batching_rule(
|
||||
|
||||
# For FFI calls we must update the layouts. We handle the output layouts
|
||||
# here, but the input layout updates depend on the vmap_method parameter.
|
||||
if vmap_method != "sequential" and kwargs.get("output_layouts") is not None:
|
||||
if (
|
||||
vmap_method not in ("sequential", "sequential_unrolled") and
|
||||
kwargs.get("output_layouts") is not None
|
||||
):
|
||||
kwargs["output_layouts"] = tuple(
|
||||
None if layout is None else tuple(n + 1 for n in layout) + (0,)
|
||||
for layout in kwargs["output_layouts"])
|
||||
@ -199,7 +203,7 @@ def callback_batching_rule(
|
||||
result_avals=batched_result_avals,
|
||||
**kwargs,
|
||||
)
|
||||
elif vmap_method == "sequential":
|
||||
elif vmap_method == "sequential" or vmap_method == "sequential_unrolled":
|
||||
is_batched = [d is not batching.not_mapped for d in dims]
|
||||
unbatched_args, batched_args = util.partition_list(is_batched, new_args)
|
||||
def _batch_fun(batched_args):
|
||||
@ -211,12 +215,14 @@ def callback_batching_rule(
|
||||
vmap_method=vmap_method,
|
||||
**kwargs,
|
||||
)
|
||||
outvals = lax_map(_batch_fun, batched_args)
|
||||
unroll = vmap_method == "sequential_unrolled"
|
||||
g = lambda _, x: ((), _batch_fun(x))
|
||||
_, outvals = scan(g, (), batched_args, unroll=unroll)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"vmap is only supported for the {prim.name} primitive when vmap_method "
|
||||
"is one of 'sequential', 'expand_dims', 'broadcast_all', or "
|
||||
"'legacy_vectorized'.")
|
||||
"is one of 'sequential', 'sequential_unrolled', 'expand_dims', "
|
||||
f"'broadcast_all', or 'legacy_vectorized'. Got {vmap_method=}.")
|
||||
return tuple(outvals), (0,) * len(outvals)
|
||||
|
||||
|
||||
@ -371,6 +377,8 @@ def pure_callback(
|
||||
is deprecated and it will eventually raise ``NotImplementedError``.
|
||||
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
|
||||
the batched arguments, calling ``callback`` once for each batch element.
|
||||
* ``vmap_method="sequential_unrolled"`` is like ``sequential``, but the loop
|
||||
is unrolled.
|
||||
* ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1``
|
||||
added as the leading dimension unbatched inputs.
|
||||
* ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the
|
||||
@ -459,8 +467,8 @@ def pure_callback(
|
||||
"the vectorized and vmap_method arguments of jax.pure_callback cannot "
|
||||
"be used together. Please use the vmap_method argument.")
|
||||
vmap_method = "legacy_vectorized" if vectorized else "sequential"
|
||||
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
|
||||
"legacy_vectorized", None]
|
||||
allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims",
|
||||
"broadcast_all", "legacy_vectorized", None]
|
||||
if vmap_method not in allowed_vmap_methods:
|
||||
raise ValueError(
|
||||
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
|
||||
|
@ -393,8 +393,8 @@ def ffi_call(
|
||||
"the vectorized and vmap_method arguments of ffi_call cannot "
|
||||
"be used together. Please use the vmap_method argument.")
|
||||
vmap_method = "legacy_vectorized" if vectorized else "sequential"
|
||||
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
|
||||
"legacy_vectorized", None]
|
||||
allowed_vmap_methods = ["sequential", "sequential_unrolled", "expand_dims",
|
||||
"broadcast_all", "legacy_vectorized", None]
|
||||
if vmap_method not in allowed_vmap_methods:
|
||||
raise ValueError(
|
||||
f"vmap_method must be on of the allowed methods {allowed_vmap_methods}, "
|
||||
|
@ -181,7 +181,8 @@ class FfiTest(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=[(6, 5), (4, 5, 6)],
|
||||
vmap_method=["expand_dims", "broadcast_all", "sequential"],
|
||||
vmap_method=["expand_dims", "broadcast_all", "sequential",
|
||||
"sequential_unrolled"],
|
||||
)
|
||||
@jtu.run_on_devices("gpu", "cpu")
|
||||
def test_ffi_call_batching(self, shape, vmap_method):
|
||||
@ -190,7 +191,7 @@ class FfiTest(jtu.JaxTestCase):
|
||||
expected = lax_linalg_internal.geqrf(x)
|
||||
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
|
||||
for a, b in zip(actual, expected):
|
||||
if vmap_method == "sequential" and len(shape) == 3:
|
||||
if vmap_method.startswith("sequential") and len(shape) == 3:
|
||||
# On GPU, the batched FFI call to geqrf uses an algorithm with
|
||||
# different numerics than the unbatched version (which is used when
|
||||
# vmap_method="sequential"). Therefore, we need to include floating
|
||||
|
Loading…
x
Reference in New Issue
Block a user