Add dtype dispatching to FFI example.

This commit is contained in:
Dan Foreman-Mackey 2025-02-19 11:30:33 -05:00
parent 0ef5eb0623
commit ca23b7495e
3 changed files with 108 additions and 64 deletions

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include <cmath>
#include <complex>
#include <cstdint>
#include <functional>
#include <numeric>
@ -30,12 +31,14 @@ namespace ffi = xla::ffi;
// This is the example "library function" that we want to expose to JAX. This
// isn't meant to be a particularly good implementation, it's just here as a
// placeholder for the purposes of this tutorial.
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
float sm = 0.0f;
template <typename T>
T ComputeRmsNorm(float eps, int64_t size, const T *x, T *y) {
T sm = static_cast<T>(0);
for (int64_t n = 0; n < size; ++n) {
sm += x[n] * x[n];
}
float scale = 1.0f / std::sqrt(sm / static_cast<float>(size) + eps);
T scale = static_cast<T>(1) /
std::sqrt(sm / static_cast<T>(size) + static_cast<T>(eps));
for (int64_t n = 0; n < size; ++n) {
y[n] = x[n] * scale;
}
@ -46,94 +49,140 @@ float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
// In this example, we treat all leading dimensions as batch dimensions, so this
// function returns the total number of elements in the buffer, and the size of
// the last dimension.
template <ffi::DataType T>
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
auto dims = buffer.dimensions();
std::pair<int64_t, int64_t> GetDims(const ffi::AnyBuffer buffer) {
const ffi::AnyBuffer::Dimensions dims = buffer.dimensions();
if (dims.size() == 0) {
return std::make_pair(0, 0);
}
return std::make_pair(buffer.element_count(), dims.back());
}
#define ELEMENT_TYPE_DISPATCH(element_type, fn, ...) \
switch (element_type) { \
case ffi::F32: \
return fn<float>(__VA_ARGS__); \
case ffi::F64: \
return fn<double>(__VA_ARGS__); \
case ffi::C64: \
return fn<std::complex<float>>(__VA_ARGS__); \
case ffi::C128: \
return fn<std::complex<double>>(__VA_ARGS__); \
default: \
return ffi::Error::InvalidArgument("Unsupported input data type."); \
}
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y) {
template <typename T>
ffi::Error RmsNormImpl(int64_t totalSize, int64_t lastDim, float eps,
ffi::AnyBuffer x, ffi::Result<ffi::AnyBuffer> y) {
T *x_data = x.typed_data<T>();
T *y_data = y->typed_data<T>();
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, x_data + n, y_data + n);
}
return ffi::Error::Success();
}
ffi::Error RmsNormDispatch(float eps, ffi::AnyBuffer x,
ffi::Result<ffi::AnyBuffer> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
}
return ffi::Error::Success();
ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormImpl, totalSize, lastDim, eps,
x, y);
}
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`
// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormDispatch,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
.Arg<ffi::AnyBuffer>() // x
.Ret<ffi::AnyBuffer>() // y
);
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y,
ffi::ResultBuffer<ffi::F32> res) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
}
template <typename T>
ffi::Error RmsNormFwdImpl(int64_t totalSize, int64_t lastDim, float eps,
ffi::AnyBuffer x, ffi::Result<ffi::AnyBuffer> y,
ffi::Result<ffi::AnyBuffer> res) {
T *x_data = x.typed_data<T>();
T *y_data = y->typed_data<T>();
T *res_data = res->typed_data<T>();
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
&(y->typed_data()[n]));
res_data[idx] = ComputeRmsNorm(eps, lastDim, x_data + n, y_data + n);
}
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl,
ffi::Error RmsNormFwdDispatch(float eps, ffi::AnyBuffer x,
ffi::Result<ffi::AnyBuffer> y,
ffi::Result<ffi::AnyBuffer> res) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
}
ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormFwdImpl, totalSize, lastDim,
eps, x, y, res);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdDispatch,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
.Ret<ffi::Buffer<ffi::F32>>() // res
.Arg<ffi::AnyBuffer>() // x
.Ret<ffi::AnyBuffer>() // y
.Ret<ffi::AnyBuffer>() // res
);
void ComputeRmsNormBwd(int64_t size, float res, const float *x,
const float *ct_y, float *ct_x) {
float ct_res = 0.0f;
template <typename T>
void ComputeRmsNormBwd(int64_t size, T res, const T *x, const T *ct_y,
T *ct_x) {
T ct_res = static_cast<T>(0);
for (int64_t n = 0; n < size; ++n) {
ct_res += x[n] * ct_y[n];
}
float factor = ct_res * res * res * res / static_cast<float>(size);
T factor = ct_res * res * res * res / static_cast<T>(size);
for (int64_t n = 0; n < size; ++n) {
ct_x[n] = res * ct_y[n] - factor * x[n];
}
}
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
ffi::Buffer<ffi::F32> ct_y,
ffi::ResultBuffer<ffi::F32> ct_x) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
}
template <typename T>
ffi::Error RmsNormBwdImpl(int64_t totalSize, int64_t lastDim,
ffi::AnyBuffer res, ffi::AnyBuffer x,
ffi::AnyBuffer ct_y,
ffi::Result<ffi::AnyBuffer> ct_x) {
T *res_data = res.typed_data<T>();
T *x_data = x.typed_data<T>();
T *ct_y_data = ct_y.typed_data<T>();
T *ct_x_data = ct_x->typed_data<T>();
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
&(ct_y.typed_data()[n]), &(ct_x->typed_data()[n]));
ComputeRmsNormBwd(lastDim, res_data[idx], x_data + n, ct_y_data + n,
ct_x_data + n);
}
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl,
ffi::Error RmsNormBwdDispatch(ffi::AnyBuffer res, ffi::AnyBuffer x,
ffi::AnyBuffer ct_y,
ffi::Result<ffi::AnyBuffer> ct_x) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
}
ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormBwdImpl, totalSize, lastDim,
res, x, ct_y, ct_x);
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdDispatch,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::F32>>() // res
.Arg<ffi::Buffer<ffi::F32>>() // x
.Arg<ffi::Buffer<ffi::F32>>() // ct_y
.Ret<ffi::Buffer<ffi::F32>>() // ct_x
.Arg<ffi::AnyBuffer>() // res
.Arg<ffi::AnyBuffer>() // x
.Arg<ffi::AnyBuffer>() // ct_y
.Ret<ffi::AnyBuffer>() // ct_x
);
template <typename T>

View File

@ -26,7 +26,6 @@ from functools import partial
import numpy as np
import jax
import jax.numpy as jnp
from jax_ffi_example import _rms_norm
@ -36,14 +35,6 @@ for name, target in _rms_norm.registrations().items():
@partial(jax.custom_vjp, nondiff_argnums=(1,))
def rms_norm(x, eps=1e-5):
# We only implemented the `float32` version of this function, so we start by
# checking the dtype. This check isn't strictly necessary because type
# checking is also performed by the FFI when decoding input and output
# buffers, but it can be useful to check types in Python to raise more
# informative errors.
if x.dtype != jnp.float32:
raise ValueError("Only the float32 dtype is implemented by rms_norm")
# In this case, the output of our FFI function is just a single array with the
# same shape and dtype as the input.
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from absl.testing import absltest
from absl.testing import absltest, parameterized
import jax
import jax.numpy as jnp
@ -25,6 +25,7 @@ jax.config.parse_flags_with_absl()
def rms_norm_ref(x, eps=1e-5):
eps = jnp.float32(eps).astype(x.dtype)
scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
return x / scale
@ -36,18 +37,21 @@ class RmsNormTests(jtu.JaxTestCase):
if not jtu.test_device_matches(["cpu"]):
self.skipTest("Unsupported platform")
def test_basic(self):
x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32)
@parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex)
def test_basic(self, dtype):
x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype)
self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x))
def test_batching(self):
x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32).reshape((3, 5))
@parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex)
def test_batching(self, dtype):
x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype).reshape((3, 5))
self.assertAllClose(
jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x)
)
jax.vmap(rms_norm.rms_norm)(x),
jax.vmap(rms_norm_ref)(x))
def test_grads(self):
x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32).reshape((3, 5))
@parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex)
def test_grads(self, dtype):
x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype).reshape((3, 5))
jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",))