mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Add FFI example demonstrating the use of XLA's FFI state.
Support for this was added in JAX v0.5.0. PiperOrigin-RevId: 722597704
This commit is contained in:
parent
cb188a0cb1
commit
28afd25259
@ -13,12 +13,12 @@ message(STATUS "XLA include directory: ${XLA_DIR}")
|
||||
find_package(nanobind CONFIG REQUIRED)
|
||||
|
||||
set(
|
||||
JAX_FFI_EXAMPLE_PROJECTS
|
||||
JAX_FFI_EXAMPLE_CPU_PROJECTS
|
||||
"rms_norm"
|
||||
"cpu_examples"
|
||||
)
|
||||
|
||||
foreach(PROJECT ${JAX_FFI_EXAMPLE_PROJECTS})
|
||||
foreach(PROJECT ${JAX_FFI_EXAMPLE_CPU_PROJECTS})
|
||||
nanobind_add_module("_${PROJECT}" NB_STATIC "src/jax_ffi_example/${PROJECT}.cc")
|
||||
target_include_directories("_${PROJECT}" PUBLIC ${XLA_DIR})
|
||||
install(TARGETS "_${PROJECT}" LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
@ -26,9 +26,16 @@ endforeach()
|
||||
|
||||
if(JAX_FFI_EXAMPLE_ENABLE_CUDA)
|
||||
enable_language(CUDA)
|
||||
find_package(CUDAToolkit REQUIRED)
|
||||
|
||||
add_library(_cuda_examples SHARED "src/jax_ffi_example/cuda_examples.cu")
|
||||
set_target_properties(_cuda_examples PROPERTIES POSITION_INDEPENDENT_CODE ON
|
||||
CUDA_STANDARD 17)
|
||||
target_include_directories(_cuda_examples PUBLIC ${XLA_DIR})
|
||||
install(TARGETS _cuda_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
|
||||
nanobind_add_module(_gpu_examples NB_STATIC "src/jax_ffi_example/gpu_examples.cc")
|
||||
target_include_directories(_gpu_examples PUBLIC ${XLA_DIR})
|
||||
target_link_libraries(_gpu_examples PRIVATE CUDA::cudart)
|
||||
install(TARGETS _gpu_examples LIBRARY DESTINATION ${SKBUILD_PROJECT_NAME})
|
||||
endif()
|
||||
|
62
examples/ffi/src/jax_ffi_example/gpu_examples.cc
Normal file
62
examples/ffi/src/jax_ffi_example/gpu_examples.cc
Normal file
@ -0,0 +1,62 @@
|
||||
/* Copyright 2025 The JAX Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <memory>
|
||||
|
||||
#include "nanobind/nanobind.h"
|
||||
#include "cuda_runtime_api.h"
|
||||
#include "xla/ffi/api/ffi.h"
|
||||
|
||||
namespace nb = nanobind;
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
struct State {
|
||||
static xla::ffi::TypeId id;
|
||||
explicit State(int32_t value) : value(value) {}
|
||||
int32_t value;
|
||||
};
|
||||
ffi::TypeId State::id = {};
|
||||
|
||||
static ffi::ErrorOr<std::unique_ptr<State>> StateInstantiate() {
|
||||
return std::make_unique<State>(42);
|
||||
}
|
||||
|
||||
static ffi::Error StateExecute(cudaStream_t stream, State* state,
|
||||
ffi::ResultBufferR0<ffi::S32> out) {
|
||||
cudaMemcpyAsync(out->typed_data(), &state->value, sizeof(int32_t),
|
||||
cudaMemcpyHostToDevice, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER(kStateInstantiate, StateInstantiate,
|
||||
ffi::Ffi::BindInstantiate());
|
||||
XLA_FFI_DEFINE_HANDLER(kStateExecute, StateExecute,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<cudaStream_t>>()
|
||||
.Ctx<ffi::State<State>>()
|
||||
.Ret<ffi::BufferR0<ffi::S32>>());
|
||||
|
||||
NB_MODULE(_gpu_examples, m) {
|
||||
m.def("type_id",
|
||||
[]() { return nb::capsule(reinterpret_cast<void*>(&State::id)); });
|
||||
m.def("handler", []() {
|
||||
nb::dict d;
|
||||
d["instantiate"] = nb::capsule(reinterpret_cast<void*>(kStateInstantiate));
|
||||
d["execute"] = nb::capsule(reinterpret_cast<void*>(kStateExecute));
|
||||
return d;
|
||||
});
|
||||
}
|
24
examples/ffi/src/jax_ffi_example/gpu_examples.py
Normal file
24
examples/ffi/src/jax_ffi_example/gpu_examples.py
Normal file
@ -0,0 +1,24 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import jax
|
||||
from jax_ffi_example import _gpu_examples
|
||||
import jax.numpy as jnp
|
||||
|
||||
jax.ffi.register_ffi_target("state", _gpu_examples.handler(), platform="CUDA")
|
||||
jax.ffi.register_ffi_type_id("state", _gpu_examples.type_id(), platform="CUDA")
|
||||
|
||||
|
||||
def read_state():
|
||||
return jax.ffi.ffi_call("state", jax.ShapeDtypeStruct((), jnp.int32))()
|
41
examples/ffi/tests/gpu_examples_test.py
Normal file
41
examples/ffi/tests/gpu_examples_test.py
Normal file
@ -0,0 +1,41 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
import jax
|
||||
from jax._src import test_util as jtu
|
||||
|
||||
jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
class GpuExamplesTest(jtu.JaxTestCase):
|
||||
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
if not jtu.test_device_matches(["cuda"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
# Import here to avoid trying to load the library when it's not built.
|
||||
from jax_ffi_example import gpu_examples # pylint: disable=g-import-not-at-top
|
||||
|
||||
self.read_state = gpu_examples.read_state
|
||||
|
||||
def test_basic(self):
|
||||
self.assertEqual(self.read_state(), 42)
|
||||
self.assertEqual(jax.jit(self.read_state)(), 42)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
Loading…
x
Reference in New Issue
Block a user