mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add a discussion of sharding to the FFI tutorial.
This commit is contained in:
parent
07f4fd3e51
commit
729418094e
273
docs/ffi.ipynb
273
docs/ffi.ipynb
@ -28,6 +28,24 @@
|
||||
"\n",
|
||||
"The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).\n",
|
||||
"\n",
|
||||
"Because we demonstrate how FFI calls can be sharded at the end of this tutorial, let's first set up our environment to be treated by JAX as having multiple CPUs:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"XLA_FLAGS\"] = \"--xla_force_host_platform_device_count=4\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## A simple example\n",
|
||||
"\n",
|
||||
"To demonstrate the use of the FFI interface, we will implement a simple \"root-mean-square (RMS)\" normalization function.\n",
|
||||
@ -304,7 +322,7 @@
|
||||
"\n",
|
||||
"\n",
|
||||
"# Test that this gives the same result as our reference implementation\n",
|
||||
"x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))\n",
|
||||
"x = jnp.linspace(-0.5, 0.5, 32).reshape((8, 4))\n",
|
||||
"np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)"
|
||||
]
|
||||
},
|
||||
@ -612,8 +630,257 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"and there will be no runtime overhead to using {func}`jax.lax.platform_dependent`, and the compiled program won't include any references to unavailable FFI targets.\n",
|
||||
"and there will be no runtime overhead to using {func}`jax.lax.platform_dependent`, and the compiled program won't include any references to unavailable FFI targets."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Sharding\n",
|
||||
"\n",
|
||||
"Most large scale users of JAX use its APIs for distributed computation across multiple devices.\n",
|
||||
"As discussed in {ref}`sharded-computation`, parallelism in JAX is controlled by sharding data across devices, and most JAX operations can be used within any of the supported parallel programming paradigms (from automatic to fully manual).\n",
|
||||
"But, the story is a little bit more complicated for FFI calls.\n",
|
||||
"Since the internals of an FFI call are opaque to both JAX and XLA, FFI calls won't typically show optimal (or even good) performance when the data are sharded.\n",
|
||||
"\n",
|
||||
"Before getting into the FFI details, let's consider the behavior of our pure-JAX reference implementation of RMS normalization (the `rms_norm_ref` function defined at the top of this document) with a sharded input.\n",
|
||||
"As discussed above, our implementation treats all leading axes of the input as _batch_ dimensions, and the normalization is performed along the last axis.\n",
|
||||
"This means that if the data are sharded along any batch dimensions, but replicated on the last dimension, no communication is required.\n",
|
||||
"This can be seen by sharding our 2-dimensional test data from above along its first dimension and checking the compiled HLO for operations like `all-gather`, `all-reduce`, etc.:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.sharding import PartitionSpec as P\n",
|
||||
"\n",
|
||||
"assert len(jax.devices()) == 4 # Set using the XLA_FLAGS environment variable\n",
|
||||
"mesh = jax.make_mesh((4,), (\"x\",))\n",
|
||||
"\n",
|
||||
"batch_shd = jax.NamedSharding(mesh, P(\"x\", None))\n",
|
||||
"x_batch_shd = jax.device_put(x, batch_shd)\n",
|
||||
"hlo_batch = jax.jit(rms_norm_ref, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text()\n",
|
||||
"assert \"all-\" not in hlo_batch"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"However, if the data are sharded along the last axis, communication (in this case an `all-reduce`) is required to compute the sum in the normalization:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"data_shd = jax.NamedSharding(mesh, P(None, \"x\"))\n",
|
||||
"x_data_shd = jax.device_put(x, data_shd)\n",
|
||||
"hlo_data = jax.jit(rms_norm_ref, out_shardings=data_shd).lower(x_data_shd).compile().as_text()\n",
|
||||
"assert \"all-reduce\" in hlo_data"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Now, if we try to naively use our FFI version of the same model, it runs fine and gets the right answer:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"output = jax.jit(rms_norm, out_shardings=batch_shd)(x_batch_shd)\n",
|
||||
"np.testing.assert_allclose(output, rms_norm_ref(x), rtol=1e-5)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"But, if you look at the compiled HLO (omitting a helper functions for clarity), you'll see that\n",
|
||||
"\n",
|
||||
"1. the data are first fully replicated onto each device via an `all-gather` operation,\n",
|
||||
"2. the FFI call is executed on the full dataset on each device, and\n",
|
||||
"3. the output is sliced to discard the unused portions."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hlo = jax.jit(rms_norm, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip()\n",
|
||||
"print(hlo.split(\"\\n\\n\")[-1])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"This clearly (to us!) isn't the optimal partitioning of this function, but it's the best that JAX/XLA can do with the information given.\n",
|
||||
"\n",
|
||||
"To generate better partitioning logic, we can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here.\n",
|
||||
"That being said, it's not straightforward to generate _optimal_ partitioning for all inputs, because sometimes this would require algorithmic changes.\n",
|
||||
"Specifically, let's add support for \"batch partitioning\", which handles the case where the data are sharded on batch dimensions, but sharding on the last dimension will always require in re-sharding.\n",
|
||||
"\n",
|
||||
"### Using `shard_map`\n",
|
||||
"\n",
|
||||
"If you are using manual sharding control via {func}`~jax.experimental.shard_map.shard_map`, any FFI calls in your program should already partition appropriately:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from functools import partial\n",
|
||||
"from jax.experimental.shard_map import shard_map\n",
|
||||
"\n",
|
||||
"@partial(shard_map, mesh=mesh, in_specs=P(\"x\", None), out_specs=P(\"x\", None))\n",
|
||||
"def rms_norm_shmap(x):\n",
|
||||
" return rms_norm(x)\n",
|
||||
"\n",
|
||||
"np.testing.assert_allclose(rms_norm_shmap(x_batch_shd), rms_norm_ref(x), rtol=1e-5)\n",
|
||||
"print(jax.jit(rms_norm_shmap, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As you can see in this program, if the input and output shardings match the `shard_map` specs, no communication is required and the FFI call is executed on the appropriately sharded subset of the data.\n",
|
||||
"\n",
|
||||
"You can also use inputs and outputs with shardings that don't match the `shard_map` specs, but (unrelated to the FFI) this will require re-sharding, as seen by the `all-to-all` operations in the compiled HLO:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hlo_data_shmap = jax.jit(rms_norm_shmap, out_shardings=data_shd).lower(x_data_shd).compile().as_text()\n",
|
||||
"assert \"all-to-all\" in hlo_data_shmap"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Using `custom partitioning`\n",
|
||||
"\n",
|
||||
"If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`.\n",
|
||||
"{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.\n",
|
||||
"We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:\n",
|
||||
"\n",
|
||||
"1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n",
|
||||
"2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there.\n",
|
||||
"\n",
|
||||
"All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from jax.experimental.custom_partitioning import custom_partitioning\n",
|
||||
"\n",
|
||||
"@partial(custom_partitioning, static_argnums=(1,))\n",
|
||||
"def rms_norm_partitioned(x, eps=1e-5):\n",
|
||||
" return rms_norm(x, eps=eps)\n",
|
||||
"\n",
|
||||
"def replicate_sharding_on_last_dim(mesh, sharding, target_info):\n",
|
||||
" # Our implementation supports trivial sharding on any batch dimensions, but the data\n",
|
||||
" # must be replicated on the last (non-batch) dimension.\n",
|
||||
" rank = len(target_info.shape)\n",
|
||||
" num_batch_dims = min(len(sharding.spec), rank - 1)\n",
|
||||
"\n",
|
||||
" # The Nones here indicate which dimensions should be replicated.\n",
|
||||
" names = tuple(sharding.spec[:num_batch_dims]) + (None,) * (rank - num_batch_dims)\n",
|
||||
" return jax.NamedSharding(mesh, P(*names))\n",
|
||||
"\n",
|
||||
"def rms_norm_infer_sharding_from_operands(eps, mesh, args_info, result_info):\n",
|
||||
" del eps # unused\n",
|
||||
" arg_info, = args_info\n",
|
||||
" result_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, result_info)\n",
|
||||
"\n",
|
||||
" # In this case, we only have a single output, but the return value from this function\n",
|
||||
" # must have the same pytree structure as the output from the underlying function\n",
|
||||
" # (`rms_norm` in this case).\n",
|
||||
" return result_sharding\n",
|
||||
"\n",
|
||||
"def rms_norm_partition(eps, mesh, args_info, result_info):\n",
|
||||
" arg_info, = args_info\n",
|
||||
" arg_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, arg_info)\n",
|
||||
" result_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, result_info)\n",
|
||||
"\n",
|
||||
" # This is the function that computes the partitioned model on the appropriate subset\n",
|
||||
" # of the data.\n",
|
||||
" def partitioned_rms_norm(x):\n",
|
||||
" return rms_norm(x, eps=eps)\n",
|
||||
"\n",
|
||||
" # Note that the third element of our returned tuple must be the shardings for the\n",
|
||||
" # _outputs_ and its pytree structure must match the output of `rms_norm`. Similarly,\n",
|
||||
" # the fourth element must have the same pytree structure as the _inputs_ to\n",
|
||||
" # `rms_norm`. In this case, there is only one input, but it must be returned within\n",
|
||||
" # a `tuple` anyways.\n",
|
||||
" return mesh, partitioned_rms_norm, result_sharding, (arg_sharding,)\n",
|
||||
"\n",
|
||||
"rms_norm_partitioned.def_partition(\n",
|
||||
" infer_sharding_from_operands=rms_norm_infer_sharding_from_operands,\n",
|
||||
" partition=rms_norm_partition,\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"output = jax.jit(rms_norm_partitioned, out_shardings=batch_shd)(x_batch_shd)\n",
|
||||
"np.testing.assert_allclose(output, rms_norm_ref(x), rtol=1e-5)\n",
|
||||
"print(jax.jit(rms_norm_partitioned, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip())"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"As you can see from the compiled program above, this `custom_partitioning` logic produces exactly the same program as the `shard_map` version above when the input is sharded on the batch dimension.\n",
|
||||
"\n",
|
||||
"However, it's worth noting that the behavior is _different_ when the input is sharded along the data dimension.\n",
|
||||
"When used under `shard_map`, the data are resharded on the batch dimension, whereas with `custom_partitioning` the data are gathered onto each device."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"hlo_data_partitioned = jax.jit(rms_norm_partitioned, out_shardings=data_shd).lower(x_data_shd).compile().as_text().strip()\n",
|
||||
"assert \"all-gather\" in hlo_data_partitioned"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"To also support automatic parallelization of the backwards pass, we would also need to write (similar) {func}`~jax.experimental.custom_partitioning.custom_partitioning` rules for `rms_norm_fwd` and `rms_norm_bwd`, but we leave those as an exercise for the reader."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Advanced topics\n",
|
||||
"\n",
|
||||
"This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.\n",
|
||||
@ -621,8 +888,6 @@
|
||||
"\n",
|
||||
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
|
||||
"\n",
|
||||
"* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n",
|
||||
"\n",
|
||||
"* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details."
|
||||
]
|
||||
}
|
||||
|
175
docs/ffi.md
175
docs/ffi.md
@ -36,6 +36,14 @@ We start by presenting the FFI on CPU, and discuss generalizations to GPU or mul
|
||||
|
||||
The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).
|
||||
|
||||
Because we demonstrate how FFI calls can be sharded at the end of this tutorial, let's first set up our environment to be treated by JAX as having multiple CPUs:
|
||||
|
||||
```{code-cell} ipython3
|
||||
import os
|
||||
|
||||
os.environ["XLA_FLAGS"] = "--xla_force_host_platform_device_count=4"
|
||||
```
|
||||
|
||||
## A simple example
|
||||
|
||||
To demonstrate the use of the FFI interface, we will implement a simple "root-mean-square (RMS)" normalization function.
|
||||
@ -265,7 +273,7 @@ def rms_norm(x, eps=1e-5):
|
||||
|
||||
|
||||
# Test that this gives the same result as our reference implementation
|
||||
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
|
||||
x = jnp.linspace(-0.5, 0.5, 32).reshape((8, 4))
|
||||
np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)
|
||||
```
|
||||
|
||||
@ -493,6 +501,169 @@ print(jax.jit(rms_norm_cross_platform).lower(x).as_text(dialect="hlo").strip())
|
||||
|
||||
and there will be no runtime overhead to using {func}`jax.lax.platform_dependent`, and the compiled program won't include any references to unavailable FFI targets.
|
||||
|
||||
+++
|
||||
|
||||
## Sharding
|
||||
|
||||
Most large scale users of JAX use its APIs for distributed computation across multiple devices.
|
||||
As discussed in {ref}`sharded-computation`, parallelism in JAX is controlled by sharding data across devices, and most JAX operations can be used within any of the supported parallel programming paradigms (from automatic to fully manual).
|
||||
But, the story is a little bit more complicated for FFI calls.
|
||||
Since the internals of an FFI call are opaque to both JAX and XLA, FFI calls won't typically show optimal (or even good) performance when the data are sharded.
|
||||
|
||||
Before getting into the FFI details, let's consider the behavior of our pure-JAX reference implementation of RMS normalization (the `rms_norm_ref` function defined at the top of this document) with a sharded input.
|
||||
As discussed above, our implementation treats all leading axes of the input as _batch_ dimensions, and the normalization is performed along the last axis.
|
||||
This means that if the data are sharded along any batch dimensions, but replicated on the last dimension, no communication is required.
|
||||
This can be seen by sharding our 2-dimensional test data from above along its first dimension and checking the compiled HLO for operations like `all-gather`, `all-reduce`, etc.:
|
||||
|
||||
```{code-cell} ipython3
|
||||
from jax.sharding import PartitionSpec as P
|
||||
|
||||
assert len(jax.devices()) == 4 # Set using the XLA_FLAGS environment variable
|
||||
mesh = jax.make_mesh((4,), ("x",))
|
||||
|
||||
batch_shd = jax.NamedSharding(mesh, P("x", None))
|
||||
x_batch_shd = jax.device_put(x, batch_shd)
|
||||
hlo_batch = jax.jit(rms_norm_ref, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text()
|
||||
assert "all-" not in hlo_batch
|
||||
```
|
||||
|
||||
However, if the data are sharded along the last axis, communication (in this case an `all-reduce`) is required to compute the sum in the normalization:
|
||||
|
||||
```{code-cell} ipython3
|
||||
data_shd = jax.NamedSharding(mesh, P(None, "x"))
|
||||
x_data_shd = jax.device_put(x, data_shd)
|
||||
hlo_data = jax.jit(rms_norm_ref, out_shardings=data_shd).lower(x_data_shd).compile().as_text()
|
||||
assert "all-reduce" in hlo_data
|
||||
```
|
||||
|
||||
Now, if we try to naively use our FFI version of the same model, it runs fine and gets the right answer:
|
||||
|
||||
```{code-cell} ipython3
|
||||
output = jax.jit(rms_norm, out_shardings=batch_shd)(x_batch_shd)
|
||||
np.testing.assert_allclose(output, rms_norm_ref(x), rtol=1e-5)
|
||||
```
|
||||
|
||||
But, if you look at the compiled HLO (omitting a helper functions for clarity), you'll see that
|
||||
|
||||
1. the data are first fully replicated onto each device via an `all-gather` operation,
|
||||
2. the FFI call is executed on the full dataset on each device, and
|
||||
3. the output is sliced to discard the unused portions.
|
||||
|
||||
```{code-cell} ipython3
|
||||
hlo = jax.jit(rms_norm, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip()
|
||||
print(hlo.split("\n\n")[-1])
|
||||
```
|
||||
|
||||
This clearly (to us!) isn't the optimal partitioning of this function, but it's the best that JAX/XLA can do with the information given.
|
||||
|
||||
To generate better partitioning logic, we can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here.
|
||||
That being said, it's not straightforward to generate _optimal_ partitioning for all inputs, because sometimes this would require algorithmic changes.
|
||||
Specifically, let's add support for "batch partitioning", which handles the case where the data are sharded on batch dimensions, but sharding on the last dimension will always require in re-sharding.
|
||||
|
||||
### Using `shard_map`
|
||||
|
||||
If you are using manual sharding control via {func}`~jax.experimental.shard_map.shard_map`, any FFI calls in your program should already partition appropriately:
|
||||
|
||||
```{code-cell} ipython3
|
||||
from functools import partial
|
||||
from jax.experimental.shard_map import shard_map
|
||||
|
||||
@partial(shard_map, mesh=mesh, in_specs=P("x", None), out_specs=P("x", None))
|
||||
def rms_norm_shmap(x):
|
||||
return rms_norm(x)
|
||||
|
||||
np.testing.assert_allclose(rms_norm_shmap(x_batch_shd), rms_norm_ref(x), rtol=1e-5)
|
||||
print(jax.jit(rms_norm_shmap, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip())
|
||||
```
|
||||
|
||||
As you can see in this program, if the input and output shardings match the `shard_map` specs, no communication is required and the FFI call is executed on the appropriately sharded subset of the data.
|
||||
|
||||
You can also use inputs and outputs with shardings that don't match the `shard_map` specs, but (unrelated to the FFI) this will require re-sharding, as seen by the `all-to-all` operations in the compiled HLO:
|
||||
|
||||
```{code-cell} ipython3
|
||||
hlo_data_shmap = jax.jit(rms_norm_shmap, out_shardings=data_shd).lower(x_data_shd).compile().as_text()
|
||||
assert "all-to-all" in hlo_data_shmap
|
||||
```
|
||||
|
||||
### Using `custom partitioning`
|
||||
|
||||
If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`.
|
||||
{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.
|
||||
We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:
|
||||
|
||||
1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.
|
||||
2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there.
|
||||
|
||||
All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`:
|
||||
|
||||
```{code-cell} ipython3
|
||||
from jax.experimental.custom_partitioning import custom_partitioning
|
||||
|
||||
@partial(custom_partitioning, static_argnums=(1,))
|
||||
def rms_norm_partitioned(x, eps=1e-5):
|
||||
return rms_norm(x, eps=eps)
|
||||
|
||||
def replicate_sharding_on_last_dim(mesh, sharding, target_info):
|
||||
# Our implementation supports trivial sharding on any batch dimensions, but the data
|
||||
# must be replicated on the last (non-batch) dimension.
|
||||
rank = len(target_info.shape)
|
||||
num_batch_dims = min(len(sharding.spec), rank - 1)
|
||||
|
||||
# The Nones here indicate which dimensions should be replicated.
|
||||
names = tuple(sharding.spec[:num_batch_dims]) + (None,) * (rank - num_batch_dims)
|
||||
return jax.NamedSharding(mesh, P(*names))
|
||||
|
||||
def rms_norm_infer_sharding_from_operands(eps, mesh, args_info, result_info):
|
||||
del eps # unused
|
||||
arg_info, = args_info
|
||||
result_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, result_info)
|
||||
|
||||
# In this case, we only have a single output, but the return value from this function
|
||||
# must have the same pytree structure as the output from the underlying function
|
||||
# (`rms_norm` in this case).
|
||||
return result_sharding
|
||||
|
||||
def rms_norm_partition(eps, mesh, args_info, result_info):
|
||||
arg_info, = args_info
|
||||
arg_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, arg_info)
|
||||
result_sharding = replicate_sharding_on_last_dim(mesh, arg_info.sharding, result_info)
|
||||
|
||||
# This is the function that computes the partitioned model on the appropriate subset
|
||||
# of the data.
|
||||
def partitioned_rms_norm(x):
|
||||
return rms_norm(x, eps=eps)
|
||||
|
||||
# Note that the third element of our returned tuple must be the shardings for the
|
||||
# _outputs_ and its pytree structure must match the output of `rms_norm`. Similarly,
|
||||
# the fourth element must have the same pytree structure as the _inputs_ to
|
||||
# `rms_norm`. In this case, there is only one input, but it must be returned within
|
||||
# a `tuple` anyways.
|
||||
return mesh, partitioned_rms_norm, result_sharding, (arg_sharding,)
|
||||
|
||||
rms_norm_partitioned.def_partition(
|
||||
infer_sharding_from_operands=rms_norm_infer_sharding_from_operands,
|
||||
partition=rms_norm_partition,
|
||||
)
|
||||
|
||||
output = jax.jit(rms_norm_partitioned, out_shardings=batch_shd)(x_batch_shd)
|
||||
np.testing.assert_allclose(output, rms_norm_ref(x), rtol=1e-5)
|
||||
print(jax.jit(rms_norm_partitioned, out_shardings=batch_shd).lower(x_batch_shd).compile().as_text().strip())
|
||||
```
|
||||
|
||||
As you can see from the compiled program above, this `custom_partitioning` logic produces exactly the same program as the `shard_map` version above when the input is sharded on the batch dimension.
|
||||
|
||||
However, it's worth noting that the behavior is _different_ when the input is sharded along the data dimension.
|
||||
When used under `shard_map`, the data are resharded on the batch dimension, whereas with `custom_partitioning` the data are gathered onto each device.
|
||||
|
||||
```{code-cell} ipython3
|
||||
hlo_data_partitioned = jax.jit(rms_norm_partitioned, out_shardings=data_shd).lower(x_data_shd).compile().as_text().strip()
|
||||
assert "all-gather" in hlo_data_partitioned
|
||||
```
|
||||
|
||||
To also support automatic parallelization of the backwards pass, we would also need to write (similar) {func}`~jax.experimental.custom_partitioning.custom_partitioning` rules for `rms_norm_fwd` and `rms_norm_bwd`, but we leave those as an exercise for the reader.
|
||||
|
||||
+++
|
||||
|
||||
## Advanced topics
|
||||
|
||||
This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.
|
||||
@ -500,6 +671,4 @@ We will leave these topics to future tutorials, but here are some possibly usefu
|
||||
|
||||
* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.
|
||||
|
||||
* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.
|
||||
|
||||
* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details.
|
||||
|
Loading…
x
Reference in New Issue
Block a user