Add FFI example demonstrating the use of XLA's FFI state.

Support for this was added in JAX v0.5.0.

PiperOrigin-RevId: 722597704
This commit is contained in:
Dan Foreman-Mackey 2025-02-03 04:05:31 -08:00 committed by jax authors
parent cb188a0cb1
commit 28afd25259
4 changed files with 136 additions and 2 deletions

View File

@ -13,12 +13,12 @@ message(STATUS "XLA include directory: ${XLA_DIR}")
find_package(nanobind CONFIG REQUIRED)
set(
JAX_FFI_EXAMPLE_PROJECTS
JAX_FFI_EXAMPLE_CPU_PROJECTS
"rms_norm"
"cpu_examples"
)
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
foreach(PROJECT ${JAX_FFI_EXAMPLE_CPU_PROJECTS})
nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc")
target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR})
install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
@ -26,9 +26,16 @@ endforeach()
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
CUDA_STANDARD 17)
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
nanobind_add_module(_gpu_examples NB_STATIC "src/jax_ffi_example/gpu_examples.cc")
target_include_directories(_gpu_examples PUBLIC ${XLA_DIR})
target_link_libraries(_gpu_examples PRIVATE CUDA::cudart)
install(TARGETS _gpu_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
endif()

View File

@ -0,0 +1,62 @@
/* Copyright 2025 The JAX Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdint>
#include <memory>
#include "nanobind/nanobind.h"
#include "cuda_runtime_api.h"
#include "xla/ffi/api/ffi.h"
namespace nb = nanobind;
namespace ffi = xla::ffi;
struct State {
static xla::ffi::TypeId id;
explicit State(int32_t value) : value(value) {}
int32_t value;
};
ffi::TypeId State::id = {};
static ffi::ErrorOr<std::unique_ptr<State>> StateInstantiate() {
return std::make_unique<State>(42);
}
static ffi::Error StateExecute(cudaStream_t stream, State* state,
ffi::ResultBufferR0<ffi::S32> out) {
cudaMemcpyAsync(out->typed_data(), &state->value, sizeof(int32_t),
cudaMemcpyHostToDevice, stream);
cudaStreamSynchronize(stream);
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate,
ffi::Ffi::BindInstantiate());
XLA_FFI_DEFINE_HANDLER(kStateExecute, StateExecute,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>()
.Ctx<ffi::State<State>>()
.Ret<ffi::BufferR0<ffi::S32>>());
NB_MODULE(_gpu_examples, m) {
m.def("type_id",
[]() { return nb::capsule(reinterpret_cast<void*>(&State::id)); });
m.def("handler", []() {
nb::dict d;
d["instantiate"] = nb::capsule(reinterpret_cast<void*>(kStateInstantiate));
d["execute"] = nb::capsule(reinterpret_cast<void*>(kStateExecute));
return d;
});
}

View File

@ -0,0 +1,24 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jax
from jax_ffi_example import _gpu_examples
import jax.numpy as jnp
jax.ffi.register_ffi_target("state", _gpu_examples.handler(), platform="CUDA")
jax.ffi.register_ffi_type_id("state", _gpu_examples.type_id(), platform="CUDA")
def read_state():
return jax.ffi.ffi_call("state", jax.ShapeDtypeStruct((), jnp.int32))()

View File

@ -0,0 +1,41 @@
# Copyright 2025 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
import jax
from jax._src import test_util as jtu
jax.config.parse_flags_with_absl()
class GpuExamplesTest(jtu.JaxTestCase):
def setUp(self):
super().setUp()
if not jtu.test_device_matches(["cuda"]):
self.skipTest("Unsupported platform")
# Import here to avoid trying to load the library when it's not built.
from jax_ffi_example import gpu_examples # pylint: disable=g-import-not-at-top
self.read_state = gpu_examples.read_state
def test_basic(self):
self.assertEqual(self.read_state(), 42)
self.assertEqual(jax.jit(self.read_state)(), 42)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())