mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00
Add dtype dispatching to FFI example.
This commit is contained in:
parent
0ef5eb0623
commit
ca23b7495e
@ -14,6 +14,7 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cmath>
|
||||
#include <complex>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
@ -30,12 +31,14 @@ namespace ffi = xla::ffi;
|
||||
// This is the example "library function" that we want to expose to JAX. This
|
||||
// isn't meant to be a particularly good implementation, it's just here as a
|
||||
// placeholder for the purposes of this tutorial.
|
||||
float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
|
||||
float sm = 0.0f;
|
||||
template <typename T>
|
||||
T ComputeRmsNorm(float eps, int64_t size, const T *x, T *y) {
|
||||
T sm = static_cast<T>(0);
|
||||
for (int64_t n = 0; n < size; ++n) {
|
||||
sm += x[n] * x[n];
|
||||
}
|
||||
float scale = 1.0f / std::sqrt(sm / static_cast<float>(size) + eps);
|
||||
T scale = static_cast<T>(1) /
|
||||
std::sqrt(sm / static_cast<T>(size) + static_cast<T>(eps));
|
||||
for (int64_t n = 0; n < size; ++n) {
|
||||
y[n] = x[n] * scale;
|
||||
}
|
||||
@ -46,94 +49,140 @@ float ComputeRmsNorm(float eps, int64_t size, const float *x, float *y) {
|
||||
// In this example, we treat all leading dimensions as batch dimensions, so this
|
||||
// function returns the total number of elements in the buffer, and the size of
|
||||
// the last dimension.
|
||||
template <ffi::DataType T>
|
||||
std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
|
||||
auto dims = buffer.dimensions();
|
||||
std::pair<int64_t, int64_t> GetDims(const ffi::AnyBuffer buffer) {
|
||||
const ffi::AnyBuffer::Dimensions dims = buffer.dimensions();
|
||||
if (dims.size() == 0) {
|
||||
return std::make_pair(0, 0);
|
||||
}
|
||||
return std::make_pair(buffer.element_count(), dims.back());
|
||||
}
|
||||
|
||||
#define ELEMENT_TYPE_DISPATCH(element_type, fn, ...) \
|
||||
switch (element_type) { \
|
||||
case ffi::F32: \
|
||||
return fn<float>(__VA_ARGS__); \
|
||||
case ffi::F64: \
|
||||
return fn<double>(__VA_ARGS__); \
|
||||
case ffi::C64: \
|
||||
return fn<std::complex<float>>(__VA_ARGS__); \
|
||||
case ffi::C128: \
|
||||
return fn<std::complex<double>>(__VA_ARGS__); \
|
||||
default: \
|
||||
return ffi::Error::InvalidArgument("Unsupported input data type."); \
|
||||
}
|
||||
|
||||
// A wrapper function providing the interface between the XLA FFI call and our
|
||||
// library function `ComputeRmsNorm` above. This function handles the batch
|
||||
// dimensions by calling `ComputeRmsNorm` within a loop.
|
||||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::ResultBuffer<ffi::F32> y) {
|
||||
template <typename T>
|
||||
ffi::Error RmsNormImpl(int64_t totalSize, int64_t lastDim, float eps,
|
||||
ffi::AnyBuffer x, ffi::Result<ffi::AnyBuffer> y) {
|
||||
T *x_data = x.typed_data<T>();
|
||||
T *y_data = y->typed_data<T>();
|
||||
for (int64_t n = 0; n < totalSize; n += lastDim) {
|
||||
ComputeRmsNorm(eps, lastDim, x_data + n, y_data + n);
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
ffi::Error RmsNormDispatch(float eps, ffi::AnyBuffer x,
|
||||
ffi::Result<ffi::AnyBuffer> y) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
|
||||
}
|
||||
for (int64_t n = 0; n < totalSize; n += lastDim) {
|
||||
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormImpl, totalSize, lastDim, eps,
|
||||
x, y);
|
||||
}
|
||||
|
||||
// Wrap `RmsNormImpl` and specify the interface to XLA. If you need to declare
|
||||
// this handler in a header, you can use the `XLA_FFI_DECLARE_HANDLER_SYMBOL`
|
||||
// macro: `XLA_FFI_DECLARE_HANDLER_SYMBOL(RmsNorm)`.
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Attr<float>("eps")
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // y
|
||||
.Arg<ffi::AnyBuffer>() // x
|
||||
.Ret<ffi::AnyBuffer>() // y
|
||||
);
|
||||
|
||||
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::ResultBuffer<ffi::F32> y,
|
||||
ffi::ResultBuffer<ffi::F32> res) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
|
||||
}
|
||||
template <typename T>
|
||||
ffi::Error RmsNormFwdImpl(int64_t totalSize, int64_t lastDim, float eps,
|
||||
ffi::AnyBuffer x, ffi::Result<ffi::AnyBuffer> y,
|
||||
ffi::Result<ffi::AnyBuffer> res) {
|
||||
T *x_data = x.typed_data<T>();
|
||||
T *y_data = y->typed_data<T>();
|
||||
T *res_data = res->typed_data<T>();
|
||||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
|
||||
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
|
||||
&(y->typed_data()[n]));
|
||||
res_data[idx] = ComputeRmsNorm(eps, lastDim, x_data + n, y_data + n);
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl,
|
||||
ffi::Error RmsNormFwdDispatch(float eps, ffi::AnyBuffer x,
|
||||
ffi::Result<ffi::AnyBuffer> y,
|
||||
ffi::Result<ffi::AnyBuffer> res) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
|
||||
}
|
||||
ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormFwdImpl, totalSize, lastDim,
|
||||
eps, x, y, res);
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Attr<float>("eps")
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // y
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // res
|
||||
.Arg<ffi::AnyBuffer>() // x
|
||||
.Ret<ffi::AnyBuffer>() // y
|
||||
.Ret<ffi::AnyBuffer>() // res
|
||||
);
|
||||
|
||||
void ComputeRmsNormBwd(int64_t size, float res, const float *x,
|
||||
const float *ct_y, float *ct_x) {
|
||||
float ct_res = 0.0f;
|
||||
template <typename T>
|
||||
void ComputeRmsNormBwd(int64_t size, T res, const T *x, const T *ct_y,
|
||||
T *ct_x) {
|
||||
T ct_res = static_cast<T>(0);
|
||||
for (int64_t n = 0; n < size; ++n) {
|
||||
ct_res += x[n] * ct_y[n];
|
||||
}
|
||||
float factor = ct_res * res * res * res / static_cast<float>(size);
|
||||
T factor = ct_res * res * res * res / static_cast<T>(size);
|
||||
for (int64_t n = 0; n < size; ++n) {
|
||||
ct_x[n] = res * ct_y[n] - factor * x[n];
|
||||
}
|
||||
}
|
||||
|
||||
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
|
||||
ffi::Buffer<ffi::F32> ct_y,
|
||||
ffi::ResultBuffer<ffi::F32> ct_x) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
|
||||
}
|
||||
template <typename T>
|
||||
ffi::Error RmsNormBwdImpl(int64_t totalSize, int64_t lastDim,
|
||||
ffi::AnyBuffer res, ffi::AnyBuffer x,
|
||||
ffi::AnyBuffer ct_y,
|
||||
ffi::Result<ffi::AnyBuffer> ct_x) {
|
||||
T *res_data = res.typed_data<T>();
|
||||
T *x_data = x.typed_data<T>();
|
||||
T *ct_y_data = ct_y.typed_data<T>();
|
||||
T *ct_x_data = ct_x->typed_data<T>();
|
||||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
|
||||
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
|
||||
&(ct_y.typed_data()[n]), &(ct_x->typed_data()[n]));
|
||||
ComputeRmsNormBwd(lastDim, res_data[idx], x_data + n, ct_y_data + n,
|
||||
ct_x_data + n);
|
||||
}
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl,
|
||||
ffi::Error RmsNormBwdDispatch(ffi::AnyBuffer res, ffi::AnyBuffer x,
|
||||
ffi::AnyBuffer ct_y,
|
||||
ffi::Result<ffi::AnyBuffer> ct_x) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
|
||||
}
|
||||
ELEMENT_TYPE_DISPATCH(x.element_type(), RmsNormBwdImpl, totalSize, lastDim,
|
||||
res, x, ct_y, ct_x);
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdDispatch,
|
||||
ffi::Ffi::Bind()
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // res
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // ct_y
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // ct_x
|
||||
.Arg<ffi::AnyBuffer>() // res
|
||||
.Arg<ffi::AnyBuffer>() // x
|
||||
.Arg<ffi::AnyBuffer>() // ct_y
|
||||
.Ret<ffi::AnyBuffer>() // ct_x
|
||||
);
|
||||
|
||||
template <typename T>
|
||||
|
@ -26,7 +26,6 @@ from functools import partial
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from jax_ffi_example import _rms_norm
|
||||
|
||||
@ -36,14 +35,6 @@ for name, target in _rms_norm.registrations().items():
|
||||
|
||||
@partial(jax.custom_vjp, nondiff_argnums=(1,))
|
||||
def rms_norm(x, eps=1e-5):
|
||||
# We only implemented the `float32` version of this function, so we start by
|
||||
# checking the dtype. This check isn't strictly necessary because type
|
||||
# checking is also performed by the FFI when decoding input and output
|
||||
# buffers, but it can be useful to check types in Python to raise more
|
||||
# informative errors.
|
||||
if x.dtype != jnp.float32:
|
||||
raise ValueError("Only the float32 dtype is implemented by rms_norm")
|
||||
|
||||
# In this case, the output of our FFI function is just a single array with the
|
||||
# same shape and dtype as the input.
|
||||
out_type = jax.ShapeDtypeStruct(x.shape, x.dtype)
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from absl.testing import absltest
|
||||
from absl.testing import absltest, parameterized
|
||||
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
@ -25,6 +25,7 @@ jax.config.parse_flags_with_absl()
|
||||
|
||||
|
||||
def rms_norm_ref(x, eps=1e-5):
|
||||
eps = jnp.float32(eps).astype(x.dtype)
|
||||
scale = jnp.sqrt(jnp.mean(jnp.square(x), axis=-1, keepdims=True) + eps)
|
||||
return x / scale
|
||||
|
||||
@ -36,18 +37,21 @@ class RmsNormTests(jtu.JaxTestCase):
|
||||
if not jtu.test_device_matches(["cpu"]):
|
||||
self.skipTest("Unsupported platform")
|
||||
|
||||
def test_basic(self):
|
||||
x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32)
|
||||
@parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex)
|
||||
def test_basic(self, dtype):
|
||||
x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype)
|
||||
self.assertAllClose(rms_norm.rms_norm(x), rms_norm_ref(x))
|
||||
|
||||
def test_batching(self):
|
||||
x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32).reshape((3, 5))
|
||||
@parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex)
|
||||
def test_batching(self, dtype):
|
||||
x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype).reshape((3, 5))
|
||||
self.assertAllClose(
|
||||
jax.vmap(rms_norm.rms_norm)(x), jax.vmap(rms_norm_ref)(x)
|
||||
)
|
||||
jax.vmap(rms_norm.rms_norm)(x),
|
||||
jax.vmap(rms_norm_ref)(x))
|
||||
|
||||
def test_grads(self):
|
||||
x = jnp.linspace(-0.5, 0.5, 15, dtype=jnp.float32).reshape((3, 5))
|
||||
@parameterized.parameters(jtu.dtypes.floating + jtu.dtypes.complex)
|
||||
def test_grads(self, dtype):
|
||||
x = jnp.linspace(-0.5, 0.5, 15, dtype=dtype).reshape((3, 5))
|
||||
jtu.check_grads(rms_norm.rms_norm, (x,), order=1, modes=("rev",))
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user