mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #23805 from dfm:ffi-examples-state
PiperOrigin-RevId: 696383873
This commit is contained in:
commit
83700828c5
@ -12,13 +12,18 @@ message(STATUS "XLA include directory: ${XLA_DIR}")
|
||||
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
nanobind_add_module(_rms_norm NB_STATIC "src/jax_ffi_example/rms_norm.cc")
|
||||
target_include_directories(_rms_norm PUBLIC ${XLA_DIR})
|
||||
install(TARGETS _rms_norm LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
set(
|
||||
JAX_FFI_EXAMPLE_PROJECTS
|
||||
"rms_norm"
|
||||
"attrs"
|
||||
"counter"
|
||||
)
|
||||
|
||||
nanobind_add_module(_attrs NB_STATIC "src/jax_ffi_example/attrs.cc")
|
||||
target_include_directories(_attrs PUBLIC ${XLA_DIR})
|
||||
install(TARGETS _attrs LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
|
||||
nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc")
|
||||
target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR})
|
||||
install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
endforeach()
|
||||
|
||||
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
|
||||
enable_language(CUDA)
|
||||
|
@ -3,7 +3,26 @@
|
||||
This directory includes an example project demonstrating the use of JAX's
|
||||
foreign function interface (FFI). The JAX docs provide more information about
|
||||
this interface in [the FFI tutorial](https://jax.readthedocs.io/en/latest/ffi.html),
|
||||
but the example in this directory explicitly demonstrates:
|
||||
but the example in this directory complements that document by demonstrating
|
||||
(and testing!) the full packaging workflow, and some more advanced use cases.
|
||||
Within the example project, there are several example calls:
|
||||
|
||||
1. One way to package and distribute FFI targets, and
|
||||
2. Some more advanced use cases.
|
||||
1. `rms_norm`: This is the example from the tutorial on the JAX docs, and it
|
||||
demonstrates the most basic use of the FFI. It also includes customization of
|
||||
behavior under automatic differentiation using `jax.custom_vjp`.
|
||||
|
||||
2. `counter`: This example demonstrates a common pattern for how an FFI call can
|
||||
use global cache to maintain state between calls. This pattern is useful when
|
||||
an FFI call requires an expensive initialization step which shouldn't be
|
||||
run on every execution, or if there is other shared state that could be
|
||||
reused between calls. In this simple example we just count the number of
|
||||
times the call was executed.
|
||||
|
||||
3. `attrs`: An example demonstrating the different ways that attributes can be
|
||||
passed to the FFI. For example, we can pass arrays, variadic attributes, and
|
||||
user-defined types. Full support of user-defined types isn't yet supported
|
||||
by XLA, so that example will be added in the future.
|
||||
|
||||
4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with
|
||||
CUDA. The specifics of the kernels are not very important, but the general
|
||||
structure, and packaging of the extension are useful for testing.
|
||||
|
53
examples/ffi/src/jax_ffi_example/counter.cc
Normal file
53
examples/ffi/src/jax_ffi_example/counter.cc
Normal file
@ -0,0 +1,53 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <mutex>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0<ffi::S32> out) {
|
||||
static std::mutex mutex;
|
||||
static auto& cache = *new std::unordered_map<int64_t, int32_t>();
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock(mutex);
|
||||
auto it = cache.find(index);
|
||||
if (it != cache.end()) {
|
||||
out->typed_data()[0] = ++it->second;
|
||||
} else {
|
||||
cache.insert({index, 0});
|
||||
out->typed_data()[0] = 0;
|
||||
}
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
Counter, CounterImpl,
|
||||
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
|
||||
|
||||
NB_MODULE(_counter, m) {
|
||||
m.def("registrations", []() {
|
||||
nb::dict registrations;
|
||||
registrations["counter"] = nb::capsule(reinterpret_cast<void*>(Counter));
|
||||
return registrations;
|
||||
});
|
||||
}
|
38
examples/ffi/src/jax_ffi_example/counter.py
Normal file
38
examples/ffi/src/jax_ffi_example/counter.py
Normal file
@ -0,0 +1,38 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An example demonstrating how an FFI call can maintain "state" between calls
|
||||
|
||||
In this case, the ``counter`` call simply accumulates the number of times it
|
||||
was executed, but this pattern can also be used for more advanced use cases.
|
||||
For example, this pattern is used in jaxlib for:
|
||||
|
||||
1. The GPU solver linear algebra kernels which require an expensive "handler"
|
||||
initialization, and
|
||||
2. The ``triton_call`` function which caches the compiled triton modules after
|
||||
their first use.
|
||||
"""
|
||||
|
||||
import jax
|
||||
import jax.extend as jex
|
||||
|
||||
from jax_ffi_example import _counter
|
||||
|
||||
for name, target in _counter.registrations().items():
|
||||
jex.ffi.register_ffi_target(name, target)
|
||||
|
||||
|
||||
def counter(index):
|
||||
return jex.ffi.ffi_call(
|
||||
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))
|
55
examples/ffi/tests/counter_test.py
Normal file
55
examples/ffi/tests/counter_test.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax_ffi_example import counter
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class CounterTests(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cpu"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(counter.counter(0), 0)
|
||||
self.assertEqual(counter.counter(0), 1)
|
||||
self.assertEqual(counter.counter(0), 2)
|
||||
self.assertEqual(counter.counter(1), 0)
|
||||
self.assertEqual(counter.counter(0), 3)
|
||||
|
||||
def test_jit(self):
|
||||
@jax.jit
|
||||
def counter_fun(x):
|
||||
return x, counter.counter(2)
|
||||
|
||||
self.assertEqual(counter_fun(0)[1], 0)
|
||||
self.assertEqual(counter_fun(0)[1], 1)
|
||||
|
||||
# Persists across different cache hits
|
||||
self.assertEqual(counter_fun(1)[1], 2)
|
||||
|
||||
# Persists after the cache is cleared
|
||||
counter_fun.clear_cache()
|
||||
self.assertEqual(counter_fun(0)[1], 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user