Refactor FFI examples to consolidate several examples into one submodule.

This commit is contained in:
Dan Foreman-Mackey 2024-11-21 12:44:27 -05:00
parent 69e3f0d37d
commit 84a9cba85b
11 changed files with 122 additions and 189 deletions

View File

@ -15,8 +15,7 @@ find_package(nanobind CONFIG REQUIRED)
set(
JAX_FFI_EXAMPLE_PROJECTS
"rms_norm"
"attrs"
"counter"
"cpu_examples"
)
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
@ -27,9 +26,9 @@ endforeach()
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
enable_language(CUDA)
add_library(_cuda_e2e SHARED "src/jax_ffi_example/cuda_e2e.cu")
set_target_properties(_cuda_e2e PROPERTIES POSITION_INDEPENDENT_CODE ON
CUDA_STANDARD 17)
target_include_directories(_cuda_e2e PUBLIC ${XLA_DIR})
install(TARGETS _cuda_e2e LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
CUDA_STANDARD 17)
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
endif()

View File

@ -11,18 +11,19 @@ Within the example project, there are several example calls:
demonstrates the most basic use of the FFI. It also includes customization of
behavior under automatic differentiation using `jax.custom_vjp`.
2. `counter`: This example demonstrates a common pattern for how an FFI call can
use global cache to maintain state between calls. This pattern is useful when
an FFI call requires an expensive initialization step which shouldn't be
run on every execution, or if there is other shared state that could be
reused between calls. In this simple example we just count the number of
times the call was executed.
2. `cpu_examples`: This submodule includes several smaller examples:
3. `attrs`: An example demonstrating the different ways that attributes can be
passed to the FFI. For example, we can pass arrays, variadic attributes, and
user-defined types. Full support of user-defined types isn't yet supported
by XLA, so that example will be added in the future.
* `counter`: This example demonstrates a common pattern for how an FFI call
can use global cache to maintain state between calls. This pattern is
useful when an FFI call requires an expensive initialization step which
shouldn't be run on every execution, or if there is other shared state
that could be reused between calls. In this simple example we just count
the number of times the call was executed.
* `attrs`: An example demonstrating the different ways that attributes can be
passed to the FFI. For example, we can pass arrays, variadic attributes,
and user-defined types. Full support of user-defined types isn't yet
supported by XLA, so that example will be added in the future.
4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with
CUDA. The specifics of the kernels are not very important, but the general
structure, and packaging of the extension are useful for testing.
3. `cuda_examples`: An end-to-end example demonstrating the use of the JAX FFI
with CUDA. The specifics of the kernels are not very important, but the
general structure, and packaging of the extension are useful for testing.

View File

@ -1,53 +0,0 @@
/* Copyright 2024 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <mutex>
#include <string_view>
#include <unordered_map>
#include "nanobind/nanobind.h"
#include "xla/ffi/api/ffi.h"
namespace nb = nanobind;
namespace ffi = xla::ffi;
ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0<ffi::S32> out) {
static std::mutex mutex;
static auto& cache = *new std::unordered_map<int64_t, int32_t>();
{
const std::lock_guard<std::mutex> lock(mutex);
auto it = cache.find(index);
if (it != cache.end()) {
out->typed_data()[0] = ++it->second;
} else {
cache.insert({index, 0});
out->typed_data()[0] = 0;
}
}
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
Counter, CounterImpl,
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
NB_MODULE(_counter, m) {
m.def("registrations", []() {
nb::dict registrations;
registrations["counter"] = nb::capsule(reinterpret_cast<void*>(Counter));
return registrations;
});
}

View File

@ -1,38 +0,0 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example demonstrating how an FFI call can maintain "state" between calls
In this case, the ``counter`` call simply accumulates the number of times it
was executed, but this pattern can also be used for more advanced use cases.
For example, this pattern is used in jaxlib for:
1. The GPU solver linear algebra kernels which require an expensive "handler"
initialization, and
2. The ``triton_call`` function which caches the compiled triton modules after
their first use.
"""
import jax
import jax.extend as jex
from jax_ffi_example import _counter
for name, target in _counter.registrations().items():
jex.ffi.register_ffi_target(name, target)
def counter(index):
return jex.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))

View File

@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include <cstdint>
#include <mutex>
#include <string_view>
#include <unordered_map>
#include "nanobind/nanobind.h"
#include "xla/ffi/api/ffi.h"
@ -21,6 +24,17 @@ limitations under the License.
namespace nb = nanobind;
namespace ffi = xla::ffi;
// ----------
// Attributes
// ----------
//
// An example demonstrating the different ways that attributes can be passed to
// the FFI.
//
// For example, we can pass arrays, variadic attributes, and user-defined types.
// Full support of user-defined types isn't yet supported by XLA, so that
// example will be added in the future.
ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
ffi::ResultBufferR0<ffi::S32> res) {
int64_t total = 0;
@ -54,13 +68,52 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl,
.Ret<ffi::BufferR0<ffi::S32>>()
.Ret<ffi::BufferR0<ffi::S32>>());
NB_MODULE(_attrs, m) {
// -------
// Counter
// -------
//
// An example demonstrating how an FFI call can maintain "state" between calls
//
// In this case, the ``Counter`` call simply accumulates the number of times it
// was executed, but this pattern can also be used for more advanced use cases.
// For example, this pattern is used in jaxlib for:
//
// 1. The GPU solver linear algebra kernels which require an expensive "handler"
// initialization, and
// 2. The ``triton_call`` function which caches the compiled triton modules
// after their first use.
ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0<ffi::S32> out) {
static std::mutex mutex;
static auto &cache = *new std::unordered_map<int64_t, int32_t>();
{
const std::lock_guard<std::mutex> lock(mutex);
auto it = cache.find(index);
if (it != cache.end()) {
out->typed_data()[0] = ++it->second;
} else {
cache.insert({index, 0});
out->typed_data()[0] = 0;
}
}
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
Counter, CounterImpl,
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
// Boilerplate for exposing handlers to Python
NB_MODULE(_cpu_examples, m) {
m.def("registrations", []() {
nb::dict registrations;
registrations["array_attr"] =
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
registrations["dictionary_attr"] =
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));
registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));
return registrations;
});
}

View File

@ -12,22 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""An example demonstrating the different ways that attributes can be passed to
the FFI.
For example, we can pass arrays, variadic attributes, and user-defined types.
Full support of user-defined types isn't yet supported by XLA, so that example
will be added in the future.
"""
import numpy as np
import jax
import jax.extend as jex
from jax_ffi_example import _attrs
from jax_ffi_example import _cpu_examples
for name, target in _attrs.registrations().items():
for name, target in _cpu_examples.registrations().items():
jex.ffi.register_ffi_target(name, target)
@ -43,3 +35,8 @@ def dictionary_attr(**kwargs):
"dictionary_attr",
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
)(**kwargs)
def counter(index):
return jex.ffi.ffi_call(
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))

View File

@ -27,7 +27,7 @@ import jax.numpy as jnp
import jax.extend as jex
# Load the shared library with the FFI target definitions
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_e2e.so")
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so")
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd),

View File

@ -1,55 +0,0 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
from jax_ffi_example import counter
jax.config.parse_flags_with_absl()
class CounterTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")
def test_basic(self):
self.assertEqual(counter.counter(0), 0)
self.assertEqual(counter.counter(0), 1)
self.assertEqual(counter.counter(0), 2)
self.assertEqual(counter.counter(1), 0)
self.assertEqual(counter.counter(0), 3)
def test_jit(self):
@jax.jit
def counter_fun(x):
return x, counter.counter(2)
self.assertEqual(counter_fun(0)[1], 0)
self.assertEqual(counter_fun(0)[1], 1)
# Persists across different cache hits
self.assertEqual(counter_fun(1)[1], 2)
# Persists after the cache is cleared
counter_fun.clear_cache()
self.assertEqual(counter_fun(0)[1], 3)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())

View File

@ -18,7 +18,7 @@ import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax_ffi_example import attrs
from jax_ffi_example import cpu_examples
jax.config.parse_flags_with_absl()
@ -30,11 +30,11 @@ class AttrsTests(jtu.JaxTestCase):
self.skipTest("Unsupported platform")
def test_array_attr(self):
self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum())
self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum())
self.assertEqual(cpu_examples.array_attr(5), jnp.arange(5).sum())
self.assertEqual(cpu_examples.array_attr(3), jnp.arange(3).sum())
def test_array_attr_jit_cache(self):
jit_array_attr = jax.jit(attrs.array_attr, static_argnums=(0,))
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
with jtu.count_jit_and_pmap_lowerings() as count:
jit_array_attr(5)
self.assertEqual(count[0], 1) # compiles once the first time
@ -44,22 +44,51 @@ class AttrsTests(jtu.JaxTestCase):
def test_array_attr_no_jit(self):
with jax.disable_jit():
attrs.array_attr(5) # doesn't crash
cpu_examples.array_attr(5) # doesn't crash
def test_dictionary_attr(self):
secret, count = attrs.dictionary_attr(secret=5)
secret, count = cpu_examples.dictionary_attr(secret=5)
self.assertEqual(secret, 5)
self.assertEqual(count, 1)
secret, count = attrs.dictionary_attr(secret=3, a_string="hello")
secret, count = cpu_examples.dictionary_attr(secret=3, a_string="hello")
self.assertEqual(secret, 3)
self.assertEqual(count, 2)
with self.assertRaisesRegex(Exception, "Unexpected attribute"):
attrs.dictionary_attr()
cpu_examples.dictionary_attr()
with self.assertRaisesRegex(Exception, "Wrong attribute type"):
attrs.dictionary_attr(secret="invalid")
cpu_examples.dictionary_attr(secret="invalid")
class CounterTests(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")
def test_basic(self):
self.assertEqual(cpu_examples.counter(0), 0)
self.assertEqual(cpu_examples.counter(0), 1)
self.assertEqual(cpu_examples.counter(0), 2)
self.assertEqual(cpu_examples.counter(1), 0)
self.assertEqual(cpu_examples.counter(0), 3)
def test_jit(self):
@jax.jit
def counter_fun(x):
return x, cpu_examples.counter(2)
self.assertEqual(counter_fun(0)[1], 0)
self.assertEqual(counter_fun(0)[1], 1)
# Persists across different cache hits
self.assertEqual(counter_fun(1)[1], 2)
# Persists after the cache is cleared
counter_fun.clear_cache()
self.assertEqual(counter_fun(0)[1], 3)
if __name__ == "__main__":

View File

@ -28,8 +28,8 @@ class CudaE2eTests(jtu.JaxTestCase):
self.skipTest("Unsupported platform")
# Import here to avoid trying to load the library when it's not built.
from jax_ffi_example import cuda_e2e
self.foo = cuda_e2e.foo
from jax_ffi_example import cuda_examples
self.foo = cuda_examples.foo
def test_fwd_interpretable(self):
shape = (2, 3)