mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Refactor FFI examples to consolidate several examples into one submodule.
This commit is contained in:
parent
69e3f0d37d
commit
84a9cba85b
@ -15,8 +15,7 @@ find_package(nanobind CONFIG REQUIRED)
|
||||
set(
|
||||
JAX_FFI_EXAMPLE_PROJECTS
|
||||
"rms_norm"
|
||||
"attrs"
|
||||
"counter"
|
||||
"cpu_examples"
|
||||
)
|
||||
|
||||
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
|
||||
@ -27,9 +26,9 @@ endforeach()
|
||||
|
||||
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
|
||||
enable_language(CUDA)
|
||||
add_library(_cuda_e2e SHARED "src/jax_ffi_example/cuda_e2e.cu")
|
||||
set_target_properties(_cuda_e2e PROPERTIES POSITION_INDEPENDENT_CODE ON
|
||||
CUDA_STANDARD 17)
|
||||
target_include_directories(_cuda_e2e PUBLIC ${XLA_DIR})
|
||||
install(TARGETS _cuda_e2e LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
|
||||
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
|
||||
CUDA_STANDARD 17)
|
||||
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
|
||||
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
endif()
|
||||
|
@ -11,18 +11,19 @@ Within the example project, there are several example calls:
|
||||
demonstrates the most basic use of the FFI. It also includes customization of
|
||||
behavior under automatic differentiation using `jax.custom_vjp`.
|
||||
|
||||
2. `counter`: This example demonstrates a common pattern for how an FFI call can
|
||||
use global cache to maintain state between calls. This pattern is useful when
|
||||
an FFI call requires an expensive initialization step which shouldn't be
|
||||
run on every execution, or if there is other shared state that could be
|
||||
reused between calls. In this simple example we just count the number of
|
||||
times the call was executed.
|
||||
2. `cpu_examples`: This submodule includes several smaller examples:
|
||||
|
||||
3. `attrs`: An example demonstrating the different ways that attributes can be
|
||||
passed to the FFI. For example, we can pass arrays, variadic attributes, and
|
||||
user-defined types. Full support of user-defined types isn't yet supported
|
||||
by XLA, so that example will be added in the future.
|
||||
* `counter`: This example demonstrates a common pattern for how an FFI call
|
||||
can use global cache to maintain state between calls. This pattern is
|
||||
useful when an FFI call requires an expensive initialization step which
|
||||
shouldn't be run on every execution, or if there is other shared state
|
||||
that could be reused between calls. In this simple example we just count
|
||||
the number of times the call was executed.
|
||||
* `attrs`: An example demonstrating the different ways that attributes can be
|
||||
passed to the FFI. For example, we can pass arrays, variadic attributes,
|
||||
and user-defined types. Full support of user-defined types isn't yet
|
||||
supported by XLA, so that example will be added in the future.
|
||||
|
||||
4. `cuda_e2e`: An end-to-end example demonstrating the use of the JAX FFI with
|
||||
CUDA. The specifics of the kernels are not very important, but the general
|
||||
structure, and packaging of the extension are useful for testing.
|
||||
3. `cuda_examples`: An end-to-end example demonstrating the use of the JAX FFI
|
||||
with CUDA. The specifics of the kernels are not very important, but the
|
||||
general structure, and packaging of the extension are useful for testing.
|
||||
|
@ -1,53 +0,0 @@
|
||||
/* Copyright 2024 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <mutex>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0<ffi::S32> out) {
|
||||
static std::mutex mutex;
|
||||
static auto& cache = *new std::unordered_map<int64_t, int32_t>();
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock(mutex);
|
||||
auto it = cache.find(index);
|
||||
if (it != cache.end()) {
|
||||
out->typed_data()[0] = ++it->second;
|
||||
} else {
|
||||
cache.insert({index, 0});
|
||||
out->typed_data()[0] = 0;
|
||||
}
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
Counter, CounterImpl,
|
||||
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
|
||||
|
||||
NB_MODULE(_counter, m) {
|
||||
m.def("registrations", []() {
|
||||
nb::dict registrations;
|
||||
registrations["counter"] = nb::capsule(reinterpret_cast<void*>(Counter));
|
||||
return registrations;
|
||||
});
|
||||
}
|
@ -1,38 +0,0 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An example demonstrating how an FFI call can maintain "state" between calls
|
||||
|
||||
In this case, the ``counter`` call simply accumulates the number of times it
|
||||
was executed, but this pattern can also be used for more advanced use cases.
|
||||
For example, this pattern is used in jaxlib for:
|
||||
|
||||
1. The GPU solver linear algebra kernels which require an expensive "handler"
|
||||
initialization, and
|
||||
2. The ``triton_call`` function which caches the compiled triton modules after
|
||||
their first use.
|
||||
"""
|
||||
|
||||
import jax
|
||||
import jax.extend as jex
|
||||
|
||||
from jax_ffi_example import _counter
|
||||
|
||||
for name, target in _counter.registrations().items():
|
||||
jex.ffi.register_ffi_target(name, target)
|
||||
|
||||
|
||||
def counter(index):
|
||||
return jex.ffi.ffi_call(
|
||||
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))
|
@ -14,6 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <mutex>
|
||||
#include <string_view>
|
||||
#include <unordered_map>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
@ -21,6 +24,17 @@ limitations under the License.
|
||||
namespace nb = nanobind;
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
// ----------
|
||||
// Attributes
|
||||
// ----------
|
||||
//
|
||||
// An example demonstrating the different ways that attributes can be passed to
|
||||
// the FFI.
|
||||
//
|
||||
// For example, we can pass arrays, variadic attributes, and user-defined types.
|
||||
// Full support of user-defined types isn't yet supported by XLA, so that
|
||||
// example will be added in the future.
|
||||
|
||||
ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
|
||||
ffi::ResultBufferR0<ffi::S32> res) {
|
||||
int64_t total = 0;
|
||||
@ -54,13 +68,52 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DictionaryAttr, DictionaryAttrImpl,
|
||||
.Ret<ffi::BufferR0<ffi::S32>>()
|
||||
.Ret<ffi::BufferR0<ffi::S32>>());
|
||||
|
||||
NB_MODULE(_attrs, m) {
|
||||
// -------
|
||||
// Counter
|
||||
// -------
|
||||
//
|
||||
// An example demonstrating how an FFI call can maintain "state" between calls
|
||||
//
|
||||
// In this case, the ``Counter`` call simply accumulates the number of times it
|
||||
// was executed, but this pattern can also be used for more advanced use cases.
|
||||
// For example, this pattern is used in jaxlib for:
|
||||
//
|
||||
// 1. The GPU solver linear algebra kernels which require an expensive "handler"
|
||||
// initialization, and
|
||||
// 2. The ``triton_call`` function which caches the compiled triton modules
|
||||
// after their first use.
|
||||
|
||||
ffi::Error CounterImpl(int64_t index, ffi::ResultBufferR0<ffi::S32> out) {
|
||||
static std::mutex mutex;
|
||||
static auto &cache = *new std::unordered_map<int64_t, int32_t>();
|
||||
{
|
||||
const std::lock_guard<std::mutex> lock(mutex);
|
||||
auto it = cache.find(index);
|
||||
if (it != cache.end()) {
|
||||
out->typed_data()[0] = ++it->second;
|
||||
} else {
|
||||
cache.insert({index, 0});
|
||||
out->typed_data()[0] = 0;
|
||||
}
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
Counter, CounterImpl,
|
||||
ffi::Ffi::Bind().Attr<int64_t>("index").Ret<ffi::BufferR0<ffi::S32>>());
|
||||
|
||||
// Boilerplate for exposing handlers to Python
|
||||
NB_MODULE(_cpu_examples, m) {
|
||||
m.def("registrations", []() {
|
||||
nb::dict registrations;
|
||||
registrations["array_attr"] =
|
||||
nb::capsule(reinterpret_cast<void *>(ArrayAttr));
|
||||
registrations["dictionary_attr"] =
|
||||
nb::capsule(reinterpret_cast<void *>(DictionaryAttr));
|
||||
|
||||
registrations["counter"] = nb::capsule(reinterpret_cast<void *>(Counter));
|
||||
|
||||
return registrations;
|
||||
});
|
||||
}
|
@ -12,22 +12,14 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""An example demonstrating the different ways that attributes can be passed to
|
||||
the FFI.
|
||||
|
||||
For example, we can pass arrays, variadic attributes, and user-defined types.
|
||||
Full support of user-defined types isn't yet supported by XLA, so that example
|
||||
will be added in the future.
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.extend as jex
|
||||
|
||||
from jax_ffi_example import _attrs
|
||||
from jax_ffi_example import _cpu_examples
|
||||
|
||||
for name, target in _attrs.registrations().items():
|
||||
for name, target in _cpu_examples.registrations().items():
|
||||
jex.ffi.register_ffi_target(name, target)
|
||||
|
||||
|
||||
@ -43,3 +35,8 @@ def dictionary_attr(**kwargs):
|
||||
"dictionary_attr",
|
||||
(jax.ShapeDtypeStruct((), np.int32), jax.ShapeDtypeStruct((), np.int32)),
|
||||
)(**kwargs)
|
||||
|
||||
|
||||
def counter(index):
|
||||
return jex.ffi.ffi_call(
|
||||
"counter", jax.ShapeDtypeStruct((), jax.numpy.int32))(index=int(index))
|
@ -27,7 +27,7 @@ import jax.numpy as jnp
|
||||
import jax.extend as jex
|
||||
|
||||
# Load the shared library with the FFI target definitions
|
||||
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_e2e.so")
|
||||
SHARED_LIBRARY = os.path.join(os.path.dirname(__file__), "lib_cuda_examples.so")
|
||||
library = ctypes.cdll.LoadLibrary(SHARED_LIBRARY)
|
||||
|
||||
jex.ffi.register_ffi_target("foo-fwd", jex.ffi.pycapsule(library.FooFwd),
|
@ -1,55 +0,0 @@
|
||||
# Copyright 2024 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax_ffi_example import counter
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class CounterTests(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cpu"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(counter.counter(0), 0)
|
||||
self.assertEqual(counter.counter(0), 1)
|
||||
self.assertEqual(counter.counter(0), 2)
|
||||
self.assertEqual(counter.counter(1), 0)
|
||||
self.assertEqual(counter.counter(0), 3)
|
||||
|
||||
def test_jit(self):
|
||||
@jax.jit
|
||||
def counter_fun(x):
|
||||
return x, counter.counter(2)
|
||||
|
||||
self.assertEqual(counter_fun(0)[1], 0)
|
||||
self.assertEqual(counter_fun(0)[1], 1)
|
||||
|
||||
# Persists across different cache hits
|
||||
self.assertEqual(counter_fun(1)[1], 2)
|
||||
|
||||
# Persists after the cache is cleared
|
||||
counter_fun.clear_cache()
|
||||
self.assertEqual(counter_fun(0)[1], 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
@ -18,7 +18,7 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
from jax_ffi_example import attrs
|
||||
from jax_ffi_example import cpu_examples
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
@ -30,11 +30,11 @@ class AttrsTests(jtu.JaxTestCase):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
def test_array_attr(self):
|
||||
self.assertEqual(attrs.array_attr(5), jnp.arange(5).sum())
|
||||
self.assertEqual(attrs.array_attr(3), jnp.arange(3).sum())
|
||||
self.assertEqual(cpu_examples.array_attr(5), jnp.arange(5).sum())
|
||||
self.assertEqual(cpu_examples.array_attr(3), jnp.arange(3).sum())
|
||||
|
||||
def test_array_attr_jit_cache(self):
|
||||
jit_array_attr = jax.jit(attrs.array_attr, static_argnums=(0,))
|
||||
jit_array_attr = jax.jit(cpu_examples.array_attr, static_argnums=(0,))
|
||||
with jtu.count_jit_and_pmap_lowerings() as count:
|
||||
jit_array_attr(5)
|
||||
self.assertEqual(count[0], 1) # compiles once the first time
|
||||
@ -44,22 +44,51 @@ class AttrsTests(jtu.JaxTestCase):
|
||||
|
||||
def test_array_attr_no_jit(self):
|
||||
with jax.disable_jit():
|
||||
attrs.array_attr(5) # doesn't crash
|
||||
cpu_examples.array_attr(5) # doesn't crash
|
||||
|
||||
def test_dictionary_attr(self):
|
||||
secret, count = attrs.dictionary_attr(secret=5)
|
||||
secret, count = cpu_examples.dictionary_attr(secret=5)
|
||||
self.assertEqual(secret, 5)
|
||||
self.assertEqual(count, 1)
|
||||
|
||||
secret, count = attrs.dictionary_attr(secret=3, a_string="hello")
|
||||
secret, count = cpu_examples.dictionary_attr(secret=3, a_string="hello")
|
||||
self.assertEqual(secret, 3)
|
||||
self.assertEqual(count, 2)
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Unexpected attribute"):
|
||||
attrs.dictionary_attr()
|
||||
cpu_examples.dictionary_attr()
|
||||
|
||||
with self.assertRaisesRegex(Exception, "Wrong attribute type"):
|
||||
attrs.dictionary_attr(secret="invalid")
|
||||
cpu_examples.dictionary_attr(secret="invalid")
|
||||
|
||||
|
||||
class CounterTests(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cpu"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(cpu_examples.counter(0), 0)
|
||||
self.assertEqual(cpu_examples.counter(0), 1)
|
||||
self.assertEqual(cpu_examples.counter(0), 2)
|
||||
self.assertEqual(cpu_examples.counter(1), 0)
|
||||
self.assertEqual(cpu_examples.counter(0), 3)
|
||||
|
||||
def test_jit(self):
|
||||
@jax.jit
|
||||
def counter_fun(x):
|
||||
return x, cpu_examples.counter(2)
|
||||
|
||||
self.assertEqual(counter_fun(0)[1], 0)
|
||||
self.assertEqual(counter_fun(0)[1], 1)
|
||||
|
||||
# Persists across different cache hits
|
||||
self.assertEqual(counter_fun(1)[1], 2)
|
||||
|
||||
# Persists after the cache is cleared
|
||||
counter_fun.clear_cache()
|
||||
self.assertEqual(counter_fun(0)[1], 3)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
@ -28,8 +28,8 @@ class CudaE2eTests(jtu.JaxTestCase):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
# Import here to avoid trying to load the library when it's not built.
|
||||
from jax_ffi_example import cuda_e2e
|
||||
self.foo = cuda_e2e.foo
|
||||
from jax_ffi_example import cuda_examples
|
||||
self.foo = cuda_examples.foo
|
||||
|
||||
def test_fwd_interpretable(self):
|
||||
shape = (2, 3)
|
Loading…
x
Reference in New Issue
Block a user