"_This tutorial requires JAX v0.4.31 or newer._\n",
"\n",
"While a wide range of numerical operations can be easily and efficiently implemented using JAX's built in `jax.numpy` and `jax.lax` interfaces, it can sometimes be useful to explicitly call out to external compiled libraries via a \"foreign function interface\" (FFI).\n",
"This can be particularly useful when particular operations have been previously implemented in an optimized C or CUDA library, and it would be non-trivial to reimplement these computations directly using JAX, but it can also be useful for optimizing runtime or memory performance of JAX programs.\n",
"That being said, the FFI should typically be considered a last resort option because the XLA compiler that sits in the backend, or the Pallas kernel language, which provides lower level control, typically produce performant code with a lower development and maintenance cost.\n",
"\n",
"One point that should be taken into account when considering use of the FFI is that _JAX doesn't automatically know how to differentiate through foreign functions_.\n",
"This means that if you want to use JAX's autodifferentiation capabilities alongside a foreign function, you'll also need to provide an implementation of the relevant differentiation rules.\n",
"We will discuss some possible approaches below, but it is important to call this limitation out right from the start!\n",
"\n",
"JAX's FFI support is provided in two parts:\n",
"\n",
"1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and\n",
"In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n",
"We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n",
"The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).\n",
"Because we demonstrate how FFI calls can be sharded at the end of this tutorial, let's first set up our environment to be treated by JAX as having multiple CPUs:"
"But, it's just non-trivial enough to be useful for demonstrating some key details of the FFI, while still being straightforward to understand.\n",
"We will use this reference implementation to test our FFI version below.\n",
"\n",
"## Backend code\n",
"\n",
"To begin with, we need an implementation of RMS normalization in C++ that we will expose using the FFI.\n",
"This isn't meant to be particularly performant, but you could imagine that if you had some new better implementation of RMS normalization in a C++ library, it might have an interface like the following.\n",
"So, here's a simple implementation of RMS normalization in C++:\n",
"and, for our example, this is the function that we want to expose to JAX via the FFI."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### C++ interface\n",
"\n",
"To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n",
"For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).\n",
"The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:\n",
"Starting at the bottom, we're using the XLA-provided macro `XLA_FFI_DEFINE_HANDLER_SYMBOL` to generate some boilerplate which will expand into a function called `RmsNorm` with the appropriate signature.\n",
"But, the important stuff here is all in the call to `ffi::Ffi::Bind()`, where we define the input and output types, and the types of any parameters.\n",
"\n",
"Then, in `RmsNormImpl`, we accept `ffi::Buffer` arguments which include information about the buffer shape, and pointers to the underlying data.\n",
"In this implementation, we treat all leading dimensions of the buffer as batch dimensions, and perform RMS normalization over the last axis.\n",
"`GetDims` is a helper function providing support for this batching behavior.\n",
"We discuss this batching behavior in more detail [below](ffi-call-vmap), but the general idea is that it can be useful to transparently handle batching in the left-most dimensions of the input arguments.\n",
"In this case, we treat all but the last axis as batch dimensions, but other foreign functions may require a different number of non-batch dimensions."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Building and registering an FFI handler\n",
"\n",
"Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n",
"In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.\n",
"This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html).\n",
"If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n",
"A common alternative pattern for exposing handlers to Python is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) to define a tiny Python extension which can be imported.\n",
"For our example here, the nanobind code would be:\n",
"\n",
"```c++\n",
"#include <type_traits>\n",
"\n",
"#include \"nanobind/nanobind.h\"\n",
"#include \"xla/ffi/api/c_api.h\"\n",
"\n",
"namespace nb = nanobind;\n",
"\n",
"template <typename T>\n",
"nb::capsule EncapsulateFfiCall(T *fn) {\n",
" // This check is optional, but it can be helpful for avoiding invalid handlers.\n",
"This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.\n",
"Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n",
"It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n",
"Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n",
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`.\n",
"In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.\n",
"One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n",
"{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`.\n",
"Many FFI calls provide more efficient batching behavior and, in some simple cases, the `\"expand_dims\"` or `\"broadcast_all\"` methods can be used to expose a better implementation.\n",
"Another way of saying this is that the result of calling `ffi_call` on the batched inputs is assumed to be equal to stacking the repeated application of `ffi_call` to each element in the batched input, roughly:\n",
"\n",
"```python\n",
"ffi_call(xs) == jnp.stack([ffi_call(x) for x in xs])\n",
"For simplicity, we will use the `\"broadcast_all\"` throughout this tutorial, which guarantees that all inputs will be broadcasted to have the same batch dimensions, but it would also be possible to implement a foreign function to handle the `\"expand_dims\"` method.\n",
"We can inspect the [jaxpr](jax-internals-jaxpr) of the {func}`~jax.vmap` of `rms_norm` to confirm that it isn't being rewritten using {func}`~jax.lax.scan`:"
"If your foreign function provides an efficient batching rule that isn't supported by this simple `vmap_method` parameter, it might also be possible to define more flexible custom `vmap` rules using the experimental `custom_vmap` interface, but it's worth also opening an issue describing your use case on [the JAX issue tracker](https://github.com/jax-ml/jax/issues)."
"As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n",
"More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n",
"In this case, we actually define two new FFI calls:\n",
"\n",
"1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n",
"2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n",
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n",
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
"At this point, we can use our new `rms_norm` function transparently for many JAX applications, and it will transform appropriately under the standard JAX function transformations like {func}`~jax.vmap` and {func}`~jax.grad`.\n",
"One thing that this example doesn't support is forward-mode AD ({func}`jax.jvp`, for example) since {func}`~jax.custom_vjp` is restricted to reverse-mode.\n",
"JAX doesn't currently expose a public API for simultaneously customizing both forward-mode and reverse-mode AD, but such an API is on the roadmap, so please [open an issue](https://github.com/jax-ml/jax/issues) describing you use case if you hit this limitation in practice.\n",
"One other JAX feature that this example doesn't support is higher-order AD.\n",
"It would be possible to work around this by wrapping the `res_norm_bwd` function above in a {func}`jax.custom_jvp` or {func}`jax.custom_vjp` decorator, but we won't go into the details of that advanced use case here.\n",
"\n",
"## FFI calls on a GPU\n",
"\n",
"So far, we have been interfacing only with foreign functions running on the CPU, but JAX's FFI also supports calls to GPU code.\n",
"Since this documentation page is automatically generated on a machine without access to a GPU, we can't execute any GPU-specific examples here, but we will go over the key points.\n",
"\n",
"When defining our FFI wrapper for CPU, the function signature that we used was:\n",
"To support running our `rms_norm` function on both GPU and CPU, we can combine our implementation above with the {func}`jax.lax.platform_dependent` function:"
"and there will be no runtime overhead to using {func}`jax.lax.platform_dependent`, and the compiled program won't include any references to unavailable FFI targets."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sharding\n",
"\n",
"Most large scale users of JAX use its APIs for distributed computation across multiple devices.\n",
"As discussed in {ref}`sharded-computation`, parallelism in JAX is controlled by sharding data across devices, and most JAX operations can be used within any of the supported parallel programming paradigms (from automatic to fully manual).\n",
"But, the story is a little bit more complicated for FFI calls.\n",
"Since the internals of an FFI call are opaque to both JAX and XLA, FFI calls won't typically show optimal (or even good) performance when the data are sharded.\n",
"\n",
"Before getting into the FFI details, let's consider the behavior of our pure-JAX reference implementation of RMS normalization (the `rms_norm_ref` function defined at the top of this document) with a sharded input.\n",
"As discussed above, our implementation treats all leading axes of the input as _batch_ dimensions, and the normalization is performed along the last axis.\n",
"This means that if the data are sharded along any batch dimensions, but replicated on the last dimension, no communication is required.\n",
"This can be seen by sharding our 2-dimensional test data from above along its first dimension and checking the compiled HLO for operations like `all-gather`, `all-reduce`, etc.:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from jax.sharding import PartitionSpec as P\n",
"\n",
"assert len(jax.devices()) == 4 # Set using the XLA_FLAGS environment variable\n",
"However, if the data are sharded along the last axis, communication (in this case an `all-reduce`) is required to compute the sum in the normalization:"
"This clearly (to us!) isn't the optimal partitioning of this function, but it's the best that JAX/XLA can do with the information given.\n",
"\n",
"To generate better partitioning logic, we can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`, and we discuss both options here.\n",
"That being said, it's not straightforward to generate _optimal_ partitioning for all inputs, because sometimes this would require algorithmic changes.\n",
"Specifically, let's add support for \"batch partitioning\", which handles the case where the data are sharded on batch dimensions, but sharding on the last dimension will always require in re-sharding.\n",
"\n",
"### Using `shard_map`\n",
"\n",
"If you are using manual sharding control via {func}`~jax.experimental.shard_map.shard_map`, any FFI calls in your program should already partition appropriately:"
"As you can see in this program, if the input and output shardings match the `shard_map` specs, no communication is required and the FFI call is executed on the appropriately sharded subset of the data.\n",
"\n",
"You can also use inputs and outputs with shardings that don't match the `shard_map` specs, but (unrelated to the FFI) this will require re-sharding, as seen by the `all-to-all` operations in the compiled HLO:"
"If you can't use {func}`~jax.experimental.shard_map.shard_map`, an alternative approach is to use {func}`~jax.experimental.custom_partitioning.custom_partitioning`, which supports automatic parallelization via {func}`jax.jit`.\n",
"{func}`~jax.experimental.custom_partitioning.custom_partitioning` works by adding Python callbacks into the XLA compiler's partitioning pass, which allows very flexible logic, but also comes with some rough edges.\n",
"We won't go into too much detail on the caveats here, but the main issues that you should be aware of are:\n",
"\n",
"1. `custom_partitioning` can cause unexpected cache misses when used with the JAX's [Persistent compilation cache](https://jax.readthedocs.io/en/latest/persistent_compilation_cache.html). This can be mitigated using the `jax_remove_custom_partitioning_ptr_from_cache_key` configuration flag, but that isn't always appropriate either.\n",
"2. Debugging `custom_partitioning` logic can be tedious because Python errors don't always get propagated, instead causing your Python process to exit. That being said, any exceptions will show up in the process logs, so you should be able to track them down there.\n",
"\n",
"All that being said, here's how we can wrap our FFI implementation of `rms_norm` using {func}`~jax.experimental.custom_partitioning.custom_partitioning`:"
"As you can see from the compiled program above, this `custom_partitioning` logic produces exactly the same program as the `shard_map` version above when the input is sharded on the batch dimension.\n",
"However, it's worth noting that the behavior is _different_ when the input is sharded along the data dimension.\n",
"When used under `shard_map`, the data are resharded on the batch dimension, whereas with `custom_partitioning` the data are gathered onto each device."
"To also support automatic parallelization of the backwards pass, we would also need to write (similar) {func}`~jax.experimental.custom_partitioning.custom_partitioning` rules for `rms_norm_fwd` and `rms_norm_bwd`, but we leave those as an exercise for the reader."
"This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.\n",
"We will leave these topics to future tutorials, but here are some possibly useful references:\n",
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
"* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details."