Merge pull request #23805 from dfm:ffi-examples-state

PiperOrigin-RevId: 696383873
This commit is contained in:
jax authors 2024-11-13 21:43:41 -08:00
commit 83700828c5
5 changed files with 179 additions and 9 deletions

View File

@ -12,13 +12,18 @@ message(STATUS "XLA include directory: ${XLA_DIR}")
find_package(nanobind CONFIG REQUIRED)
nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
set(
JAX_FFI_EXAMPLE_PROJECTS
"rms_norm"
"attrs"
"counter"
)
nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc")
target_include_directories(_attrs PUBLIC ${XLA_DIR})
install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc")
target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR})
install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
endforeach()
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
enable_language(CUDA)

View File

@ -3,7 +3,26 @@
This directory includes an example project demonstrating the use of JAX's
foreign function interface (FFI). The JAX docs provide more information about
this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html),
but the example in this directory explicitly demonstrates:
but the example in this directory complements that document by demonstrating
(and testing!) the full packaging workflow, and some more advanced use cases.
Within the example project, there are several example calls:
1. One way to package and distribute FFI targets, and
2. Some more advanced use cases.
1. `rms_norm`: This is the example from the tutorial on the JAX docs, and it
demonstrates the most basic use of the FFI. It also includes customization of
behavior under automatic differentiation using `jax.custom_vjp`.
2. `counter`: This example demonstrates a common pattern for how an FFI call can
use global cache to maintain state between calls. This pattern is useful when
an FFI call requires an expensive initialization step which shouldn't be
run on every execution, or if there is other shared state that could be
reused between calls. In this simple example we just count the number of
times the call was executed.
3. `attrs`: An example demonstrating the different ways that attributes can be
passed to the FFI. For example, we can pass arrays, variadic attributes, and
user-defined types. Full support of user-defined types isn't yet supported
by XLA, so that example will be added in the future.
4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with
CUDA. The specifics of the kernels are not very important, but the general
structure, and packaging of the extension are useful for testing.

View File

@ -0,0 +1,53 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <mutex>
#include <string_view>
#include <unordered_map>
#include "nanobind/nanobind.h"
#include "xla/ffi/api/ffi.h"
namespace nb = nanobind;
namespace ffi = xla::ffi;
ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0<ffi::S32> out) {
static std::mutex mutex;
static auto& cache = *new std::unordered_map<int64_t, int32_t>();
{
const std::lock_guard<std::mutex> lock(mutex);
auto it = cache.find(index);
if (it != cache.end()) {
out->typed_data()[0] = ++it->second;
} else {
cache.insert({index, 0});
out->typed_data()[0] = 0;
}
}
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
Counter, CounterImpl,
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
NB_MODULE(_counter, m) {
m.def("registrations", []() {
nb::dict registrations;
registrations["counter"] = nb::capsule(reinterpret_cast<void*>(Counter));
return registrations;
});
}

View File

@ -0,0 +1,38 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example demonstrating how an FFI call can maintain "state" between calls
In this case, the ``counter`` call simply accumulates the number of times it
was executed, but this pattern can also be used for more advanced use cases.
For example, this pattern is used in jaxlib for:
1. The GPU solver linear algebra kernels which require an expensive "handler"
initialization, and
2. The ``triton_call`` function which caches the compiled triton modules after
their first use.
"""
import jax
import jax.extend as jex
from jax_ffi_example import _counter
for name, target in _counter.registrations().items():
jex.ffi.register_ffi_target(name, target)
def counter(index):
return jex.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))

View File

@ -0,0 +1,55 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax_ffi_example import counter
jax.config.parse_flags_with_absl()
class CounterTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")
def test_basic(self):
self.assertEqual(counter.counter(0), 0)
self.assertEqual(counter.counter(0), 1)
self.assertEqual(counter.counter(0), 2)
self.assertEqual(counter.counter(1), 0)
self.assertEqual(counter.counter(0), 3)
def test_jit(self):
@jax.jit
def counter_fun(x):
return x, counter.counter(2)
self.assertEqual(counter_fun(0)[1], 0)
self.assertEqual(counter_fun(0)[1], 1)
# Persists across different cache hits
self.assertEqual(counter_fun(1)[1], 2)
# Persists after the cache is cleared
counter_fun.clear_cache()
self.assertEqual(counter_fun(0)[1], 3)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())