Update some outdated syntax in FFI tutorial.

This commit is contained in:
Dan Foreman-Mackey 2024-11-08 11:48:45 -05:00
parent 9bb6366741
commit f757054267
6 changed files with 83 additions and 104 deletions

View File

@ -26,10 +26,7 @@
"In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n",
"We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n",
"\n",
"This tutorial comes with two supplementary files:\n",
"\n",
"* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and\n",
"* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.\n",
"The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).\n",
"\n",
"## A simple example\n",
"\n",
@ -101,7 +98,7 @@
"\n",
"To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n",
"For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).\n",
"The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:\n",
"The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:\n",
"\n",
"```c++\n",
"#include <functional>\n",
@ -129,12 +126,11 @@
"// A wrapper function providing the interface between the XLA FFI call and our\n",
"// library function `ComputeRmsNorm` above. This function handles the batch\n",
"// dimensions by calling `ComputeRmsNorm` within a loop.\n",
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,\n",
" ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {\n",
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,\n",
" ffi::ResultBuffer<ffi::F32> y) {\n",
" auto [totalSize, lastDim] = GetDims(x);\n",
" if (lastDim == 0) {\n",
" return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n",
" \"RmsNorm input must be an array\");\n",
" return ffi::Error::InvalidArgument(\"RmsNorm input must be an array\");\n",
" }\n",
" for (int64_t n = 0; n < totalSize; n += lastDim) {\n",
" ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));\n",
@ -149,8 +145,8 @@
" RmsNorm, RmsNormImpl,\n",
" ffi::Ffi::Bind()\n",
" .Attr<float>(\"eps\")\n",
" .Arg<ffi::Buffer<ffi::DataType::F32>>() // x\n",
" .Ret<ffi::Buffer<ffi::DataType::F32>>() // y\n",
" .Arg<ffi::Buffer<ffi::F32>>() // x\n",
" .Ret<ffi::Buffer<ffi::F32>>() // y\n",
");\n",
"```\n",
"\n",
@ -173,8 +169,7 @@
"Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n",
"In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.\n",
"\n",
"To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.\n",
"The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt)."
"To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble."
]
},
{
@ -433,7 +428,7 @@
"1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n",
"2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n",
"\n",
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.\n",
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n",
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
"\n",
"This custom derivative rule can be wired in as follows:"
@ -508,16 +503,16 @@
"When defining our FFI wrapper for CPU, the function signature that we used was:\n",
"\n",
"```c++\n",
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,\n",
" ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)\n",
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,\n",
" ffi::ResultBuffer<ffi::F32> y)\n",
"```\n",
"\n",
"To update this to interface with a CUDA kernel, this signature becomes:\n",
"\n",
"```c++\n",
"ffi::Error RmsNormImpl(cudaStream_t stream, float eps,\n",
" ffi::Buffer<ffi::DataType::F32> x,\n",
" ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)\n",
" ffi::Buffer<ffi::F32> x,\n",
" ffi::ResultBuffer<ffi::F32> y)\n",
"```\n",
"\n",
"And the handler definition is updated to include a `Ctx` in its binding:\n",
@ -528,8 +523,8 @@
" ffi::Ffi::Bind()\n",
" .Ctx<ffi::PlatformStream<cudaStream_t>>()\n",
" .Attr<float>(\"eps\")\n",
" .Arg<ffi::Buffer<ffi::DataType::F32>>() // x\n",
" .Ret<ffi::Buffer<ffi::DataType::F32>>() // y\n",
" .Arg<ffi::Buffer<ffi::F32>>() // x\n",
" .Ret<ffi::Buffer<ffi::F32>>() // y\n",
");\n",
"```\n",
"\n",

View File

@ -34,10 +34,7 @@ JAX's FFI support is provided in two parts:
In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.
We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.
This tutorial comes with two supplementary files:
* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and
* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.
The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).
## A simple example
@ -96,7 +93,7 @@ and, for our example, this is the function that we want to expose to JAX via the
To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).
For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).
The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:
The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:
```c++
#include <functional>
@ -124,12 +121,11 @@ std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNorm input must be an array");
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
@ -144,8 +140,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
```
@ -166,7 +162,6 @@ Now that we have our minimal FFI wrapper implemented, we need to expose this fun
In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.
To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.
The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt).
```{code-cell} ipython3
:tags: [hide-output]
@ -357,7 +352,7 @@ In this case, we actually define two new FFI calls:
1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass.
2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.
We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.
We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.
The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.
This custom derivative rule can be wired in as follows:
@ -422,16 +417,16 @@ Since this documentation page is automatically generated on a machine without ac
When defining our FFI wrapper for CPU, the function signature that we used was:
```c++
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
```
To update this to interface with a CUDA kernel, this signature becomes:
```c++
ffi::Error RmsNormImpl(cudaStream_t stream, float eps,
ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)
ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y)
```
And the handler definition is updated to include a `Ctx` in its binding:
@ -442,8 +437,8 @@ XLA_FFI_DEFINE_HANDLER(
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
```

View File

@ -56,12 +56,11 @@ std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
// A wrapper function providing the interface between the XLA FFI call and our
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNorm input must be an array");
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
@ -75,17 +74,16 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
);
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> res) {
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::ResultBuffer<ffi::F32> y,
ffi::ResultBuffer<ffi::F32> res) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormFwd input must be an array");
return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
@ -94,13 +92,12 @@ ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNormFwd, RmsNormFwdImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
.Ret<ffi::Buffer<ffi::DataType::F32>>() // res
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl,
ffi::Ffi::Bind()
.Attr<float>("eps")
.Arg<ffi::Buffer<ffi::F32>>() // x
.Ret<ffi::Buffer<ffi::F32>>() // y
.Ret<ffi::Buffer<ffi::F32>>() // res
);
void ComputeRmsNormBwd(int64_t size, float res, const float *x,
@ -115,14 +112,12 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x,
}
}
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::DataType::F32> res,
ffi::Buffer<ffi::DataType::F32> x,
ffi::Buffer<ffi::DataType::F32> ct_y,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> ct_x) {
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
ffi::Buffer<ffi::F32> ct_y,
ffi::ResultBuffer<ffi::F32> ct_x) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormBwd inputs must be arrays");
return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
@ -131,11 +126,10 @@ ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::DataType::F32> res,
return ffi::Error::Success();
}
XLA_FFI_DEFINE_HANDLER_SYMBOL(
RmsNormBwd, RmsNormBwdImpl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::DataType::F32>>() // res
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
.Arg<ffi::Buffer<ffi::DataType::F32>>() // ct_y
.Ret<ffi::Buffer<ffi::DataType::F32>>() // ct_x
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl,
ffi::Ffi::Bind()
.Arg<ffi::Buffer<ffi::F32>>() // res
.Arg<ffi::Buffer<ffi::F32>>() // x
.Arg<ffi::Buffer<ffi::F32>>() // ct_y
.Ret<ffi::Buffer<ffi::F32>>() // ct_x
);

View File

@ -22,7 +22,7 @@ namespace nb = nanobind;
namespace ffi = xla::ffi;
ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
ffi::Result<ffi::BufferR0<ffi::S32>> res) {
ffi::ResultBufferR0<ffi::S32> res) {
int64_t total = 0;
for (int32_t x : array) {
total += x;
@ -37,8 +37,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl,
.Ret<ffi::BufferR0<ffi::S32>>());
ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs,
ffi::Result<ffi::BufferR0<ffi::S32>> secret,
ffi::Result<ffi::BufferR0<ffi::S32>> count) {
ffi::ResultBufferR0<ffi::S32> secret,
ffi::ResultBufferR0<ffi::S32> count) {
auto maybe_secret = attrs.get<int64_t>("secret");
if (maybe_secret.has_error()) {
return maybe_secret.error();

View File

@ -44,11 +44,9 @@ __global__ void FooFwdKernel(const float *a, const float *b, float *c,
// Buffer type provides buffer dimensions, so the "n" argument here is not
// strictly necessary, but it allows us to demonstrate the use of attributes
// (.Attr in the FFI handler definition above).
ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::DataType::F32> a,
ffi::Buffer<ffi::DataType::F32> b,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> c,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
size_t n) {
ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::F32> a,
ffi::Buffer<ffi::F32> b, ffi::ResultBuffer<ffi::F32> c,
ffi::ResultBuffer<ffi::F32> b_plus_1, size_t n) {
const int block_dim = 128;
const int grid_dim = 1;
// Note how we access regular Buffer data vs Result Buffer data:
@ -73,12 +71,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
FooFwd, FooFwdHost,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>() // stream
.Arg<ffi::Buffer<ffi::DataType::F32>>() // a
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b
.Ret<ffi::Buffer<ffi::DataType::F32>>() // c
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
.Arg<ffi::Buffer<ffi::F32>>() // a
.Arg<ffi::Buffer<ffi::F32>>() // b
.Ret<ffi::Buffer<ffi::F32>>() // c
.Ret<ffi::Buffer<ffi::F32>>() // b_plus_1
.Attr<size_t>("n"),
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
//----------------------------------------------------------------------------//
// Backward pass //
@ -106,11 +104,11 @@ __global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c
}
ffi::Error FooBwdHost(cudaStream_t stream,
ffi::Buffer<ffi::DataType::F32> c_grad,
ffi::Buffer<ffi::DataType::F32> a,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> a_grad,
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_grad,
ffi::Buffer<ffi::F32> c_grad,
ffi::Buffer<ffi::F32> a,
ffi::ResultBuffer<ffi::F32> b_plus_1,
ffi::ResultBuffer<ffi::F32> a_grad,
ffi::ResultBuffer<ffi::F32> b_grad,
size_t n) {
const int block_dim = 128;
const int grid_dim = 1;
@ -131,10 +129,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
FooBwd, FooBwdHost,
ffi::Ffi::Bind()
.Ctx<ffi::PlatformStream<cudaStream_t>>() // stream
.Arg<ffi::Buffer<ffi::DataType::F32>>() // c_grad
.Arg<ffi::Buffer<ffi::DataType::F32>>() // a
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
.Ret<ffi::Buffer<ffi::DataType::F32>>() // a_grad
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_grad
.Arg<ffi::Buffer<ffi::F32>>() // c_grad
.Arg<ffi::Buffer<ffi::F32>>() // a
.Arg<ffi::Buffer<ffi::F32>>() // b_plus_1
.Ret<ffi::Buffer<ffi::F32>>() // a_grad
.Ret<ffi::Buffer<ffi::F32>>() // b_grad
.Attr<size_t>("n"),
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled

View File

@ -59,11 +59,10 @@ std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
// library function `ComputeRmsNorm` above. This function handles the batch
// dimensions by calling `ComputeRmsNorm` within a loop.
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::Result<ffi::Buffer<ffi::F32>> y) {
ffi::ResultBuffer<ffi::F32> y) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNorm input must be an array");
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
}
for (int64_t n = 0; n < totalSize; n += lastDim) {
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
@ -82,12 +81,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
);
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
ffi::Result<ffi::Buffer<ffi::F32>> y,
ffi::Result<ffi::Buffer<ffi::F32>> res) {
ffi::ResultBuffer<ffi::F32> y,
ffi::ResultBuffer<ffi::F32> res) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormFwd input must be an array");
return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
@ -118,11 +116,10 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x,
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
ffi::Buffer<ffi::F32> ct_y,
ffi::Result<ffi::Buffer<ffi::F32>> ct_x) {
ffi::ResultBuffer<ffi::F32> ct_x) {
auto [totalSize, lastDim] = GetDims(x);
if (lastDim == 0) {
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
"RmsNormBwd inputs must be arrays");
return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
}
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),