From ca23b7495ef7e1ad0f6377a6da6c844937442f8c Mon Sep 17 00:00:00 2001
From: Dan Foreman-Mackey <danfm@google.com>
Date: Wed, 19 Feb 2025 11:30:33 -0500
Subject: [PATCH] Add dtype dispatching to FFI example.

---
 examples/ffi/src/jax_ffi_example/rms_norm.cc | 141 +++++++++++++------
 examples/ffi/src/jax_ffi_example/rms_norm.py |   9 --
 examples/ffi/tests/rms_norm_test.py          |  22 +--
 3 files changed, 108 insertions(+), 64 deletions(-)

diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc
index 8314219c4..819f3b9f8 100644
--- a/examples/ffi/src/jax_ffi_example/rms_norm.cc
+++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc
@@ -14,6 +14,7 @@ limitations under the License.
 ==============================================================================*/
 
 #include <cmath>
+#include <complex>
 #include <cstdint>
 #include <functional>
 #include <numeric>
@@ -30,12 +31,14 @@ namespace ffi = xla::ffi;
 // This is the example "library function" that we want to expose to JAX. This
 // isn't meant to be a particularly good implementation, it's just here as a
 // placeholder for the purposes of this tutorial.
-float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
-  float sm = 0.0f;
+template <typename T>
+T ComputeRmsNorm(float eps, int64_t size, const T *x, T *y) {
+  T sm = static_cast<T>(0);
   for (int64_t n = 0; n < size; ++n) {
     sm += x[n] * x[n];
   }
-  float scale = 1.0f / std::sqrt(sm / static_cast<float>(size) + eps);
+  T scale = static_cast<T>(1) /
+            std::sqrt(sm / static_cast<T>(size) + static_cast<T>(eps));
   for (int64_t n = 0; n < size; ++n) {
     y[n] = x[n] * scale;
   }
@@ -46,94 +49,140 @@ float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
 // In this example, we treat all leading dimensions as batch dimensions, so this
 // function returns the total number of elements in the buffer, and the size of
 // the last dimension.
-template <ffi::DataType T>
-std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
-  auto dims = buffer.dimensions();
+std::pair<int64_t, int64_t> GetDims(const ffi::AnyBuffer buffer) {
+  const ffi::AnyBuffer::Dimensions dims = buffer.dimensions();
   if (dims.size() == 0) {
     return std::make_pair(0, 0);
   }
   return std::make_pair(buffer.element_count(), dims.back());
 }
 
+#define ELEMENT_TYPE_DISPATCH(element_type, fn, ...)                      \
+  switch (element_type) {                                                 \
+    case ffi::F32:                                                        \
+      return fn<float>(__VA_ARGS__);                                      \
+    case ffi::F64:                                                        \
+      return fn<double>(__VA_ARGS__);                                     \
+    case ffi::C64:                                                        \
+      return fn<std::complex<float>>(__VA_ARGS__);                        \
+    case ffi::C128:                                                       \
+      return fn<std::complex<double>>(__VA_ARGS__);                       \
+    default:                                                              \
+      return ffi::Error::InvalidArgument("Unsupported input data type."); \
+  }
+
 // A wrapper function providing the interface between the XLA FFI call and our
 // library function `ComputeRmsNorm` above. This function handles the batch
 // dimensions by calling `ComputeRmsNorm` within a loop.
-ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
-                       ffi::ResultBuffer<ffi::F32> y) {
+template <typename T>
+ffi::Error RmsNormImpl(int64_t totalSize, int64_t lastDim, float eps,
+                       ffi::AnyBuffer x, ffi::Result<ffi::AnyBuffer> y) {
+  T *x_data = x.typed_data<T>();
+  T *y_data = y->typed_data<T>();
+  for (int64_t n = 0; n < totalSize; n += lastDim) {
+    ComputeRmsNorm(eps, lastDim, x_data + n, y_data + n);
+  }
+  return ffi::Error::Success();
+}
+
+ffi::Error RmsNormDispatch(float eps, ffi::AnyBuffer x,
+                           ffi::Result<ffi::AnyBuffer> y) {
   auto [totalSize, lastDim] = GetDims(x);
   if (lastDim == 0) {
     return ffi::Error::InvalidArgument("RmsNorm input must be an array");
   }
-  for (int64_t n = 0; n < totalSize; n += lastDim) {
-    ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
-  }
-  return ffi::Error::Success();
+  ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormImpl, totalSize, lastDim, eps,
+                        x, y);
 }
 
 // Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
 // this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`
 // macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.
-XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
+XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormDispatch,
                               ffi::Ffi::Bind()
                                   .Attr<float>("eps")
-                                  .Arg<ffi::Buffer<ffi::F32>>()  // x
-                                  .Ret<ffi::Buffer<ffi::F32>>()  // y
+                                  .Arg<ffi::AnyBuffer>()  // x
+                                  .Ret<ffi::AnyBuffer>()  // y
 );
 
-ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
-                          ffi::ResultBuffer<ffi::F32> y,
-                          ffi::ResultBuffer<ffi::F32> res) {
-  auto [totalSize, lastDim] = GetDims(x);
-  if (lastDim == 0) {
-    return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
-  }
+template <typename T>
+ffi::Error RmsNormFwdImpl(int64_t totalSize, int64_t lastDim, float eps,
+                          ffi::AnyBuffer x, ffi::Result<ffi::AnyBuffer> y,
+                          ffi::Result<ffi::AnyBuffer> res) {
+  T *x_data = x.typed_data<T>();
+  T *y_data = y->typed_data<T>();
+  T *res_data = res->typed_data<T>();
   for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
-    res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
-                                            &(y->typed_data()[n]));
+    res_data[idx] = ComputeRmsNorm(eps, lastDim, x_data + n, y_data + n);
   }
   return ffi::Error::Success();
 }
 
-XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl,
+ffi::Error RmsNormFwdDispatch(float eps, ffi::AnyBuffer x,
+                              ffi::Result<ffi::AnyBuffer> y,
+                              ffi::Result<ffi::AnyBuffer> res) {
+  auto [totalSize, lastDim] = GetDims(x);
+  if (lastDim == 0) {
+    return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
+  }
+  ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormFwdImpl, totalSize, lastDim,
+                        eps, x, y, res);
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdDispatch,
                               ffi::Ffi::Bind()
                                   .Attr<float>("eps")
-                                  .Arg<ffi::Buffer<ffi::F32>>()  // x
-                                  .Ret<ffi::Buffer<ffi::F32>>()  // y
-                                  .Ret<ffi::Buffer<ffi::F32>>()  // res
+                                  .Arg<ffi::AnyBuffer>()  // x
+                                  .Ret<ffi::AnyBuffer>()  // y
+                                  .Ret<ffi::AnyBuffer>()  // res
 );
 
-void ComputeRmsNormBwd(int64_t size, float res, const float *x,
-                       const float *ct_y, float *ct_x) {
-  float ct_res = 0.0f;
+template <typename T>
+void ComputeRmsNormBwd(int64_t size, T res, const T *x, const T *ct_y,
+                       T *ct_x) {
+  T ct_res = static_cast<T>(0);
   for (int64_t n = 0; n < size; ++n) {
     ct_res += x[n] * ct_y[n];
   }
-  float factor = ct_res * res * res * res / static_cast<float>(size);
+  T factor = ct_res * res * res * res / static_cast<T>(size);
   for (int64_t n = 0; n < size; ++n) {
     ct_x[n] = res * ct_y[n] - factor * x[n];
   }
 }
 
-ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
-                          ffi::Buffer<ffi::F32> ct_y,
-                          ffi::ResultBuffer<ffi::F32> ct_x) {
-  auto [totalSize, lastDim] = GetDims(x);
-  if (lastDim == 0) {
-    return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
-  }
+template <typename T>
+ffi::Error RmsNormBwdImpl(int64_t totalSize, int64_t lastDim,
+                          ffi::AnyBuffer res, ffi::AnyBuffer x,
+                          ffi::AnyBuffer ct_y,
+                          ffi::Result<ffi::AnyBuffer> ct_x) {
+  T *res_data = res.typed_data<T>();
+  T *x_data = x.typed_data<T>();
+  T *ct_y_data = ct_y.typed_data<T>();
+  T *ct_x_data = ct_x->typed_data<T>();
   for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
-    ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
-                      &(ct_y.typed_data()[n]), &(ct_x->typed_data()[n]));
+    ComputeRmsNormBwd(lastDim, res_data[idx], x_data + n, ct_y_data + n,
+                      ct_x_data + n);
   }
   return ffi::Error::Success();
 }
 
-XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl,
+ffi::Error RmsNormBwdDispatch(ffi::AnyBuffer res, ffi::AnyBuffer x,
+                              ffi::AnyBuffer ct_y,
+                              ffi::Result<ffi::AnyBuffer> ct_x) {
+  auto [totalSize, lastDim] = GetDims(x);
+  if (lastDim == 0) {
+    return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
+  }
+  ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormBwdImpl, totalSize, lastDim,
+                        res, x, ct_y, ct_x);
+}
+
+XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdDispatch,
                               ffi::Ffi::Bind()
-                                  .Arg<ffi::Buffer<ffi::F32>>()  // res
-                                  .Arg<ffi::Buffer<ffi::F32>>()  // x
-                                  .Arg<ffi::Buffer<ffi::F32>>()  // ct_y
-                                  .Ret<ffi::Buffer<ffi::F32>>()  // ct_x
+                                  .Arg<ffi::AnyBuffer>()  // res
+                                  .Arg<ffi::AnyBuffer>()  // x
+                                  .Arg<ffi::AnyBuffer>()  // ct_y
+                                  .Ret<ffi::AnyBuffer>()  // ct_x
 );
 
 template <typename T>
diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.py b/examples/ffi/src/jax_ffi_example/rms_norm.py
index 851f1900c..6dbfe5043 100644
--- a/examples/ffi/src/jax_ffi_example/rms_norm.py
+++ b/examples/ffi/src/jax_ffi_example/rms_norm.py
@@ -26,7 +26,6 @@ from functools import partial
 import numpy as np
 
 import jax
-import jax.numpy as jnp
 
 from jax_ffi_example import _rms_norm
 
@@ -36,14 +35,6 @@ for name, target in _rms_norm.registrations().items():
 
 @partial(jax.custom_vjp, nondiff_argnums=(1,))
 def rms_norm(x, eps=1e-5):
-  # We only implemented the `float32` version of this function, so we start by
-  # checking the dtype. This check isn't strictly necessary because type
-  # checking is also performed by the FFI when decoding input and output
-  # buffers, but it can be useful to check types in Python to raise more
-  # informative errors.
-  if x.dtype != jnp.float32:
-    raise ValueError("Only the float32 dtype is implemented by rms_norm")
-
   # In this case, the output of our FFI function is just a single array with the
   # same shape and dtype as the input.
   out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
diff --git a/examples/ffi/tests/rms_norm_test.py b/examples/ffi/tests/rms_norm_test.py
index 3a1039cf4..9b6108e1f 100644
--- a/examples/ffi/tests/rms_norm_test.py
+++ b/examples/ffi/tests/rms_norm_test.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from absl.testing import absltest
+from absl.testing import absltest, parameterized
 
 import jax
 import jax.numpy as jnp
@@ -25,6 +25,7 @@ jax.config.parse_flags_with_absl()
 
 
 def rms_norm_ref(x, eps=1e-5):
+  eps = jnp.float32(eps).astype(x.dtype)
   scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
   return x / scale
 
@@ -36,18 +37,21 @@ class RmsNormTests(jtu.JaxTestCase):
     if not jtu.test_device_matches(["cpu"]):
       self.skipTest("Unsupported platform")
 
-  def test_basic(self):
-    x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32)
+  @parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex)
+  def test_basic(self, dtype):
+    x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype)
     self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x))
 
-  def test_batching(self):
-    x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32).reshape((3, 5))
+  @parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex)
+  def test_batching(self, dtype):
+    x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype).reshape((3, 5))
     self.assertAllClose(
-        jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x)
-    )
+        jax.vmap(rms_norm.rms_norm)(x),
+        jax.vmap(rms_norm_ref)(x))
 
-  def test_grads(self):
-    x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32).reshape((3, 5))
+  @parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex)
+  def test_grads(self, dtype):
+    x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype).reshape((3, 5))
     jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",))