Merge pull request #23409 from dfm:ffi-examples

PiperOrigin-RevId: 678690801
This commit is contained in:
jax authors 2024-09-25 07:23:26 -07:00
commit 9d277e61ce
9 changed files with 386 additions and 2 deletions

View File

@ -210,4 +210,37 @@ jobs:
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py
ffi:
name: FFI example
runs-on: ubuntu-latest
timeout-minutes: 5
steps:
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
- name: Set up Python 3.11
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
with:
python-version: 3.11
- name: Get pip cache dir
id: pip-cache
run: |
python -m pip install --upgrade pip wheel
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
- name: pip cache
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
- name: Install JAX
run: pip install .
- name: Build and install example project
run: python -m pip install -v ./examples/ffi[test]
env:
# We test building using GCC instead of clang. All other JAX builds use
# clang, but it is useful to make sure that FFI users can compile using
# a different toolchain. GCC is the default compiler on the
# 'ubuntu-latest' runner, but we still set this explicitly just to be
# clear.
CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++
- name: Run tests
run: python -m pytest examples/ffi/tests

View File

@ -0,0 +1,15 @@
cmake_minimum_required(VERSION 3.15...3.30)
project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX)
find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)
execute_process(
COMMAND "${Python_EXECUTABLE}"
"-c" "from jax.extend import ffi; print(ffi.include_dir())"
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR)
message(STATUS "XLA include directory: ${XLA_DIR}")
find_package(nanobind CONFIG REQUIRED)
nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})

9
examples/ffi/README.md Normal file
View File

@ -0,0 +1,9 @@
# End-to-end example usage for JAX's foreign function interface
This directory includes an example project demonstrating the use of JAX's
foreign function interface (FFI). The JAX docs provide more information about
this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html),
but the example in this directory explicitly demonstrates:
1. One way to package and distribute FFI targets, and
2. Some more advanced use cases.

View File

@ -0,0 +1,12 @@
[build-system]
requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"]
build-backend = "scikit_build_core.build"
[project]
name = "jax_ffi_example"
version = "0.0.1"
requires-python = ">=3.10"
dependencies = ["jax"]
[project.optional-dependencies]
test = ["pytest", "absl-py"]

View File

@ -0,0 +1,13 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

View File

@ -0,0 +1,157 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cmath>
#include <cstdint>
#include <functional>
#include <numeric>
#include <type_traits>
#include <utility>
#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/api/ffi.h"
namespace nb = nanobind;
namespace ffi = xla::ffi;
// This is the example "library function" that we want to expose to JAX. This
// isn't meant to be a particularly good implementation, it's just here as a
// placeholder for the purposes of this tutorial.
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
float sm = 0.0f;
for (int64_t n = 0; n < size; ++n) {
sm += x[n] * x[n];
}
float scale = 1.0f / std::sqrt(sm / float(size) + eps);
for (int64_t n = 0; n < size; ++n) {
y[n] = x[n] * scale;
}
return scale;
}
// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
auto dims = buffer.dimensions();
if (dims.size() == 0) {
return std::make_pair(0, 0);
}
return std::make_pair(buffer.element_count(), dims.back());
}
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::Result<ffi::Buffer<ffi::F32>> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
}
return ffi::Error::Success();
}
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::Result<ffi::Buffer<ffi::F32>> y,
ffi::Result<ffi::Buffer<ffi::F32>> res) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormFwd input must be an array");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
&(y->typed_data()[n]));
}
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
.Ret<ffi::Buffer<ffi::F32>>() // res
);
void ComputeRmsNormBwd(int64_t size, float res, const float *x,
const float *ct_y, float *ct_x) {
float ct_res = 0.0f;
for (int64_t n = 0; n < size; ++n) {
ct_res += x[n] * ct_y[n];
}
float factor = ct_res * res * res * res / float(size);
for (int64_t n = 0; n < size; ++n) {
ct_x[n] = res * ct_y[n] - factor * x[n];
}
}
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
ffi::Buffer<ffi::F32> ct_y,
ffi::Result<ffi::Buffer<ffi::F32>> ct_x) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormBwd inputs must be arrays");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
&(ct_y.typed_data()[n]), &(ct_x->typed_data()[n]));
}
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::F32>>() // res
.Arg<ffi::Buffer<ffi::F32>>() // x
.Arg<ffi::Buffer<ffi::F32>>() // ct_y
.Ret<ffi::Buffer<ffi::F32>>() // ct_x
);
template <typename T>
nb::capsule EncapsulateFfiHandler(T *fn) {
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
"Encapsulated function must be and XLA FFI handler");
return nb::capsule(reinterpret_cast<void *>(fn));
}
NB_MODULE(_rms_norm, m) {
m.def("registrations", []() {
nb::dict registrations;
registrations["rms_norm"] = EncapsulateFfiHandler(RmsNorm);
registrations["rms_norm_fwd"] = EncapsulateFfiHandler(RmsNormFwd);
registrations["rms_norm_bwd"] = EncapsulateFfiHandler(RmsNormBwd);
return registrations;
});
}

View File

@ -0,0 +1,99 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example demontrating the basic end-to-end use of the JAX FFI.
This example is exactly the same as the one in the `FFI tutorial
<https://jax.readthedocs.io/en/latest/ffi.html>`, so more details can be found
on that page. But, the high level summary is that we implement our custom
extension in ``rms_norm.cc``, then call it usin ``jax.extend.ffi.ffi_call`` in
this module. The behavior under autodiff is implemented using
``jax.custom_vjp``.
"""
from functools import partial
import numpy as np
import jax
import jax.extend as jex
import jax.numpy as jnp
from jax_ffi_example import _rms_norm
for name, target in _rms_norm.registrations().items():
jex.ffi.register_ffi_target(name, target)
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def rms_norm(x, eps=1e-5):
# We only implemented the `float32` version of this function, so we start by
# checking the dtype. This check isn't strictly necessary because type
# checking is also performed by the FFI when decoding input and output
# buffers, but it can be useful to check types in Python to raise more
# informative errors.
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")
# In this case, the output of our FFI function is just a single array with the
# same shape and dtype as the input.
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
return jex.ffi.ffi_call(
# The target name must be the same string as we used to register the target
# above in `register_ffi_target`
"rms_norm",
out_type,
x,
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
# type (which corresponds to numpy's `float32` type), and it must be a
# static parameter (i.e. not a JAX array).
eps=np.float32(eps),
# The `vectorized` parameter controls this function's behavior under `vmap`.
vectorized=True,
)
def rms_norm_fwd(x, eps=1e-5):
y, res = jex.ffi.ffi_call(
"rms_norm_fwd",
(
jax.ShapeDtypeStruct(x.shape, x.dtype),
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
),
x,
eps=np.float32(eps),
vectorized=True,
)
return y, (res, x)
def rms_norm_bwd(eps, res, ct):
del eps
res, x = res
assert res.shape == ct.shape[:-1]
assert x.shape == ct.shape
return (
jex.ffi.ffi_call(
"rms_norm_bwd",
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
res,
x,
ct,
vectorized=True,
),
)
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)

View File

@ -0,0 +1,46 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax_ffi_example import rms_norm
jax.config.parse_flags_with_absl()
def rms_norm_ref(x, eps=1e-5):
scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
return x / scale
class RmsNormTests(jtu.JaxTestCase):
def test_basic(self):
x = jnp.linspace(-0.5, 0.5, 15)
self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x))
def test_batching(self):
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
self.assertAllClose(jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x))
def test_grads(self):
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",))
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -70,7 +70,7 @@ doctest_optionflags = [
"NUMBER",
"NORMALIZE_WHITESPACE"
]
addopts = "--doctest-glob='*.rst'"
addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'"
[tool.pylint.master]
extension-pkg-whitelist = "numpy"