From ca23b7495ef7e1ad0f6377a6da6c844937442f8c Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey <danfm@google.com> Date: Wed, 19 Feb 2025 11:30:33 -0500 Subject: [PATCH] Add dtype dispatching to FFI example. --- examples/ffi/src/jax_ffi_example/rms_norm.cc | 141 +++++++++++++------ examples/ffi/src/jax_ffi_example/rms_norm.py | 9 -- examples/ffi/tests/rms_norm_test.py | 22 +-- 3 files changed, 108 insertions(+), 64 deletions(-) diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 8314219c4..819f3b9f8 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include <cmath> +#include <complex> #include <cstdint> #include <functional> #include <numeric> @@ -30,12 +31,14 @@ namespace ffi = xla::ffi; // This is the example "library function" that we want to expose to JAX. This // isn't meant to be a particularly good implementation, it's just here as a // placeholder for the purposes of this tutorial. -float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { - float sm = 0.0f; +template <typename T> +T ComputeRmsNorm(float eps, int64_t size, const T *x, T *y) { + T sm = static_cast<T>(0); for (int64_t n = 0; n < size; ++n) { sm += x[n] * x[n]; } - float scale = 1.0f / std::sqrt(sm / static_cast<float>(size) + eps); + T scale = static_cast<T>(1) / + std::sqrt(sm / static_cast<T>(size) + static_cast<T>(eps)); for (int64_t n = 0; n < size; ++n) { y[n] = x[n] * scale; } @@ -46,94 +49,140 @@ float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) { // In this example, we treat all leading dimensions as batch dimensions, so this // function returns the total number of elements in the buffer, and the size of // the last dimension. -template <ffi::DataType T> -std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) { - auto dims = buffer.dimensions(); +std::pair<int64_t, int64_t> GetDims(const ffi::AnyBuffer buffer) { + const ffi::AnyBuffer::Dimensions dims = buffer.dimensions(); if (dims.size() == 0) { return std::make_pair(0, 0); } return std::make_pair(buffer.element_count(), dims.back()); } +#define ELEMENT_TYPE_DISPATCH(element_type, fn, ...) \ + switch (element_type) { \ + case ffi::F32: \ + return fn<float>(__VA_ARGS__); \ + case ffi::F64: \ + return fn<double>(__VA_ARGS__); \ + case ffi::C64: \ + return fn<std::complex<float>>(__VA_ARGS__); \ + case ffi::C128: \ + return fn<std::complex<double>>(__VA_ARGS__); \ + default: \ + return ffi::Error::InvalidArgument("Unsupported input data type."); \ + } + // A wrapper function providing the interface between the XLA FFI call and our // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. -ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x, - ffi::ResultBuffer<ffi::F32> y) { +template <typename T> +ffi::Error RmsNormImpl(int64_t totalSize, int64_t lastDim, float eps, + ffi::AnyBuffer x, ffi::Result<ffi::AnyBuffer> y) { + T *x_data = x.typed_data<T>(); + T *y_data = y->typed_data<T>(); + for (int64_t n = 0; n < totalSize; n += lastDim) { + ComputeRmsNorm(eps, lastDim, x_data + n, y_data + n); + } + return ffi::Error::Success(); +} + +ffi::Error RmsNormDispatch(float eps, ffi::AnyBuffer x, + ffi::Result<ffi::AnyBuffer> y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } - for (int64_t n = 0; n < totalSize; n += lastDim) { - ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); - } - return ffi::Error::Success(); + ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormImpl, totalSize, lastDim, eps, + x, y); } // Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare // this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL` // macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`. -XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormDispatch, ffi::Ffi::Bind() .Attr<float>("eps") - .Arg<ffi::Buffer<ffi::F32>>() // x - .Ret<ffi::Buffer<ffi::F32>>() // y + .Arg<ffi::AnyBuffer>() // x + .Ret<ffi::AnyBuffer>() // y ); -ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x, - ffi::ResultBuffer<ffi::F32> y, - ffi::ResultBuffer<ffi::F32> res) { - auto [totalSize, lastDim] = GetDims(x); - if (lastDim == 0) { - return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); - } +template <typename T> +ffi::Error RmsNormFwdImpl(int64_t totalSize, int64_t lastDim, float eps, + ffi::AnyBuffer x, ffi::Result<ffi::AnyBuffer> y, + ffi::Result<ffi::AnyBuffer> res) { + T *x_data = x.typed_data<T>(); + T *y_data = y->typed_data<T>(); + T *res_data = res->typed_data<T>(); for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { - res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), - &(y->typed_data()[n])); + res_data[idx] = ComputeRmsNorm(eps, lastDim, x_data + n, y_data + n); } return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl, +ffi::Error RmsNormFwdDispatch(float eps, ffi::AnyBuffer x, + ffi::Result<ffi::AnyBuffer> y, + ffi::Result<ffi::AnyBuffer> res) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); + } + ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormFwdImpl, totalSize, lastDim, + eps, x, y, res); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdDispatch, ffi::Ffi::Bind() .Attr<float>("eps") - .Arg<ffi::Buffer<ffi::F32>>() // x - .Ret<ffi::Buffer<ffi::F32>>() // y - .Ret<ffi::Buffer<ffi::F32>>() // res + .Arg<ffi::AnyBuffer>() // x + .Ret<ffi::AnyBuffer>() // y + .Ret<ffi::AnyBuffer>() // res ); -void ComputeRmsNormBwd(int64_t size, float res, const float *x, - const float *ct_y, float *ct_x) { - float ct_res = 0.0f; +template <typename T> +void ComputeRmsNormBwd(int64_t size, T res, const T *x, const T *ct_y, + T *ct_x) { + T ct_res = static_cast<T>(0); for (int64_t n = 0; n < size; ++n) { ct_res += x[n] * ct_y[n]; } - float factor = ct_res * res * res * res / static_cast<float>(size); + T factor = ct_res * res * res * res / static_cast<T>(size); for (int64_t n = 0; n < size; ++n) { ct_x[n] = res * ct_y[n] - factor * x[n]; } } -ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x, - ffi::Buffer<ffi::F32> ct_y, - ffi::ResultBuffer<ffi::F32> ct_x) { - auto [totalSize, lastDim] = GetDims(x); - if (lastDim == 0) { - return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); - } +template <typename T> +ffi::Error RmsNormBwdImpl(int64_t totalSize, int64_t lastDim, + ffi::AnyBuffer res, ffi::AnyBuffer x, + ffi::AnyBuffer ct_y, + ffi::Result<ffi::AnyBuffer> ct_x) { + T *res_data = res.typed_data<T>(); + T *x_data = x.typed_data<T>(); + T *ct_y_data = ct_y.typed_data<T>(); + T *ct_x_data = ct_x->typed_data<T>(); for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { - ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), - &(ct_y.typed_data()[n]), &(ct_x->typed_data()[n])); + ComputeRmsNormBwd(lastDim, res_data[idx], x_data + n, ct_y_data + n, + ct_x_data + n); } return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl, +ffi::Error RmsNormBwdDispatch(ffi::AnyBuffer res, ffi::AnyBuffer x, + ffi::AnyBuffer ct_y, + ffi::Result<ffi::AnyBuffer> ct_x) { + auto [totalSize, lastDim] = GetDims(x); + if (lastDim == 0) { + return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); + } + ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormBwdImpl, totalSize, lastDim, + res, x, ct_y, ct_x); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdDispatch, ffi::Ffi::Bind() - .Arg<ffi::Buffer<ffi::F32>>() // res - .Arg<ffi::Buffer<ffi::F32>>() // x - .Arg<ffi::Buffer<ffi::F32>>() // ct_y - .Ret<ffi::Buffer<ffi::F32>>() // ct_x + .Arg<ffi::AnyBuffer>() // res + .Arg<ffi::AnyBuffer>() // x + .Arg<ffi::AnyBuffer>() // ct_y + .Ret<ffi::AnyBuffer>() // ct_x ); template <typename T> diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py index 851f1900c..6dbfe5043 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.py +++ b/examples/ffi/src/jax_ffi_example/rms_norm.py @@ -26,7 +26,6 @@ from functools import partial import numpy as np import jax -import jax.numpy as jnp from jax_ffi_example import _rms_norm @@ -36,14 +35,6 @@ for name, target in _rms_norm.registrations().items(): @partial(jax.custom_vjp, nondiff_argnums=(1,)) def rms_norm(x, eps=1e-5): - # We only implemented the `float32` version of this function, so we start by - # checking the dtype. This check isn't strictly necessary because type - # checking is also performed by the FFI when decoding input and output - # buffers, but it can be useful to check types in Python to raise more - # informative errors. - if x.dtype != jnp.float32: - raise ValueError("Only the float32 dtype is implemented by rms_norm") - # In this case, the output of our FFI function is just a single array with the # same shape and dtype as the input. out_type = jax.ShapeDtypeStruct(x.shape, x.dtype) diff --git a/examples/ffi/tests/rms_norm_test.py b/examples/ffi/tests/rms_norm_test.py index 3a1039cf4..9b6108e1f 100644 --- a/examples/ffi/tests/rms_norm_test.py +++ b/examples/ffi/tests/rms_norm_test.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from absl.testing import absltest +from absl.testing import absltest, parameterized import jax import jax.numpy as jnp @@ -25,6 +25,7 @@ jax.config.parse_flags_with_absl() def rms_norm_ref(x, eps=1e-5): + eps = jnp.float32(eps).astype(x.dtype) scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps) return x / scale @@ -36,18 +37,21 @@ class RmsNormTests(jtu.JaxTestCase): if not jtu.test_device_matches(["cpu"]): self.skipTest("Unsupported platform") - def test_basic(self): - x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32) + @parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex) + def test_basic(self, dtype): + x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype) self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x)) - def test_batching(self): - x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32).reshape((3, 5)) + @parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex) + def test_batching(self, dtype): + x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype).reshape((3, 5)) self.assertAllClose( - jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x) - ) + jax.vmap(rms_norm.rms_norm)(x), + jax.vmap(rms_norm_ref)(x)) - def test_grads(self): - x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32).reshape((3, 5)) + @parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex) + def test_grads(self, dtype): + x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype).reshape((3, 5)) jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",))