mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #23409 from dfm:ffi-examples
PiperOrigin-RevId: 678690801
This commit is contained in:
commit
9d277e61ce
35
.github/workflows/ci-build.yaml
vendored
35
.github/workflows/ci-build.yaml
vendored
@ -210,4 +210,37 @@ jobs:
|
||||
echo "JAX_ENABLE_CHECKS=$JAX_ENABLE_CHECKS"
|
||||
echo "JAX_SKIP_SLOW_TESTS=$JAX_SKIP_SLOW_TESTS"
|
||||
pytest -n auto --tb=short --maxfail=20 jax/experimental/jax2tf/tests/jax2tf_test.py
|
||||
|
||||
|
||||
ffi:
|
||||
name: FFI example
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 5
|
||||
steps:
|
||||
- uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # ratchet:actions/checkout@v4
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@39cd14951b08e74b54015e9e001cdefcf80e669f # ratchet:actions/setup-python@v5
|
||||
with:
|
||||
python-version: 3.11
|
||||
- name: Get pip cache dir
|
||||
id: pip-cache
|
||||
run: |
|
||||
python -m pip install --upgrade pip wheel
|
||||
echo "dir=$(pip cache dir)" >> $GITHUB_OUTPUT
|
||||
- name: pip cache
|
||||
uses: actions/cache@0c45773b623bea8c8e75f6c82b208c3cf94ea4f9 # ratchet: actions/cache@v4
|
||||
with:
|
||||
path: ${{ steps.pip-cache.outputs.dir }}
|
||||
key: ${{ runner.os }}-pip-ffi-examples-${{ hashFiles('**/setup.py', '**/requirements.txt', '**/test-requirements.txt', 'examples/**/pyproject.toml') }}
|
||||
- name: Install JAX
|
||||
run: pip install .
|
||||
- name: Build and install example project
|
||||
run: python -m pip install -v ./examples/ffi[test]
|
||||
env:
|
||||
# We test building using GCC instead of clang. All other JAX builds use
|
||||
# clang, but it is useful to make sure that FFI users can compile using
|
||||
# a different toolchain. GCC is the default compiler on the
|
||||
# 'ubuntu-latest' runner, but we still set this explicitly just to be
|
||||
# clear.
|
||||
CMAKE_ARGS: -DCMAKE_C_COMPILER=gcc -DCMAKE_CXX_COMPILER=g++
|
||||
- name: Run tests
|
||||
run: python -m pytest examples/ffi/tests
|
||||
|
15
examples/ffi/CMakeLists.txt
Normal file
15
examples/ffi/CMakeLists.txt
Normal file
@ -0,0 +1,15 @@
|
||||
cmake_minimum_required(VERSION 3.15...3.30)
|
||||
project(${SKBUILD_PROJECT_NAME} LANGUAGES CXX)
|
||||
|
||||
find_package(Python 3.10 REQUIRED COMPONENTS Interpreter Development.Module)
|
||||
execute_process(
|
||||
COMMAND "${Python_EXECUTABLE}"
|
||||
"-c" "from jax.extend import ffi; print(ffi.include_dir())"
|
||||
OUTPUT_STRIP_TRAILING_WHITESPACE OUTPUT_VARIABLE XLA_DIR)
|
||||
message(STATUS "XLA include directory: ${XLA_DIR}")
|
||||
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
|
||||
target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
|
||||
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
9
examples/ffi/README.md
Normal file
9
examples/ffi/README.md
Normal file
@ -0,0 +1,9 @@
|
||||
# End-to-end example usage for JAX's foreign function interface
|
||||
|
||||
This directory includes an example project demonstrating the use of JAX's
|
||||
foreign function interface (FFI). The JAX docs provide more information about
|
||||
this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html),
|
||||
but the example in this directory explicitly demonstrates:
|
||||
|
||||
1. One way to package and distribute FFI targets, and
|
||||
2. Some more advanced use cases.
|
12
examples/ffi/pyproject.toml
Normal file
12
examples/ffi/pyproject.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[build-system]
|
||||
requires = ["scikit-build-core", "nanobind", "jax>=0.4.31"]
|
||||
build-backend = "scikit_build_core.build"
|
||||
|
||||
[project]
|
||||
name = "jax_ffi_example"
|
||||
version = "0.0.1"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["jax"]
|
||||
|
||||
[project.optional-dependencies]
|
||||
test = ["pytest", "absl-py"]
|
13
examples/ffi/src/jax_ffi_example/__init__.py
Normal file
13
examples/ffi/src/jax_ffi_example/__init__.py
Normal file
@ -0,0 +1,13 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
157
examples/ffi/src/jax_ffi_example/rms_norm.cc
Normal file
157
examples/ffi/src/jax_ffi_example/rms_norm.cc
Normal file
@ -0,0 +1,157 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "xla/ffi/api/c_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
// This is the example "library function" that we want to expose to JAX. This
|
||||
// isn't meant to be a particularly good implementation, it's just here as a
|
||||
// placeholder for the purposes of this tutorial.
|
||||
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
|
||||
float sm = 0.0f;
|
||||
for (int64_t n = 0; n < size; ++n) {
|
||||
sm += x[n] * x[n];
|
||||
}
|
||||
float scale = 1.0f / std::sqrt(sm / float(size) + eps);
|
||||
for (int64_t n = 0; n < size; ++n) {
|
||||
y[n] = x[n] * scale;
|
||||
}
|
||||
return scale;
|
||||
}
|
||||
|
||||
// A helper function for extracting the relevant dimensions from `ffi::Buffer`s.
|
||||
// In this example, we treat all leading dimensions as batch dimensions, so this
|
||||
// function returns the total number of elements in the buffer, and the size of
|
||||
// the last dimension.
|
||||
template <ffi::DataType T>
|
||||
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
|
||||
auto dims = buffer.dimensions();
|
||||
if (dims.size() == 0) {
|
||||
return std::make_pair(0, 0);
|
||||
}
|
||||
return std::make_pair(buffer.element_count(), dims.back());
|
||||
}
|
||||
|
||||
// A wrapper function providing the interface between the XLA FFI call and our
|
||||
// library function `ComputeRmsNorm` above. This function handles the batch
|
||||
// dimensions by calling `ComputeRmsNorm` within a loop.
|
||||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::Result<ffi::Buffer<ffi::F32>> y) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNorm input must be an array");
|
||||
}
|
||||
for (int64_t n = 0; n < totalSize; n += lastDim) {
|
||||
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
|
||||
// this handler in a header, you can use the `XLA_FFI_DECLASE_HANDLER_SYMBOL`
|
||||
// macro: `XLA_FFI_DECLASE_HANDLER_SYMBOL(RmsNorm)`.
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Attr<float>("eps")
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // y
|
||||
);
|
||||
|
||||
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::Result<ffi::Buffer<ffi::F32>> y,
|
||||
ffi::Result<ffi::Buffer<ffi::F32>> res) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNormFwd input must be an array");
|
||||
}
|
||||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
|
||||
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
|
||||
&(y->typed_data()[n]));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Attr<float>("eps")
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // y
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // res
|
||||
);
|
||||
|
||||
void ComputeRmsNormBwd(int64_t size, float res, const float *x,
|
||||
const float *ct_y, float *ct_x) {
|
||||
float ct_res = 0.0f;
|
||||
for (int64_t n = 0; n < size; ++n) {
|
||||
ct_res += x[n] * ct_y[n];
|
||||
}
|
||||
float factor = ct_res * res * res * res / float(size);
|
||||
for (int64_t n = 0; n < size; ++n) {
|
||||
ct_x[n] = res * ct_y[n] - factor * x[n];
|
||||
}
|
||||
}
|
||||
|
||||
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
|
||||
ffi::Buffer<ffi::F32> ct_y,
|
||||
ffi::Result<ffi::Buffer<ffi::F32>> ct_x) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNormBwd inputs must be arrays");
|
||||
}
|
||||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
|
||||
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
|
||||
&(ct_y.typed_data()[n]), &(ct_x->typed_data()[n]));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // res
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // ct_y
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // ct_x
|
||||
);
|
||||
|
||||
template <typename T>
|
||||
nb::capsule EncapsulateFfiHandler(T *fn) {
|
||||
static_assert(std::is_invocable_r_v<XLA_FFI_Error *, T, XLA_FFI_CallFrame *>,
|
||||
"Encapsulated function must be and XLA FFI handler");
|
||||
return nb::capsule(reinterpret_cast<void *>(fn));
|
||||
}
|
||||
|
||||
NB_MODULE(_rms_norm, m) {
|
||||
m.def("registrations", []() {
|
||||
nb::dict registrations;
|
||||
registrations["rms_norm"] = EncapsulateFfiHandler(RmsNorm);
|
||||
registrations["rms_norm_fwd"] = EncapsulateFfiHandler(RmsNormFwd);
|
||||
registrations["rms_norm_bwd"] = EncapsulateFfiHandler(RmsNormBwd);
|
||||
return registrations;
|
||||
});
|
||||
}
|
99
examples/ffi/src/jax_ffi_example/rms_norm.py
Normal file
99
examples/ffi/src/jax_ffi_example/rms_norm.py
Normal file
@ -0,0 +1,99 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""An example demontrating the basic end-to-end use of the JAX FFI.
|
||||
|
||||
This example is exactly the same as the one in the `FFI tutorial
|
||||
<https://jax.readthedocs.io/en/latest/ffi.html>`, so more details can be found
|
||||
on that page. But, the high level summary is that we implement our custom
|
||||
extension in ``rms_norm.cc``, then call it usin ``jax.extend.ffi.ffi_call`` in
|
||||
this module. The behavior under autodiff is implemented using
|
||||
``jax.custom_vjp``.
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.extend as jex
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax_ffi_example import _rms_norm
|
||||
|
||||
for name, target in _rms_norm.registrations().items():
|
||||
jex.ffi.register_ffi_target(name, target)
|
||||
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(1,))
|
||||
def rms_norm(x, eps=1e-5):
|
||||
# We only implemented the `float32` version of this function, so we start by
|
||||
# checking the dtype. This check isn't strictly necessary because type
|
||||
# checking is also performed by the FFI when decoding input and output
|
||||
# buffers, but it can be useful to check types in Python to raise more
|
||||
# informative errors.
|
||||
if x.dtype != jnp.float32:
|
||||
raise ValueError("Only the float32 dtype is implemented by rms_norm")
|
||||
|
||||
# In this case, the output of our FFI function is just a single array with the
|
||||
# same shape and dtype as the input.
|
||||
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
|
||||
return jex.ffi.ffi_call(
|
||||
# The target name must be the same string as we used to register the target
|
||||
# above in `register_ffi_target`
|
||||
"rms_norm",
|
||||
out_type,
|
||||
x,
|
||||
# Note that here we're use `numpy` (not `jax.numpy`) to specify a dtype for
|
||||
# the attribute `eps`. Our FFI function expects this to have the C++ `float`
|
||||
# type (which corresponds to numpy's `float32` type), and it must be a
|
||||
# static parameter (i.e. not a JAX array).
|
||||
eps=np.float32(eps),
|
||||
# The `vectorized` parameter controls this function's behavior under `vmap`.
|
||||
vectorized=True,
|
||||
)
|
||||
|
||||
|
||||
def rms_norm_fwd(x, eps=1e-5):
|
||||
y, res = jex.ffi.ffi_call(
|
||||
"rms_norm_fwd",
|
||||
(
|
||||
jax.ShapeDtypeStruct(x.shape, x.dtype),
|
||||
jax.ShapeDtypeStruct(x.shape[:-1], x.dtype),
|
||||
),
|
||||
x,
|
||||
eps=np.float32(eps),
|
||||
vectorized=True,
|
||||
)
|
||||
return y, (res, x)
|
||||
|
||||
|
||||
def rms_norm_bwd(eps, res, ct):
|
||||
del eps
|
||||
res, x = res
|
||||
assert res.shape == ct.shape[:-1]
|
||||
assert x.shape == ct.shape
|
||||
return (
|
||||
jex.ffi.ffi_call(
|
||||
"rms_norm_bwd",
|
||||
jax.ShapeDtypeStruct(ct.shape, ct.dtype),
|
||||
res,
|
||||
x,
|
||||
ct,
|
||||
vectorized=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
rms_norm.defvjp(rms_norm_fwd, rms_norm_bwd)
|
46
examples/ffi/tests/rms_norm_test.py
Normal file
46
examples/ffi/tests/rms_norm_test.py
Normal file
@ -0,0 +1,46 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax_ffi_example import rms_norm
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def rms_norm_ref(x, eps=1e-5):
|
||||
scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
|
||||
return x / scale
|
||||
|
||||
|
||||
class RmsNormTests(jtu.JaxTestCase):
|
||||
def test_basic(self):
|
||||
x = jnp.linspace(-0.5, 0.5, 15)
|
||||
self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x))
|
||||
|
||||
def test_batching(self):
|
||||
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
|
||||
self.assertAllClose(jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x))
|
||||
|
||||
def test_grads(self):
|
||||
x = jnp.linspace(-0.5, 0.5, 15).reshape((3, 5))
|
||||
jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -70,7 +70,7 @@ doctest_optionflags = [
|
||||
"NUMBER",
|
||||
"NORMALIZE_WHITESPACE"
|
||||
]
|
||||
addopts = "--doctest-glob='*.rst'"
|
||||
addopts = "--doctest-glob='*.rst' --ignore='examples/ffi'"
|
||||
|
||||
[tool.pylint.master]
|
||||
extension-pkg-whitelist = "numpy"
|
||||
|
Loading…
x
Reference in New Issue
Block a user