Move jex.ffi to jax.ffi.

This commit is contained in:
Dan Foreman-Mackey 2024-12-20 11:26:04 +00:00
parent 8eeedd1802
commit cb4d97aa1f
22 changed files with 528 additions and 417 deletions

View File

@ -30,6 +30,8 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings`
are now deprecated, having been replaced by symbols of the same name
in {mod}`jax.core`.
* The {mod}`jax.extend.ffi` submodule was moved to {mod}`jax.ffi`, and the
previous import path is deprecated.
* Deletions
* `jax_enable_memories` flag has been deleted and the behavior of that flag

View File

@ -362,4 +362,5 @@ rediraffe_redirects = {
'jax-101/index.rst': 'tutorials.rst',
'notebooks/external_callbacks.md': 'external-callbacks.md',
'notebooks/How_JAX_primitives_work.md': 'jax-primitives.md',
'jax.extend.ffi.rst': 'jax.ffi.rst',
}

View File

@ -21,7 +21,7 @@
"JAX's FFI support is provided in two parts:\n",
"\n",
"1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and\n",
"2. A Python front end, available in the `jax.extend.ffi` submodule.\n",
"2. A Python front end, available in the `jax.ffi` submodule.\n",
"\n",
"In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n",
"We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n",
@ -191,9 +191,9 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function.\n",
"With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.ffi.register_ffi_target` function.\n",
"This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html).\n",
"JAX provides a helper function {func}`~jax.extend.ffi.pycapsule` to help with this:"
"JAX provides a helper function {func}`~jax.ffi.pycapsule` to help with this:"
]
},
{
@ -204,12 +204,11 @@
"source": [
"import ctypes\n",
"from pathlib import Path\n",
"import jax.extend as jex\n",
"\n",
"path = next(Path(\"ffi\").glob(\"librms_norm*\"))\n",
"rms_norm_lib = ctypes.cdll.LoadLibrary(path)\n",
"jex.ffi.register_ffi_target(\n",
" \"rms_norm\", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")"
"jax.ffi.register_ffi_target(\n",
" \"rms_norm\", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform=\"cpu\")"
]
},
{
@ -217,7 +216,7 @@
"metadata": {},
"source": [
"```{tip}\n",
"If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n",
"If you're familiar with the legacy \"custom call\" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new \"typed\" FFI API that we're using here.\n",
"```\n",
"\n",
"**An alternative approach**:\n",
@ -251,7 +250,7 @@
"# Assuming that we compiled a nanobind extension called `rms_norm`:\n",
"import rms_norm as rms_norm_lib\n",
"\n",
"jex.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n",
"jax.ffi.register_ffi_target(\"rms_norm\", rms_norm_lib.rms_norm(), platform=\"cpu\")\n",
"```"
]
},
@ -261,7 +260,7 @@
"source": [
"## Frontend code\n",
"\n",
"Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:"
"Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.ffi.ffi_call` function:"
]
},
{
@ -282,7 +281,7 @@
" if x.dtype != jnp.float32:\n",
" raise ValueError(\"Only the float32 dtype is implemented by rms_norm\")\n",
"\n",
" call = jex.ffi.ffi_call(\n",
" call = jax.ffi.ffi_call(\n",
" # The target name must be the same string as we used to register the target\n",
" # above in `register_custom_call_target`\n",
" \"rms_norm\",\n",
@ -314,25 +313,25 @@
"metadata": {},
"source": [
"This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.\n",
"Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n",
"It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n",
"Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.\n",
"It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.\n",
"\n",
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.\n",
"Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.ffi.ffi_call`.\n",
"Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.\n",
"\n",
"The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"The `vmap_method` argument to {func}`~jax.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.\n",
"\n",
"```{tip}\n",
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.\n",
"If you are familiar with the earlier \"custom call\" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`.\n",
"In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.\n",
"One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n",
"One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.\n",
"```\n",
"\n",
"(ffi-call-vmap)=\n",
"### Batching with `vmap`\n",
"\n",
"{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.\n",
"{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.\n",
"The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`.\n",
"\n",
"The simplest `vmap_method` is `\"sequential\"`.\n",
"In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.\n",
@ -395,7 +394,7 @@
"outputs": [],
"source": [
"def rms_norm_sequential(x, eps=1e-5):\n",
" return jex.ffi.ffi_call(\n",
" return jax.ffi.ffi_call(\n",
" \"rms_norm\",\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
" vmap_method=\"sequential\",\n",
@ -418,9 +417,9 @@
"source": [
"### Differentiation\n",
"\n",
"Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n",
"Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.\n",
"As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.\n",
"Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n",
"Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.\n",
"\n",
"More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.\n",
"In this case, we actually define two new FFI calls:\n",
@ -429,7 +428,7 @@
"2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n",
"\n",
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n",
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
"\n",
"This custom derivative rule can be wired in as follows:"
]
@ -440,16 +439,16 @@
"metadata": {},
"outputs": [],
"source": [
"jex.ffi.register_ffi_target(\n",
" \"rms_norm_fwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n",
"jax.ffi.register_ffi_target(\n",
" \"rms_norm_fwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform=\"cpu\"\n",
")\n",
"jex.ffi.register_ffi_target(\n",
" \"rms_norm_bwd\", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n",
"jax.ffi.register_ffi_target(\n",
" \"rms_norm_bwd\", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform=\"cpu\"\n",
")\n",
"\n",
"\n",
"def rms_norm_fwd(x, eps=1e-5):\n",
" y, res = jex.ffi.ffi_call(\n",
" y, res = jax.ffi.ffi_call(\n",
" \"rms_norm_fwd\",\n",
" (\n",
" jax.ShapeDtypeStruct(x.shape, x.dtype),\n",
@ -466,7 +465,7 @@
" assert res.shape == ct.shape[:-1]\n",
" assert x.shape == ct.shape\n",
" return (\n",
" jex.ffi.ffi_call(\n",
" jax.ffi.ffi_call(\n",
" \"rms_norm_bwd\",\n",
" jax.ShapeDtypeStruct(ct.shape, ct.dtype),\n",
" vmap_method=\"broadcast_all\",\n",
@ -533,7 +532,7 @@
"On the front end, the registration code would be updated to specify the appropriate platform:\n",
"\n",
"```python\n",
"jex.ffi.register_ffi_target(\n",
"jax.ffi.register_ffi_target(\n",
" \"rms_norm_cuda\", rms_norm_lib_cuda.rms_norm(), platform=\"CUDA\"\n",
")\n",
"```\n",
@ -554,7 +553,7 @@
" out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)\n",
"\n",
" def impl(target_name):\n",
" return lambda x: jex.ffi.ffi_call(\n",
" return lambda x: jax.ffi.ffi_call(\n",
" target_name,\n",
" out_type,\n",
" vmap_method=\"broadcast_all\",\n",
@ -620,9 +619,9 @@
"This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.\n",
"We will leave these topics to future tutorials, but here are some possibly useful references:\n",
"\n",
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
"* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.\n",
"\n",
"* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n",
"* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.\n",
"\n",
"* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details."
]

View File

@ -29,7 +29,7 @@ We will discuss some possible approaches below, but it is important to call this
JAX's FFI support is provided in two parts:
1. A header-only C++ library from XLA which is packaged as part of JAX as of v0.4.29 or available from the [openxla/xla](https://github.com/openxla/xla) project, and
2. A Python front end, available in the `jax.extend.ffi` submodule.
2. A Python front end, available in the `jax.ffi` submodule.
In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.
We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.
@ -171,23 +171,22 @@ To compile the shared library, we're using CMake here, but you should be able to
!cmake --install ffi/_build
```
With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.extend.ffi.register_ffi_target` function.
With this compiled library in hand, we now need to register this handler with XLA via the {func}`~jax.ffi.register_ffi_target` function.
This function expects our handler (a function pointer to the C++ function `RmsNorm`) to be wrapped in a [`PyCapsule`](https://docs.python.org/3/c-api/capsule.html).
JAX provides a helper function {func}`~jax.extend.ffi.pycapsule` to help with this:
JAX provides a helper function {func}`~jax.ffi.pycapsule` to help with this:
```{code-cell} ipython3
import ctypes
from pathlib import Path
import jax.extend as jex
path = next(Path("ffi").glob("librms_norm*"))
rms_norm_lib = ctypes.cdll.LoadLibrary(path)
jex.ffi.register_ffi_target(
"rms_norm", jex.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
jax.ffi.register_ffi_target(
"rms_norm", jax.ffi.pycapsule(rms_norm_lib.RmsNorm), platform="cpu")
```
```{tip}
If you're familiar with the legacy "custom call" API, it's worth noting that you can also use {func}`~jax.extend.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.extend.ffi.register_ffi_target` is `1`, the new "typed" FFI API that we're using here.
If you're familiar with the legacy "custom call" API, it's worth noting that you can also use {func}`~jax.ffi.register_ffi_target` to register a custom call target by manually specifying the keyword argument `api_version=0`. The default `api_version` for {func}`~jax.ffi.register_ffi_target` is `1`, the new "typed" FFI API that we're using here.
```
**An alternative approach**:
@ -221,14 +220,14 @@ Then, in Python we can register this handler using:
# Assuming that we compiled a nanobind extension called `rms_norm`:
import rms_norm as rms_norm_lib
jex.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")
jax.ffi.register_ffi_target("rms_norm", rms_norm_lib.rms_norm(), platform="cpu")
```
+++
## Frontend code
Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.extend.ffi.ffi_call` function:
Now that we have registered our FFI handler, it is straightforward to call our C++ library from JAX using the {func}`~jax.ffi.ffi_call` function:
```{code-cell} ipython3
import numpy as np
@ -243,7 +242,7 @@ def rms_norm(x, eps=1e-5):
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")
call = jex.ffi.ffi_call(
call = jax.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_custom_call_target`
"rms_norm",
@ -271,25 +270,25 @@ np.testing.assert_allclose(rms_norm(x), rms_norm_ref(x), rtol=1e-5)
```
This code cell includes a lot of inline comments which should explain most of what is happening here, but there are a few points that are worth explicitly highlighting.
Most of the heavy lifting here is done by the {func}`~jax.extend.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.
It's important to note that the first argument to {func}`~jax.extend.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.
Most of the heavy lifting here is done by the {func}`~jax.ffi.ffi_call` function, which tells JAX how to call the foreign function for a particular set of inputs.
It's important to note that the first argument to {func}`~jax.ffi.ffi_call` must be a string that matches the target name that we used when calling `register_custom_call_target` above.
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.extend.ffi.ffi_call`.
Any attributes (defined using `Attr` in the C++ wrapper above) should be passed as keyword arguments to {func}`~jax.ffi.ffi_call`.
Note that we explicitly cast `eps` to `np.float32` because our FFI library expects a C `float`, and we can't use `jax.numpy` here, because these parameters must be static arguments.
The `vmap_method` argument to {func}`~jax.extend.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
The `vmap_method` argument to {func}`~jax.ffi.ffi_call` defines how this FFI call interacts with {func}`~jax.vmap` as described next.
```{tip}
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.extend.ffi.ffi_call`.
If you are familiar with the earlier "custom call" interface, you might be surprised that we're not passing the problem dimensions as parameters (batch size, etc.) to {func}`~jax.ffi.ffi_call`.
In this earlier API, the backend had no mechanism for receiving metadata about the input arrays, but since the FFI includes dimension information with the `Buffer` objects, we no longer need to compute this using Python when lowering.
One major perk of this change is {func}`~jax.extend.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.
One major perk of this change is {func}`~jax.ffi.ffi_call` can support some simple {func}`~jax.vmap` semantics out of the box, as discussed below.
```
(ffi-call-vmap)=
### Batching with `vmap`
{func}`~jax.extend.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.extend.ffi.ffi_call`.
{func}`~jax.ffi.ffi_call` supports some simple {func}`~jax.vmap` semantics out of the box using the `vmap_method` parameter.
The docs for {func}`~jax.pure_callback` provide more details about the `vmap_method` parameter, and the same behavior applies to {func}`~jax.ffi.ffi_call`.
The simplest `vmap_method` is `"sequential"`.
In this case, when `vmap`ped, an `ffi_call` will be rewritten as a {func}`~jax.lax.scan` with the `ffi_call` in the body.
@ -326,7 +325,7 @@ Using `vmap_method="sequential"`, `vmap`ping a `ffi_call` will fall back on a {f
```{code-cell} ipython3
def rms_norm_sequential(x, eps=1e-5):
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
"rms_norm",
jax.ShapeDtypeStruct(x.shape, x.dtype),
vmap_method="sequential",
@ -342,9 +341,9 @@ If your foreign function provides an efficient batching rule that isn't supporte
### Differentiation
Unlike with batching, {func}`~jax.extend.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.
Unlike with batching, {func}`~jax.ffi.ffi_call` doesn't provide any default support for automatic differentiation (AD) of foreign functions.
As far as JAX is concerned, the foreign function is a black box that can't be inspected to determine the appropriate behavior when differentiated.
Therefore, it is the {func}`~jax.extend.ffi.ffi_call` user's responsibility to define a custom derivative rule.
Therefore, it is the {func}`~jax.ffi.ffi_call` user's responsibility to define a custom derivative rule.
More details about custom derivative rules can be found in the [custom derivatives tutorial](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html), but the most common pattern used for implementing differentiation for foreign functions is to define a {func}`~jax.custom_vjp` which itself calls a foreign function.
In this case, we actually define two new FFI calls:
@ -353,21 +352,21 @@ In this case, we actually define two new FFI calls:
2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.
We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.
The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.
The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.
This custom derivative rule can be wired in as follows:
```{code-cell} ipython3
jex.ffi.register_ffi_target(
"rms_norm_fwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
jax.ffi.register_ffi_target(
"rms_norm_fwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormFwd), platform="cpu"
)
jex.ffi.register_ffi_target(
"rms_norm_bwd", jex.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
jax.ffi.register_ffi_target(
"rms_norm_bwd", jax.ffi.pycapsule(rms_norm_lib.RmsNormBwd), platform="cpu"
)
def rms_norm_fwd(x, eps=1e-5):
y, res = jex.ffi.ffi_call(
y, res = jax.ffi.ffi_call(
"rms_norm_fwd",
(
jax.ShapeDtypeStruct(x.shape, x.dtype),
@ -384,7 +383,7 @@ def rms_norm_bwd(eps, res, ct):
assert res.shape == ct.shape[:-1]
assert x.shape == ct.shape
return (
jex.ffi.ffi_call(
jax.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
vmap_method="broadcast_all",
@ -447,7 +446,7 @@ Then, the `RmsNormImpl` can use the CUDA stream to launch CUDA kernels.
On the front end, the registration code would be updated to specify the appropriate platform:
```python
jex.ffi.register_ffi_target(
jax.ffi.register_ffi_target(
"rms_norm_cuda", rms_norm_lib_cuda.rms_norm(), platform="CUDA"
)
```
@ -462,7 +461,7 @@ def rms_norm_cross_platform(x, eps=1e-5):
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
def impl(target_name):
return lambda x: jex.ffi.ffi_call(
return lambda x: jax.ffi.ffi_call(
target_name,
out_type,
vmap_method="broadcast_all",
@ -499,8 +498,8 @@ and there will be no runtime overhead to using {func}`jax.lax.platform_dependent
This tutorial covers most of the basic steps that are required to get up and running with JAX's FFI, but advanced use cases may require more features.
We will leave these topics to future tutorials, but here are some possibly useful references:
* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.extend.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.
* **Supporting multiple dtypes**: In this tutorial's example, we restricted to only support `float32` inputs and outputs, but many use cases require supporting multiple different input types. One option to handle this is to register different FFI targets for all supported input types and then use Python to select the appropriate target for {func}`jax.ffi.ffi_call` depending on the input types. But, this approach could get quickly unwieldy depending on the combinatorics of the supported cases. So it is also possible to define the C++ handler to accept `ffi::AnyBuffer` instead of `ffi::Buffer<Dtype>`. Then, the input buffer will include a `element_type()` method which can be used to define the appropriate dtype dispatching logic in the backend.
* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.extend.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.
* **Sharding**: When using JAX's automatic data-dependent parallelism within {func}`~jax.jit`, FFI calls implemented using {func}`~jax.ffi.ffi_call` don't have sufficient information to shard appropriately, so they result in a copy of the inputs to all devices and the FFI call gets executed on the full array on each device. To get around this limitation, you can use {func}`~jax.experimental.shard_map.shard_map` or {func}`~jax.experimental.custom_partitioning.custom_partitioning`.
* **Stateful foreign functions**: It is also possible to use the FFI to wrap functions with associated state. There is a [low-level example included in the XLA test suite](https://github.com/openxla/xla/blob/737a7da3c5405583dc95773ac0bb11b1349fc9ea/xla/service/gpu/custom_call_test.cc#L794-L845), and a future tutorial will include more details.

View File

@ -1,12 +0,0 @@
``jax.extend.ffi`` module
=========================
.. automodule:: jax.extend.ffi
.. autosummary::
:toctree: _autosummary
ffi_call
ffi_lowering
pycapsule
register_ffi_target

View File

@ -12,7 +12,6 @@ Modules
:maxdepth: 1
jax.extend.core
jax.extend.ffi
jax.extend.linear_util
jax.extend.mlir
jax.extend.random

31
docs/jax.ffi.rst Normal file
View File

@ -0,0 +1,31 @@
``jax.ffi`` module
==================
.. automodule:: jax.ffi
.. autosummary::
:toctree: _autosummary
ffi_call
ffi_lowering
pycapsule
register_ffi_target
``jax.extend.ffi`` module (deprecated)
======================================
The ``jax.extend.ffi`` module has been moved to ``jax.ffi``, and that import
path should be used instead, but these functions remain documented here while
the legacy import is being deprecated.
.. automodule:: jax.extend.ffi
.. autosummary::
:toctree: _autosummary
ffi_call
ffi_lowering
pycapsule
register_ffi_target

View File

@ -18,6 +18,7 @@ Subpackages
jax.dlpack
jax.distributed
jax.dtypes
jax.ffi
jax.flatten_util
jax.image
jax.nn

View File

@ -15,28 +15,27 @@
import numpy as np
import jax
import jax.extend as jex
from jax_ffi_example import _cpu_examples
for name, target in _cpu_examples.registrations().items():
jex.ffi.register_ffi_target(name, target)
jax.ffi.register_ffi_target(name, target)
def array_attr(num: int):
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
"array_attr",
jax.ShapeDtypeStruct((), np.int32),
)(array=np.arange(num, dtype=np.int32))
def dictionary_attr(**kwargs):
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
"dictionary_attr",
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
)(**kwargs)
def counter(index):
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))

View File

@ -24,15 +24,14 @@ import numpy as np
import jax
import jax.numpy as jnp
import jax.extend as jex
# Load the shared library with the FFI target definitions
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so")
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd),
jax.ffi.register_ffi_target("foo-fwd", jax.ffi.pycapsule(library.FooFwd),
platform="CUDA")
jex.ffi.register_ffi_target("foo-bwd", jex.ffi.pycapsule(library.FooBwd),
jax.ffi.register_ffi_target("foo-bwd", jax.ffi.pycapsule(library.FooBwd),
platform="CUDA")
@ -42,7 +41,7 @@ def foo_fwd(a, b):
assert a.dtype == b.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
c, b_plus_1 = jex.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n)
c, b_plus_1 = jax.ffi.ffi_call("foo-fwd", (out_type, out_type))(a, b, n=n)
return c, (a, b_plus_1)
@ -55,7 +54,7 @@ def foo_bwd(res, c_grad):
assert a.dtype == b_plus_1.dtype
n = np.prod(a.shape).astype(np.uint64)
out_type = jax.ShapeDtypeStruct(a.shape, a.dtype)
return jex.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1,
return jax.ffi.ffi_call("foo-bwd", (out_type, out_type))(c_grad, a, b_plus_1,
n=n)

View File

@ -16,7 +16,7 @@
This example is exactly the same as the one in the `FFI tutorial
<https://jax.readthedocs.io/en/latest/ffi.html>`, so more details can be found
on that page. But, the high level summary is that we implement our custom
extension in ``rms_norm.cc``, then call it usin ``jax.extend.ffi.ffi_call`` in
extension in ``rms_norm.cc``, then call it usin ``jax.ffi.ffi_call`` in
this module. The behavior under autodiff is implemented using
``jax.custom_vjp``.
"""
@ -26,13 +26,12 @@ from functools import partial
import numpy as np
import jax
import jax.extend as jex
import jax.numpy as jnp
from jax_ffi_example import _rms_norm
for name, target in _rms_norm.registrations().items():
jex.ffi.register_ffi_target(name, target)
jax.ffi.register_ffi_target(name, target)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
@ -53,7 +52,7 @@ def rms_norm(x, eps=1e-5):
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
return jex.ffi.ffi_call(
return jax.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_ffi_target`
"rms_norm",
@ -63,7 +62,7 @@ def rms_norm(x, eps=1e-5):
def rms_norm_fwd(x, eps=1e-5):
y, res = jex.ffi.ffi_call(
y, res = jax.ffi.ffi_call(
"rms_norm_fwd",
(
jax.ShapeDtypeStruct(x.shape, x.dtype),
@ -80,7 +79,7 @@ def rms_norm_bwd(eps, res, ct):
assert res.shape == ct.shape[:-1]
assert x.shape == ct.shape
return (
jex.ffi.ffi_call(
jax.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
vmap_method="broadcast_all",

View File

@ -199,6 +199,7 @@ py_library_providing_imports_info(
"_src/dispatch.py",
"_src/dlpack.py",
"_src/earray.py",
"_src/ffi.py",
"_src/flatten_util.py",
"_src/interpreters/__init__.py",
"_src/interpreters/ad.py",
@ -730,7 +731,6 @@ py_library(
":jax",
":mlir",
"//jax/_src/lib",
"//jax/extend:ffi",
"//jaxlib/mlir:arithmetic_dialect",
"//jaxlib/mlir:builtin_dialect",
"//jaxlib/mlir:func_dialect",

View File

@ -160,6 +160,7 @@ from jax import debug as debug
from jax import dlpack as dlpack
from jax import dtypes as dtypes
from jax import errors as errors
from jax import ffi as ffi
from jax import image as image
from jax import lax as lax
from jax import monitoring as monitoring

View File

@ -76,8 +76,8 @@ def pycapsule(funcptr):
"""Wrap a ctypes function pointer in a PyCapsule.
The primary use of this function, and the reason why it lives with in the
``jax.extend.ffi`` submodule, is to wrap function calls from external
compiled libraries to be registered as XLA custom calls.
``jax.ffi`` submodule, is to wrap function calls from external compiled
libraries to be registered as XLA custom calls.
Example usage::
@ -88,7 +88,7 @@ def pycapsule(funcptr):
libfoo = ctypes.cdll.LoadLibrary('./foo.so')
xla_client.register_custom_call_target(
name="bar",
fn=jax.extend.ffi.pycapsule(libfoo.bar),
fn=jax.ffi.pycapsule(libfoo.bar),
platform=PLATFORM,
api_version=API_VERSION
)
@ -145,7 +145,7 @@ def ffi_lowering(
Note that layouts passed to this function as tuples should be in
minor-to-major order (as expected by XLA) rather than major-to-minor as used
by :func:`~jax.extend.ffi.ffi_call` and ``DeviceLocalLayout``.
by :func:`~jax.ffi.ffi_call` and ``DeviceLocalLayout``.
If keyword arguments are passed to the lowering rule, these are treated as
attributes, and added to `backend_config`.
@ -310,7 +310,7 @@ def ffi_call(
Args:
target_name: the name of the XLA FFI custom call target that was registered
using :func:`~jax.extend.ffi.register_ffi_target`.
using :func:`~jax.ffi.register_ffi_target`.
result_shape_dtypes: an object, or sequence of objects, with ``shape`` and
``dtype`` attributes which are expected to match the shape and dtype of
the custom call output or outputs. :class:`~jax.ShapeDtypeStruct` is often

View File

@ -34,7 +34,7 @@ from jax._src import dtypes
from jax._src import util
from jax._src.core import (
Primitive, ShapedArray, is_constant_dim, is_constant_shape)
from jax._src.extend import ffi
from jax._src import ffi
from jax._src.interpreters import ad
from jax._src.interpreters import batching
from jax._src.interpreters import mlir

View File

@ -22,7 +22,6 @@ import warnings
import jax
from jax._src.lib import xla_client
from jax.extend import ffi
import jax.numpy as jnp
from jaxlib.mlir import ir
from jaxlib.mlir.dialects import arith
@ -54,7 +53,7 @@ P = ParamSpec("P")
def _event_record(args, *, copy_before):
flat_args, treedef = jax.tree.flatten(args)
event, *flat_outs = ffi.ffi_call(
event, *flat_outs = jax.ffi.ffi_call(
"mgpu_event_record",
result_shape_dtypes=(jax.core.ShapedArray((), jnp.uint64), *flat_args),
input_output_aliases={i: i + 1 for i in range(len(flat_args))},
@ -63,7 +62,7 @@ def _event_record(args, *, copy_before):
def _event_elapsed(start_event, end_event):
return ffi.ffi_call(
return jax.ffi.ffi_call(
"mgpu_event_elapsed",
result_shape_dtypes=jax.core.ShapedArray((), jnp.float32),
)(start_event, end_event)

View File

@ -52,7 +52,7 @@ from jax._src.api import _shared_code_pmap, _prepare_pmap
from jax._src.lax import (lax, parallel as lax_parallel, slicing,
windowed_reductions, convolution, fft, linalg,
special, control_flow, ann)
from jax._src.extend import ffi
from jax._src import ffi
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import sdy
from jax._src.util import (HashableFunction, HashablePartial, unzip2,

View File

@ -12,13 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src import ffi as _ffi
from jax._src.extend.ffi import (
ffi_call as ffi_call,
ffi_lowering as ffi_lowering,
include_dir as include_dir,
pycapsule as pycapsule,
register_ffi_target as register_ffi_target,
)
_deprecations = {
# Added 2024-12-20
"ffi_call": (
"jax.extend.ffi.ffi_call is deprecated, use jax.ffi.ffi_call instead.",
_ffi.ffi_call,
),
"ffi_lowering": (
"jax.extend.ffi.ffi_lowering is deprecated, use jax.ffi.ffi_lowering instead.",
_ffi.ffi_lowering,
),
"include_dir": (
"jax.extend.ffi.include_dir is deprecated, use jax.ffi.include_dir instead.",
_ffi.include_dir,
),
"pycapsule": (
"jax.extend.ffi.pycapsule is deprecated, use jax.ffi.pycapsule instead.",
_ffi.pycapsule,
),
"register_ffi_target": (
"jax.extend.ffi.register_ffi_target is deprecated, use jax.ffi.register_ffi_target instead.",
_ffi.register_ffi_target,
),
}
import typing
if typing.TYPE_CHECKING:
ffi_call = _ffi.ffi_call
ffi_lowering = _ffi.ffi_lowering
include_dir = _ffi.include_dir
pycapsule = _ffi.pycapsule
register_ffi_target = _ffi.register_ffi_target
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)
del _deprecation_getattr
del typing
del _ffi

24
jax/ffi.py Normal file
View File

@ -0,0 +1,24 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Note: import <name> as <name> is required for names to be exported.
# See PEP 484 & https://github.com/jax-ml/jax/issues/7570
from jax._src.ffi import (
ffi_call as ffi_call,
ffi_lowering as ffi_lowering,
include_dir as include_dir,
pycapsule as pycapsule,
register_ffi_target as register_ffi_target,
)

View File

@ -142,6 +142,13 @@ jax_multiplatform_test(
deps = ["//jax:extend"],
)
jax_multiplatform_test(
name = "ffi_test",
srcs = ["ffi_test.py"],
# TODO(dfm): Remove after removal of jex.ffi imports.
deps = ["//jax:extend"],
)
jax_multiplatform_test(
name = "fft_test",
srcs = ["fft_test.py"],

View File

@ -12,35 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from functools import partial
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
import jax.extend as jex
import jax.numpy as jnp
import jax.sharding as shd
from jax._src import abstract_arrays
from jax._src import api
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import linear_util
from jax._src import prng
from jax._src import test_util as jtu
from jax._src import xla_bridge
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import lapack
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal
from jax.experimental.shard_map import shard_map
jax.config.parse_flags_with_absl()
@ -110,294 +94,6 @@ class RandomTest(jtu.JaxTestCase):
self.assertEqual(repr(spec), f"PRNGSpec({spec_ref._impl.name!r})")
class FfiTest(jtu.JaxTestCase):
def find_custom_call_in_module(self, module):
for func in module.body.operations:
for block in func.body.blocks:
for op in block.operations:
if op.OPERATION_NAME == "stablehlo.custom_call":
return op
self.fail("No custom_call found in the lowered IR")
def testHeadersExist(self):
base_dir = os.path.join(jex.ffi.include_dir(), "xla", "ffi", "api")
for header in ["c_api.h", "api.h", "ffi.h"]:
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
@parameterized.parameters([
(tuple(range(3)), tuple(range(3))),
(None, tuple(reversed(range(3)))),
(DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))),
])
def testLoweringLayouts(self, layout_spec, expected_layout):
# Regression test to ensure that the lowering rule properly captures
# layouts.
def lowering_rule(ctx, x):
aval, = ctx.avals_in
return jex.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
result_layouts=[layout_spec])(ctx, x)
prim = core.Primitive("test_ffi")
prim.def_impl(lambda x: x)
prim.def_abstract_eval(lambda x: x)
mlir.register_lowering(prim, lowering_rule)
x = jnp.ones((3,) * len(expected_layout))
lowered = jax.jit(prim.bind).lower(x)
module = lowered.compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
self.assertIn("operand_layouts", op.attributes)
self.assertIn("result_layouts", op.attributes)
text = lowered.as_text()
expected = ", ".join(map(str, expected_layout))
pattern = rf"operand_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
pattern = rf"result_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
@parameterized.parameters([
(True, mlir.ir.BoolAttr.get),
(1, mlir.i64_attr),
(5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)),
("param", mlir.ir.StringAttr.get),
(np.float32(0.5),
lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)),
])
def testParams(self, param, expected_builder):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x)(x, param=param)
# Here we inspect the lowered IR to test that the parameter has been
# serialized with the appropriate type.
module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
config = op.attributes["mhlo.backend_config"]
self.assertIsInstance(config, mlir.ir.DictAttr)
self.assertIn("param", config)
with mlir.make_ir_context(), mlir.ir.Location.unknown():
expected = expected_builder(param)
self.assertEqual(type(config["param"]), type(expected))
self.assertTrue(expected.type.isinstance(config["param"].type))
def testToken(self):
def fun():
token = lax.create_token()
return jex.ffi.ffi_call("test_ffi", core.abstract_token)(token)
# Ensure that token inputs and outputs are translated to the correct type
module = jax.jit(fun).lower().compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
def testEffectsHlo(self):
# The target name must exist on the current platform, but we don't actually
# need to call it with the correct syntax, because we're only checking the
# compiled HLO.
if jtu.test_device_matches(["cpu"]):
target_name = "lapack_sgetrf_ffi"
elif jtu.test_device_matches(["rocm"]):
target_name = "hipsolver_getrf_ffi"
elif jtu.test_device_matches(["cuda", "gpu"]):
target_name = "cusolver_getrf_ffi"
else:
raise unittest.SkipTest("Unsupported device")
def fun():
jex.ffi.ffi_call(target_name, (), has_side_effect=True)()
hlo = jax.jit(fun).lower()
self.assertIn(target_name, hlo.as_text())
self.assertIn("has_side_effect = true", hlo.as_text())
self.assertIn(target_name, hlo.compile().as_text())
def testJvpError(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
with self.assertRaisesRegex(
ValueError, "The FFI call to `.+` cannot be differentiated."):
jax.jvp(fun, (0.5,), (0.5,))
def testNonHashableAttributes(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5))))
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertIn("non_hashable_arg = {a = 1", hlo)
# If non-hashable arguments aren't handled properly, this will raise a
# TypeError. We make sure it doesn't.
with self.assertRaises(Exception) as manager:
fun(jnp.ones(5))
self.assertNotIsInstance(manager.exception, TypeError)
def fun(x):
return jex.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg=np.arange(3))
self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5))))
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertIn("non_hashable_arg = array<i64: 0, 1, 2>", hlo)
with self.assertRaises(Exception) as manager:
fun(jnp.ones(5))
self.assertNotIsInstance(manager.exception, TypeError)
@jtu.sample_product(shape=[(6, 5), (4, 5, 6)])
@jtu.run_on_devices("gpu", "cpu")
def testFfiCall(self, shape):
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = ffi_call_geqrf(x)
for a, b in zip(actual, expected):
self.assertArraysEqual(a, b)
@jtu.sample_product(
shape=[(6, 5), (4, 5, 6)],
vmap_method=["expand_dims", "broadcast_all", "sequential"],
)
@jtu.run_on_devices("gpu", "cpu")
def testFfiCallBatching(self, shape, vmap_method):
shape = (10,) + shape
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
for a, b in zip(actual, expected):
if vmap_method == "sequential" and len(shape) == 3:
# On GPU, the batched FFI call to geqrf uses an algorithm with
# different numerics than the unbatched version (which is used when
# vmap_method="sequential"). Therefore, we need to include floating
# point tolerance for this check.
self.assertArraysAllClose(a, b)
else:
self.assertArraysEqual(a, b)
@jtu.run_on_devices("gpu", "cpu")
def testVectorizedDeprecation(self):
x = self.rng().randn(3, 5, 4).astype(np.float32)
with self.assertWarns(DeprecationWarning):
ffi_call_geqrf(x, vectorized=True)
with self.assertWarns(DeprecationWarning):
jax.vmap(ffi_call_geqrf)(x)
def testBackwardCompatSyntax(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x, x, param=0.5)
msg = "Calling ffi_call directly with input arguments is deprecated"
if deprecations.is_accelerated("jax-ffi-call-args"):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(fun).lower(jnp.ones(5))
else:
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(fun).lower(jnp.ones(5))
def testInputOutputAliases(self):
def fun(x):
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]")
def testInvalidInputOutputAliases(self):
def fun(x):
return jex.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x)
with self.assertRaisesRegex(ValueError, "with input index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x)
with self.assertRaisesRegex(ValueError, "with output index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape,
x.dtype),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def testLegacyBackendConfig(self):
def fun(x):
return jex.ffi.ffi_call("test", x, custom_call_api_version=2,
legacy_backend_config="12345")(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, 'backend_config = "12345"')
def testInvalidBackendConfig(self):
def fun(x):
return jex.ffi.ffi_call("test", x, legacy_backend_config="12345")(x)
with self.assertRaisesRegex(ValueError,
"The use of the legacy_backend_config"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jex.ffi.ffi_call("test", x,
custom_call_api_version=2)(x, attribute=1)
with self.assertRaisesRegex(ValueError,
"The use of ffi_call attributes requires"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def testAllow64(self):
if config.enable_x64.value:
self.skipTest("Requires enable_x64=False")
def fun():
return jex.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))()
self.assertIn("tensor<i64>", jax.jit(fun).lower().as_text())
def testInvalidResultType(self):
with self.assertRaisesRegex(
ValueError, "All elements of result_shape_dtypes.*position 0"):
jex.ffi.ffi_call("test", None)()
with self.assertRaisesRegex(
ValueError, "All elements of result_shape_dtypes.*position 1"):
jex.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))()
@jtu.run_on_devices("gpu", "cpu")
def testShardMap(self):
mesh = jtu.create_mesh((1,), ("i",))
x = self.rng().randn(8, 4, 5).astype(np.float32)
@partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'),
out_specs=shd.PartitionSpec('i'))
def f(x):
return ffi_call_geqrf(x)
f(x) # eager mode doesn't crash
jax.jit(f)(x) # neither does JIT
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
def ffi_call_geqrf(x, **kwargs):
if jtu.test_device_matches(["cpu"]):
lapack._lapack.initialize()
assert x.dtype == np.float32
ndim = x.ndim
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
output_types = [
x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)]
def call(platform, x):
target_name = dict(
cpu="lapack_sgeqrf_ffi",
rocm="hipsolver_geqrf_ffi",
cuda="cusolver_geqrf_ffi",
)[platform]
return jex.ffi.ffi_call(
target_name, output_types, input_output_aliases={0: 0},
input_layouts=[x_major_to_minor],
output_layouts=[x_major_to_minor, None],
**kwargs)(x)
return lax.platform_dependent(
x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"),
cuda=partial(call, "cuda"))
class MlirRegisterLoweringTest(jtu.JaxTestCase):
def test_unknown_platform_error(self):

338
tests/ffi_test.py Normal file
View File

@ -0,0 +1,338 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
from functools import partial
import numpy as np
from absl.testing import absltest
from absl.testing import parameterized
import jax
from jax import lax
import jax.extend as jex
import jax.numpy as jnp
import jax.sharding as shd
from jax._src import config
from jax._src import core
from jax._src import deprecations
from jax._src import test_util as jtu
from jax._src.interpreters import mlir
from jax._src.layout import DeviceLocalLayout
from jax._src.lib import lapack
from jax._src.lib.mlir.dialects import hlo
from jax._src.lax import linalg as lax_linalg_internal
from jax.experimental.shard_map import shard_map
jax.config.parse_flags_with_absl()
class FfiTest(jtu.JaxTestCase):
def find_custom_call_in_module(self, module):
for func in module.body.operations:
for block in func.body.blocks:
for op in block.operations:
if op.OPERATION_NAME == "stablehlo.custom_call":
return op
self.fail("No custom_call found in the lowered IR")
def test_headers_exist(self):
base_dir = os.path.join(jax.ffi.include_dir(), "xla", "ffi", "api")
for header in ["c_api.h", "api.h", "ffi.h"]:
self.assertTrue(os.path.exists(os.path.join(base_dir, header)))
@parameterized.parameters([
(tuple(range(3)), tuple(range(3))),
(None, tuple(reversed(range(3)))),
(DeviceLocalLayout(tuple(range(3))), tuple(reversed(range(3)))),
])
def test_lowering_layouts(self, layout_spec, expected_layout):
# Regression test to ensure that the lowering rule properly captures
# layouts.
def lowering_rule(ctx, x):
return jax.ffi.ffi_lowering("test_ffi", operand_layouts=[layout_spec],
result_layouts=[layout_spec])(ctx, x)
prim = core.Primitive("test_ffi")
prim.def_impl(lambda x: x)
prim.def_abstract_eval(lambda x: x)
mlir.register_lowering(prim, lowering_rule)
x = jnp.ones((3,) * len(expected_layout))
lowered = jax.jit(prim.bind).lower(x)
module = lowered.compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
self.assertIn("operand_layouts", op.attributes)
self.assertIn("result_layouts", op.attributes)
text = lowered.as_text()
expected = ", ".join(map(str, expected_layout))
pattern = rf"operand_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
pattern = rf"result_layouts = \[dense<\[{expected}\]>"
self.assertRegex(text, pattern)
@parameterized.parameters([
(True, mlir.ir.BoolAttr.get),
(1, mlir.i64_attr),
(5.0, lambda x: mlir.ir.FloatAttr.get(mlir.ir.F64Type.get(), x)),
("param", mlir.ir.StringAttr.get),
(np.float32(0.5),
lambda x: mlir.ir.FloatAttr.get(mlir.ir.F32Type.get(), x)),
])
def test_params(self, param, expected_builder):
def fun(x):
return jax.ffi.ffi_call("test_ffi", x)(x, param=param)
# Here we inspect the lowered IR to test that the parameter has been
# serialized with the appropriate type.
module = jax.jit(fun).lower(0.5).compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
config = op.attributes["mhlo.backend_config"]
self.assertIsInstance(config, mlir.ir.DictAttr)
self.assertIn("param", config)
with mlir.make_ir_context(), mlir.ir.Location.unknown():
expected = expected_builder(param)
self.assertEqual(type(config["param"]), type(expected))
self.assertTrue(expected.type.isinstance(config["param"].type))
def test_token(self):
def fun():
token = lax.create_token()
return jax.ffi.ffi_call("test_ffi", core.abstract_token)(token)
# Ensure that token inputs and outputs are translated to the correct type
module = jax.jit(fun).lower().compiler_ir("stablehlo")
op = self.find_custom_call_in_module(module)
self.assertTrue(hlo.TokenType.isinstance(op.operands[0].type))
self.assertTrue(hlo.TokenType.isinstance(op.results[0].type))
def test_effects_hlo(self):
# The target name must exist on the current platform, but we don't actually
# need to call it with the correct syntax, because we're only checking the
# compiled HLO.
if jtu.test_device_matches(["cpu"]):
target_name = "lapack_sgetrf_ffi"
elif jtu.test_device_matches(["rocm"]):
target_name = "hipsolver_getrf_ffi"
elif jtu.test_device_matches(["cuda", "gpu"]):
target_name = "cusolver_getrf_ffi"
else:
raise unittest.SkipTest("Unsupported device")
def fun():
jax.ffi.ffi_call(target_name, (), has_side_effect=True)()
hlo = jax.jit(fun).lower()
self.assertIn(target_name, hlo.as_text())
self.assertIn("has_side_effect = true", hlo.as_text())
self.assertIn(target_name, hlo.compile().as_text())
def test_jvp_error(self):
def fun(x):
return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
with self.assertRaisesRegex(
ValueError, "The FFI call to `.+` cannot be differentiated."):
jax.jvp(fun, (0.5,), (0.5,))
def test_non_hashable_attributes(self):
def fun(x):
return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg={"a": 1})
self.assertIn("HashableDict", str(jax.make_jaxpr(fun)(jnp.ones(5))))
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertIn("non_hashable_arg = {a = 1", hlo)
# If non-hashable arguments aren't handled properly, this will raise a
# TypeError. We make sure it doesn't.
with self.assertRaises(Exception) as manager:
fun(jnp.ones(5))
self.assertNotIsInstance(manager.exception, TypeError)
def fun(x):
return jax.ffi.ffi_call("test_ffi", x)(x, non_hashable_arg=np.arange(3))
self.assertIn("HashableArray", str(jax.make_jaxpr(fun)(jnp.ones(5))))
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertIn("non_hashable_arg = array<i64: 0, 1, 2>", hlo)
with self.assertRaises(Exception) as manager:
fun(jnp.ones(5))
self.assertNotIsInstance(manager.exception, TypeError)
@jtu.sample_product(shape=[(6, 5), (4, 5, 6)])
@jtu.run_on_devices("gpu", "cpu")
def test_ffi_call(self, shape):
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = ffi_call_geqrf(x)
for a, b in zip(actual, expected):
self.assertArraysEqual(a, b)
@jtu.sample_product(
shape=[(6, 5), (4, 5, 6)],
vmap_method=["expand_dims", "broadcast_all", "sequential"],
)
@jtu.run_on_devices("gpu", "cpu")
def test_ffi_call_batching(self, shape, vmap_method):
shape = (10,) + shape
x = self.rng().randn(*shape).astype(np.float32)
expected = lax_linalg_internal.geqrf(x)
actual = jax.vmap(partial(ffi_call_geqrf, vmap_method=vmap_method))(x)
for a, b in zip(actual, expected):
if vmap_method == "sequential" and len(shape) == 3:
# On GPU, the batched FFI call to geqrf uses an algorithm with
# different numerics than the unbatched version (which is used when
# vmap_method="sequential"). Therefore, we need to include floating
# point tolerance for this check.
self.assertArraysAllClose(a, b)
else:
self.assertArraysEqual(a, b)
@jtu.run_on_devices("gpu", "cpu")
def test_vectorized_deprecation(self):
x = self.rng().randn(3, 5, 4).astype(np.float32)
with self.assertWarns(DeprecationWarning):
ffi_call_geqrf(x, vectorized=True)
with self.assertWarns(DeprecationWarning):
jax.vmap(ffi_call_geqrf)(x)
def test_backward_compat_syntax(self):
def fun(x):
return jex.ffi.ffi_call("test_ffi", x, x, param=0.5)
msg = "Calling ffi_call directly with input arguments is deprecated"
if deprecations.is_accelerated("jax-ffi-call-args"):
with self.assertRaisesRegex(ValueError, msg):
jax.jit(fun).lower(jnp.ones(5))
else:
with self.assertWarnsRegex(DeprecationWarning, msg):
jax.jit(fun).lower(jnp.ones(5))
def test_input_output_aliases(self):
def fun(x):
return jax.ffi.ffi_call("test", x, input_output_aliases={0: 0})(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, r"output_operand_aliases = \[.*operand_index = 0.*\]")
def test_invalid_input_output_aliases(self):
def fun(x):
return jax.ffi.ffi_call("test", x, input_output_aliases={1: 0})(x)
with self.assertRaisesRegex(ValueError, "with input index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jax.ffi.ffi_call("test", x, input_output_aliases={0: 1})(x)
with self.assertRaisesRegex(ValueError, "with output index"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape, np.int32),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct(x.shape + x.shape,
x.dtype),
input_output_aliases={0: 0})(x)
with self.assertRaisesRegex(ValueError,
"referring to an input with abstract value"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def test_legacy_backend_config(self):
def fun(x):
return jax.ffi.ffi_call("test", x, custom_call_api_version=2,
legacy_backend_config="12345")(x)
hlo = jax.jit(fun).lower(jnp.ones(5)).as_text()
self.assertRegex(hlo, 'backend_config = "12345"')
def test_invalid_backend_config(self):
def fun(x):
return jax.ffi.ffi_call("test", x, legacy_backend_config="12345")(x)
with self.assertRaisesRegex(ValueError,
"The use of the legacy_backend_config"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def fun(x):
return jax.ffi.ffi_call("test", x,
custom_call_api_version=2)(x, attribute=1)
with self.assertRaisesRegex(ValueError,
"The use of ffi_call attributes requires"):
jax.jit(fun).lower(jnp.ones(5)).as_text()
def test_allow_x64(self):
if config.enable_x64.value:
self.skipTest("Requires enable_x64=False")
def fun():
return jax.ffi.ffi_call("test", jax.ShapeDtypeStruct((), np.int64))()
self.assertIn("tensor<i64>", jax.jit(fun).lower().as_text())
def test_invalid_result_type(self):
with self.assertRaisesRegex(
ValueError, "All elements of result_shape_dtypes.*position 0"):
jax.ffi.ffi_call("test", None)()
with self.assertRaisesRegex(
ValueError, "All elements of result_shape_dtypes.*position 1"):
jax.ffi.ffi_call("test", (jax.ShapeDtypeStruct((), np.float32), ()))()
@jtu.run_on_devices("gpu", "cpu")
def test_shard_map(self):
mesh = jtu.create_mesh((1,), ("i",))
x = self.rng().randn(8, 4, 5).astype(np.float32)
@partial(shard_map, mesh=mesh, in_specs=shd.PartitionSpec('i'),
out_specs=shd.PartitionSpec('i'))
def f(x):
return ffi_call_geqrf(x)
f(x) # eager mode doesn't crash
jax.jit(f)(x) # neither does JIT
self.assertNotIn("all-gather", jax.jit(f).lower(x).compile().as_text())
@jtu.run_on_devices("gpu", "cpu")
@jtu.ignore_warning(category=DeprecationWarning)
def test_extend_import_shim(self):
ffi_call_geqrf(jnp.ones((4, 5), dtype=np.float32), _use_extend=True)
def ffi_call_geqrf(x, _use_extend=False, **kwargs):
if jtu.test_device_matches(["cpu"]):
lapack._lapack.initialize()
assert x.dtype == np.float32
ndim = x.ndim
x_major_to_minor = tuple(range(ndim - 2)) + (ndim - 1, ndim - 2)
output_types = [
x, jax.ShapeDtypeStruct(x.shape[:-2] + (min(*x.shape[-2:]),), x.dtype)]
def call(platform, x):
target_name = dict(
cpu="lapack_sgeqrf_ffi",
rocm="hipsolver_geqrf_ffi",
cuda="cusolver_geqrf_ffi",
)[platform]
f = jex.ffi.ffi_call if _use_extend else jax.ffi.ffi_call
return f(
target_name, output_types, input_output_aliases={0: 0},
input_layouts=[x_major_to_minor],
output_layouts=[x_major_to_minor, None],
**kwargs)(x)
return lax.platform_dependent(
x, cpu=partial(call, "cpu"), rocm=partial(call, "rocm"),
cuda=partial(call, "cuda"))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())