mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #24425 from dfm:rename-vmap-methods
PiperOrigin-RevId: 688547393
This commit is contained in:
commit
a2e4aff897
@ -299,7 +299,7 @@
|
||||
"\n",
|
||||
" # The `vmap_method` parameter controls this function's behavior under `vmap`\n",
|
||||
" # as discussed below.\n",
|
||||
" vmap_method=\"broadcast_fullrank\",\n",
|
||||
" vmap_method=\"broadcast_all\",\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
" # Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for\n",
|
||||
@ -342,9 +342,9 @@
|
||||
"The simplest `vmap_method` is `\"sequential\"`.\n",
|
||||
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
|
||||
"This implementation is general purpose, but it doesn't parallelize very well.\n",
|
||||
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"broadcast\"` or `\"broadcast_fullrank\"` methods can be used to expose a better implementation.\n",
|
||||
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"expand_dims\"` or `\"broadcast_all\"` methods can be used to expose a better implementation.\n",
|
||||
"\n",
|
||||
"In this case, since we only have one input argument, `\"broadcast\"` and `\"broadcast_fullrank\"` actually have the same behavior.\n",
|
||||
"In this case, since we only have one input argument, `\"expand_dims\"` and `\"broadcast_all\"` actually have the same behavior.\n",
|
||||
"The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.\n",
|
||||
"Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n",
|
||||
"\n",
|
||||
@ -354,11 +354,11 @@
|
||||
"\n",
|
||||
"```{tip}\n",
|
||||
"Note that things get a bit more complicated when we have multiple input arguments.\n",
|
||||
"For simplicity, we will use the `\"broadcast_fullrank\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"broadcast\"` method.\n",
|
||||
"For simplicity, we will use the `\"broadcast_all\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"expand_dims\"` method.\n",
|
||||
"The documentation for {func}`~jax.pure_callback` includes some examples of this\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_fullrank\"` out of the box:"
|
||||
"Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method=\"broadcast_all\"` out of the box:"
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -460,7 +460,7 @@
|
||||
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
|
||||
" jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),\n",
|
||||
" ),\n",
|
||||
" vmap_method=\"broadcast_fullrank\",\n",
|
||||
" vmap_method=\"broadcast_all\",\n",
|
||||
" )(x, eps=np.float32(eps))\n",
|
||||
" return y, (res, x)\n",
|
||||
"\n",
|
||||
@ -474,7 +474,7 @@
|
||||
" jex.ffi.ffi_call(\n",
|
||||
" \"rms_norm_bwd\",\n",
|
||||
" jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n",
|
||||
" vmap_method=\"broadcast_fullrank\",\n",
|
||||
" vmap_method=\"broadcast_all\",\n",
|
||||
" )(res, x, ct),\n",
|
||||
" )\n",
|
||||
"\n",
|
||||
@ -562,7 +562,7 @@
|
||||
" return lambda x: jex.ffi.ffi_call(\n",
|
||||
" target_name,\n",
|
||||
" out_type,\n",
|
||||
" vmap_method=\"broadcast_fullrank\",\n",
|
||||
" vmap_method=\"broadcast_all\",\n",
|
||||
" )(x, eps=np.float32(eps))\n",
|
||||
"\n",
|
||||
" return jax.lax.platform_dependent(x, cpu=impl(\"rms_norm\"), cuda=impl(\"rms_norm_cuda\"))\n",
|
||||
|
16
docs/ffi.md
16
docs/ffi.md
@ -260,7 +260,7 @@ def rms_norm(x, eps=1e-5):
|
||||
|
||||
# The `vmap_method` parameter controls this function's behavior under `vmap`
|
||||
# as discussed below.
|
||||
vmap_method="broadcast_fullrank",
|
||||
vmap_method="broadcast_all",
|
||||
)
|
||||
|
||||
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
|
||||
@ -299,9 +299,9 @@ The docs for {func}`~jax.pure_callback` provide more details about the `vmap_met
|
||||
The simplest `vmap_method` is `"sequential"`.
|
||||
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
|
||||
This implementation is general purpose, but it doesn't parallelize very well.
|
||||
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"broadcast"` or `"broadcast_fullrank"` methods can be used to expose a better implementation.
|
||||
Many FFI calls provide more efficient batching behavior and, in some simple cases, the `"expand_dims"` or `"broadcast_all"` methods can be used to expose a better implementation.
|
||||
|
||||
In this case, since we only have one input argument, `"broadcast"` and `"broadcast_fullrank"` actually have the same behavior.
|
||||
In this case, since we only have one input argument, `"expand_dims"` and `"broadcast_all"` actually have the same behavior.
|
||||
The specific assumption required to use these methods is that the foreign function knows how to handle batch dimensions.
|
||||
Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:
|
||||
|
||||
@ -311,11 +311,11 @@ ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])
|
||||
|
||||
```{tip}
|
||||
Note that things get a bit more complicated when we have multiple input arguments.
|
||||
For simplicity, we will use the `"broadcast_fullrank"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"broadcast"` method.
|
||||
For simplicity, we will use the `"broadcast_all"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `"expand_dims"` method.
|
||||
The documentation for {func}`~jax.pure_callback` includes some examples of this
|
||||
```
|
||||
|
||||
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_fullrank"` out of the box:
|
||||
Our implementation of `rms_norm` has the appropriate semantics, and it supports `vmap` with `vmap_method="broadcast_all"` out of the box:
|
||||
|
||||
```{code-cell} ipython3
|
||||
np.testing.assert_allclose(jax.vmap(rms_norm)(x), jax.vmap(rms_norm_ref)(x), rtol=1e-5)
|
||||
@ -378,7 +378,7 @@ def rms_norm_fwd(x, eps=1e-5):
|
||||
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
|
||||
),
|
||||
vmap_method="broadcast_fullrank",
|
||||
vmap_method="broadcast_all",
|
||||
)(x, eps=np.float32(eps))
|
||||
return y, (res, x)
|
||||
|
||||
@ -392,7 +392,7 @@ def rms_norm_bwd(eps, res, ct):
|
||||
jex.ffi.ffi_call(
|
||||
"rms_norm_bwd",
|
||||
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
|
||||
vmap_method="broadcast_fullrank",
|
||||
vmap_method="broadcast_all",
|
||||
)(res, x, ct),
|
||||
)
|
||||
|
||||
@ -470,7 +470,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
|
||||
return lambda x: jex.ffi.ffi_call(
|
||||
target_name,
|
||||
out_type,
|
||||
vmap_method="broadcast_fullrank",
|
||||
vmap_method="broadcast_all",
|
||||
)(x, eps=np.float32(eps))
|
||||
|
||||
return jax.lax.platform_dependent(x, cpu=impl("rms_norm"), cuda=impl("rms_norm_cuda"))
|
||||
|
@ -58,7 +58,7 @@ def rms_norm(x, eps=1e-5):
|
||||
# above in `register_ffi_target`
|
||||
"rms_norm",
|
||||
out_type,
|
||||
vmap_method="broadcast_fullrank",
|
||||
vmap_method="broadcast_all",
|
||||
)(x, eps=np.float32(eps))
|
||||
|
||||
|
||||
@ -69,7 +69,7 @@ def rms_norm_fwd(x, eps=1e-5):
|
||||
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
|
||||
),
|
||||
vmap_method="broadcast_fullrank",
|
||||
vmap_method="broadcast_all",
|
||||
)(x, eps=np.float32(eps))
|
||||
return y, (res, x)
|
||||
|
||||
@ -83,7 +83,7 @@ def rms_norm_bwd(eps, res, ct):
|
||||
jex.ffi.ffi_call(
|
||||
"rms_norm_bwd",
|
||||
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
|
||||
vmap_method="broadcast_fullrank",
|
||||
vmap_method="broadcast_all",
|
||||
)(res, x, ct),
|
||||
)
|
||||
|
||||
|
@ -170,8 +170,8 @@ def callback_batching_rule(
|
||||
result_avals=batched_result_avals,
|
||||
**kwargs,
|
||||
)
|
||||
elif vmap_method == "broadcast" or vmap_method == "broadcast_fullrank":
|
||||
size = axis_size if vmap_method == "broadcast_fullrank" else 1
|
||||
elif vmap_method == "expand_dims" or vmap_method == "broadcast_all":
|
||||
size = axis_size if vmap_method == "broadcast_all" else 1
|
||||
bcast_args = [
|
||||
lax.broadcast(x, (size,)) if d is batching.not_mapped else x
|
||||
for x, d in zip(new_args, dims)]
|
||||
@ -198,7 +198,7 @@ def callback_batching_rule(
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"vmap is only supported for the {prim.name} primitive when vmap_method "
|
||||
"is one of 'sequential', 'broadcast', 'broadcast_fullrank', or "
|
||||
"is one of 'sequential', 'expand_dims', 'broadcast_all', or "
|
||||
"'legacy_vectorized'.")
|
||||
return tuple(outvals), (0,) * len(outvals)
|
||||
|
||||
@ -327,9 +327,9 @@ def pure_callback(
|
||||
is deprecated and it will eventually raise ``NotImplementedError``.
|
||||
* ``vmap_method="sequential"`` uses :func:`~jax.lax.map` to loop over
|
||||
the batched arugments, calling ``callback`` once for each batch element.
|
||||
* ``vmap_method="broadcast"`` calls ``callback`` with new axes of size ``1``
|
||||
* ``vmap_method="expand_dims"`` calls ``callback`` with new axes of size ``1``
|
||||
added as the leading dimension unbatched inputs.
|
||||
* ``vmap_method="broadcast_fullrank"`` behaves like ``broadcast``, but the
|
||||
* ``vmap_method="broadcast_all"`` behaves like ``expand_dims``, but the
|
||||
inputs are tiled to the expected batched shape.
|
||||
|
||||
If necessary, the legacy behavior provided by the deprecated
|
||||
@ -383,20 +383,20 @@ def pure_callback(
|
||||
... return jax.pure_callback(callback, out_type, x, y,
|
||||
... vmap_method=vmap_method)
|
||||
|
||||
Calling this with ``vmap_method="broadcast"`` adds a new axis of size ``1``
|
||||
Calling this with ``vmap_method="expand_dims"`` adds a new axis of size ``1``
|
||||
to ``y``:
|
||||
|
||||
>>> from functools import partial
|
||||
>>> x = jnp.arange(4)
|
||||
>>> y = 1.0
|
||||
>>> jax.vmap(partial(fun, vmap_method="broadcast"), in_axes=(0, None))(x, y)
|
||||
>>> jax.vmap(partial(fun, vmap_method="expand_dims"), in_axes=(0, None))(x, y)
|
||||
(4,) (1,)
|
||||
Array([1., 2., 3., 4.], dtype=float32)
|
||||
|
||||
Whereas, ``vmap_method="broadcast_fullrank"`` adds an axis of size ``4`` to
|
||||
Whereas, ``vmap_method="broadcast_all"`` adds an axis of size ``4`` to
|
||||
``y``:
|
||||
|
||||
>>> jax.vmap(partial(fun, vmap_method="broadcast_fullrank"),
|
||||
>>> jax.vmap(partial(fun, vmap_method="broadcast_all"),
|
||||
... in_axes=(0, None))(x, y)
|
||||
(4,) (4,)
|
||||
Array([1., 2., 3., 4.], dtype=float32)
|
||||
@ -415,7 +415,7 @@ def pure_callback(
|
||||
"the vectorized and vmap_method arguments of jax.pure_callback cannot "
|
||||
"be used together. Please use the vmap_method argument.")
|
||||
vmap_method = "legacy_vectorized" if vectorized else "sequential"
|
||||
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
|
||||
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
|
||||
"legacy_vectorized", None]
|
||||
if vmap_method not in allowed_vmap_methods:
|
||||
raise ValueError(
|
||||
|
@ -256,7 +256,7 @@ def ffi_call(
|
||||
"the vectorized and vmap_method arguments of ffi_call cannot "
|
||||
"be used together. Please use the vmap_method argument.")
|
||||
vmap_method = "legacy_vectorized" if vectorized else "sequential"
|
||||
allowed_vmap_methods = ["sequential", "broadcast", "broadcast_fullrank",
|
||||
allowed_vmap_methods = ["sequential", "expand_dims", "broadcast_all",
|
||||
"legacy_vectorized", None]
|
||||
if vmap_method not in allowed_vmap_methods:
|
||||
raise ValueError(
|
||||
|
@ -245,7 +245,7 @@ class FfiTest(jtu.JaxTestCase):
|
||||
@jtu.sample_product(
|
||||
shape=[(1,), (4,), (5,)],
|
||||
dtype=(np.int32,),
|
||||
vmap_method=("broadcast", "broadcast_fullrank", "sequential",
|
||||
vmap_method=("expand_dims", "broadcast_all", "sequential",
|
||||
"legacy_vectorized"),
|
||||
)
|
||||
@jtu.run_on_devices("gpu")
|
||||
|
@ -696,7 +696,7 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def g(x):
|
||||
return jax.pure_callback(cb2, x, x, vmap_method="broadcast")
|
||||
return jax.pure_callback(cb2, x, x, vmap_method="expand_dims")
|
||||
|
||||
np.testing.assert_allclose(g(jnp.arange(4.)), np.sin(np.arange(4.)))
|
||||
|
||||
@ -704,7 +704,7 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@functools.partial(jax.vmap, in_axes=(0, None))
|
||||
def h(x, y):
|
||||
return jax.pure_callback(lambda x, y: np.sin(x) + y, x, x, y,
|
||||
vmap_method="broadcast")
|
||||
vmap_method="expand_dims")
|
||||
out = h(jnp.arange(4.), 4.)
|
||||
np.testing.assert_allclose(out, np.sin(np.arange(4.)) + 4.)
|
||||
|
||||
@ -725,7 +725,7 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
@jax.jit
|
||||
@jax.vmap
|
||||
def f(x):
|
||||
return jax.pure_callback(cb, x, x, vmap_method="broadcast")
|
||||
return jax.pure_callback(cb, x, x, vmap_method="expand_dims")
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
f(jnp.arange(4.))
|
||||
@ -1007,18 +1007,18 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
with self.assertWarnsRegex(DeprecationWarning, "The vectorized argument"):
|
||||
f(jnp.arange(4.0), vectorized=False)
|
||||
|
||||
def test_vmap_method_broadcast(self):
|
||||
def test_vmap_method_expand_dims(self):
|
||||
def callback(x, y):
|
||||
self.assertTupleEqual(x.shape, (4,))
|
||||
self.assertTupleEqual(y.shape, (1,))
|
||||
return x + y
|
||||
|
||||
def f(x, y):
|
||||
return jax.pure_callback(callback, x, x, y, vmap_method="broadcast")
|
||||
return jax.pure_callback(callback, x, x, y, vmap_method="expand_dims")
|
||||
|
||||
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
|
||||
|
||||
def test_vmap_method_broadcast_fullrank(self):
|
||||
def test_vmap_method_broadcast_all(self):
|
||||
def callback(x, y):
|
||||
self.assertTupleEqual(x.shape, (4,))
|
||||
self.assertTupleEqual(y.shape, (4,))
|
||||
@ -1026,7 +1026,7 @@ class PureCallbackTest(jtu.JaxTestCase):
|
||||
|
||||
def f(x, y):
|
||||
return jax.pure_callback(callback, x, x, y,
|
||||
vmap_method="broadcast_fullrank")
|
||||
vmap_method="broadcast_all")
|
||||
|
||||
jax.vmap(f, in_axes=(0, None))(jnp.arange(4.0), 1.0) # doesn't error
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user