Merge pull request #24425 from dfm:rename-vmap-methods

PiperOrigin-RevId: 688547393
This commit is contained in:
jax authors 2024-10-22 07:51:29 -07:00
commit a2e4aff897
7 changed files with 38 additions and 38 deletions

View File

@ -299,7 +299,7 @@
"\n",
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
" # as discussed below.\n",
" vmap_method=\"broadcast_fullrank\",\n",
" vmap_method=\"broadcast_all\",\n",
" )\n",
"\n",
" # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n",
@ -342,9 +342,9 @@
"The simplest `vmap_method` is `\"sequential\"`.\n",
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
"This implementation is general purpose, but it doesn't parallelize very well.\n",
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n",
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"expand_dims\"` or `\"broadcast_all\"` methods can be used to expose a better implementation.\n",
"\n",
"In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n",
"In this case, since we only have one input argument, `\"expand_dims\"` and `\"broadcast_all\"` actually have the same behavior.\n",
"The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n",
"Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n",
"\n",
@ -354,11 +354,11 @@
"\n",
"```{tip}\n",
"Note that things get a bit more complicated when we have multiple input arguments.\n",
"For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n",
"For simplicity, we will use the `\"broadcast_all\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"expand_dims\"` method.\n",
"The documentation for {func}`~jax.pure_callback` includes some examples of this\n",
"```\n",
"\n",
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:"
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_all\"` out of the box:"
]
},
{
@ -460,7 +460,7 @@
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
" jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),\n",
" ),\n",
" vmap_method=\"broadcast_fullrank\",\n",
" vmap_method=\"broadcast_all\",\n",
" )(x, eps=np.float32(eps))\n",
" return y, (res, x)\n",
"\n",
@ -474,7 +474,7 @@
" jex.ffi.ffi_call(\n",
" \"rms_norm_bwd\",\n",
" jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n",
" vmap_method=\"broadcast_fullrank\",\n",
" vmap_method=\"broadcast_all\",\n",
" )(res, x, ct),\n",
" )\n",
"\n",
@ -562,7 +562,7 @@
" return lambda x: jex.ffi.ffi_call(\n",
" target_name,\n",
" out_type,\n",
" vmap_method=\"broadcast_fullrank\",\n",
" vmap_method=\"broadcast_all\",\n",
" )(x, eps=np.float32(eps))\n",
"\n",
" return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n",

View File

@ -260,7 +260,7 @@ def rms_norm(x, eps=1e-5):
# The `vmap_method` parameter controls this function's behavior under `vmap`
# as discussed below.
vmap_method="broadcast_fullrank",
vmap_method="broadcast_all",
)
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
@ -299,9 +299,9 @@ The docs for {func}`~jax.pure_callback` provide more details about the `vmap_met
The simplest `vmap_method` is `"sequential"`.
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
This implementation is general purpose, but it doesn't parallelize very well.
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation.
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"expand_dims"` or `"broadcast_all"` methods can be used to expose a better implementation.
In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior.
In this case, since we only have one input argument, `"expand_dims"` and `"broadcast_all"` actually have the same behavior.
The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.
Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:
@ -311,11 +311,11 @@ ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
```{tip}
Note that things get a bit more complicated when we have multiple input arguments.
For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method.
For simplicity, we will use the `"broadcast_all"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"expand_dims"` method.
The documentation for {func}`~jax.pure_callback` includes some examples of this
```
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box:
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_all"` out of the box:
```{code-cell} ipython3
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
@ -378,7 +378,7 @@ def rms_norm_fwd(x, eps=1e-5):
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
),
vmap_method="broadcast_fullrank",
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
return y, (res, x)
@ -392,7 +392,7 @@ def rms_norm_bwd(eps, res, ct):
jex.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
vmap_method="broadcast_fullrank",
vmap_method="broadcast_all",
)(res, x, ct),
)
@ -470,7 +470,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
return lambda x: jex.ffi.ffi_call(
target_name,
out_type,
vmap_method="broadcast_fullrank",
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))

View File

@ -58,7 +58,7 @@ def rms_norm(x, eps=1e-5):
# above in `register_ffi_target`
"rms_norm",
out_type,
vmap_method="broadcast_fullrank",
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
@ -69,7 +69,7 @@ def rms_norm_fwd(x, eps=1e-5):
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
),
vmap_method="broadcast_fullrank",
vmap_method="broadcast_all",
)(x, eps=np.float32(eps))
return y, (res, x)
@ -83,7 +83,7 @@ def rms_norm_bwd(eps, res, ct):
jex.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
vmap_method="broadcast_fullrank",
vmap_method="broadcast_all",
)(res, x, ct),
)

View File

@ -170,8 +170,8 @@ def callback_batching_rule(
result_avals=batched_result_avals,
**kwargs,
)
elif vmap_method == "broadcast" or vmap_method == "broadcast_fullrank":
size = axis_size if vmap_method == "broadcast_fullrank" else 1
elif vmap_method == "expand_dims" or vmap_method == "broadcast_all":
size = axis_size if vmap_method == "broadcast_all" else 1
bcast_args = [
lax.broadcast(x, (size,)) if d is batching.not_mapped else x
for x, d in zip(new_args, dims)]
@ -198,7 +198,7 @@ def callback_batching_rule(
else:
raise NotImplementedError(
f"vmap is only supported for the {prim.name} primitive when vmap_method "
"is one of 'sequential', 'broadcast', 'broadcast_fullrank', or "
"is one of 'sequential', 'expand_dims', 'broadcast_all', or "
"'legacy_vectorized'.")
return tuple(outvals), (0,) * len(outvals)
@ -327,9 +327,9 @@ def pure_callback(
is deprecated and it will eventually raise ``NotImplementedError``.
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
the batched arugments, calling ``callback`` once for each batch element.
* ``vmap_method="broadcast"`` calls ``callback`` with new axes of size ``1``
* ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1``
added as the leading dimension unbatched inputs.
* ``vmap_method="broadcast_fullrank"`` behaves like ``broadcast``, but the
* ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the
inputs are tiled to the expected batched shape.
If necessary, the legacy behavior provided by the deprecated
@ -383,20 +383,20 @@ def pure_callback(
... return jax.pure_callback(callback, out_type, x, y,
... vmap_method=vmap_method)
Calling this with ``vmap_method="broadcast"`` adds a new axis of size ``1``
Calling this with ``vmap_method="expand_dims"`` adds a new axis of size ``1``
to ``y``:
>>> from functools import partial
>>> x = jnp.arange(4)
>>> y = 1.0
>>> jax.vmap(partial(fun, vmap_method="broadcast"), in_axes=(0, None))(x, y)
>>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y)
(4,) (1,)
Array([1., 2., 3., 4.], dtype=float32)
Whereas, ``vmap_method="broadcast_fullrank"`` adds an axis of size ``4`` to
Whereas, ``vmap_method="broadcast_all"`` adds an axis of size ``4`` to
``y``:
>>> jax.vmap(partial(fun, vmap_method="broadcast_fullrank"),
>>> jax.vmap(partial(fun, vmap_method="broadcast_all"),
... in_axes=(0, None))(x, y)
(4,) (4,)
Array([1., 2., 3., 4.], dtype=float32)
@ -415,7 +415,7 @@ def pure_callback(
"the vectorized and vmap_method arguments of jax.pure_callback cannot "
"be used together. Please use the vmap_method argument.")
vmap_method = "legacy_vectorized" if vectorized else "sequential"
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
"legacy_vectorized", None]
if vmap_method not in allowed_vmap_methods:
raise ValueError(

View File

@ -256,7 +256,7 @@ def ffi_call(
"the vectorized and vmap_method arguments of ffi_call cannot "
"be used together. Please use the vmap_method argument.")
vmap_method = "legacy_vectorized" if vectorized else "sequential"
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
"legacy_vectorized", None]
if vmap_method not in allowed_vmap_methods:
raise ValueError(

View File

@ -245,7 +245,7 @@ class FfiTest(jtu.JaxTestCase):
@jtu.sample_product(
shape=[(1,), (4,), (5,)],
dtype=(np.int32,),
vmap_method=("broadcast", "broadcast_fullrank", "sequential",
vmap_method=("expand_dims", "broadcast_all", "sequential",
"legacy_vectorized"),
)
@jtu.run_on_devices("gpu")

View File

@ -696,7 +696,7 @@ class PureCallbackTest(jtu.JaxTestCase):
@jax.jit
@jax.vmap
def g(x):
return jax.pure_callback(cb2, x, x, vmap_method="broadcast")
return jax.pure_callback(cb2, x, x, vmap_method="expand_dims")
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
@ -704,7 +704,7 @@ class PureCallbackTest(jtu.JaxTestCase):
@functools.partial(jax.vmap, in_axes=(0, None))
def h(x, y):
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
vmap_method="broadcast")
vmap_method="expand_dims")
out = h(jnp.arange(4.), 4.)
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
@ -725,7 +725,7 @@ class PureCallbackTest(jtu.JaxTestCase):
@jax.jit
@jax.vmap
def f(x):
return jax.pure_callback(cb, x, x, vmap_method="broadcast")
return jax.pure_callback(cb, x, x, vmap_method="expand_dims")
with self.assertRaises(RuntimeError):
f(jnp.arange(4.))
@ -1007,18 +1007,18 @@ class PureCallbackTest(jtu.JaxTestCase):
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
f(jnp.arange(4.0), vectorized=False)
def test_vmap_method_broadcast(self):
def test_vmap_method_expand_dims(self):
def callback(x, y):
self.assertTupleEqual(x.shape, (4,))
self.assertTupleEqual(y.shape, (1,))
return x + y
def f(x, y):
return jax.pure_callback(callback, x, x, y, vmap_method="broadcast")
return jax.pure_callback(callback, x, x, y, vmap_method="expand_dims")
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
def test_vmap_method_broadcast_fullrank(self):
def test_vmap_method_broadcast_all(self):
def callback(x, y):
self.assertTupleEqual(x.shape, (4,))
self.assertTupleEqual(y.shape, (4,))
@ -1026,7 +1026,7 @@ class PureCallbackTest(jtu.JaxTestCase):
def f(x, y):
return jax.pure_callback(callback, x, x, y,
vmap_method="broadcast_fullrank")
vmap_method="broadcast_all")
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error