mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
Merge pull request #24792 from dfm:ffi-tutorial-outdated
PiperOrigin-RevId: 695801160
This commit is contained in:
commit
8e224dbc71
@ -26,10 +26,7 @@
|
||||
"In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n",
|
||||
"We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n",
|
||||
"\n",
|
||||
"This tutorial comes with two supplementary files:\n",
|
||||
"\n",
|
||||
"* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and\n",
|
||||
"* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.\n",
|
||||
"The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).\n",
|
||||
"\n",
|
||||
"## A simple example\n",
|
||||
"\n",
|
||||
@ -101,7 +98,7 @@
|
||||
"\n",
|
||||
"To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n",
|
||||
"For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).\n",
|
||||
"The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:\n",
|
||||
"The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:\n",
|
||||
"\n",
|
||||
"```c++\n",
|
||||
"#include <functional>\n",
|
||||
@ -129,12 +126,11 @@
|
||||
"// A wrapper function providing the interface between the XLA FFI call and our\n",
|
||||
"// library function `ComputeRmsNorm` above. This function handles the batch\n",
|
||||
"// dimensions by calling `ComputeRmsNorm` within a loop.\n",
|
||||
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,\n",
|
||||
" ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {\n",
|
||||
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,\n",
|
||||
" ffi::ResultBuffer<ffi::F32> y) {\n",
|
||||
" auto [totalSize, lastDim] = GetDims(x);\n",
|
||||
" if (lastDim == 0) {\n",
|
||||
" return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n",
|
||||
" \"RmsNorm input must be an array\");\n",
|
||||
" return ffi::Error::InvalidArgument(\"RmsNorm input must be an array\");\n",
|
||||
" }\n",
|
||||
" for (int64_t n = 0; n < totalSize; n += lastDim) {\n",
|
||||
" ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));\n",
|
||||
@ -149,8 +145,8 @@
|
||||
" RmsNorm, RmsNormImpl,\n",
|
||||
" ffi::Ffi::Bind()\n",
|
||||
" .Attr<float>(\"eps\")\n",
|
||||
" .Arg<ffi::Buffer<ffi::DataType::F32>>() // x\n",
|
||||
" .Ret<ffi::Buffer<ffi::DataType::F32>>() // y\n",
|
||||
" .Arg<ffi::Buffer<ffi::F32>>() // x\n",
|
||||
" .Ret<ffi::Buffer<ffi::F32>>() // y\n",
|
||||
");\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
@ -173,8 +169,7 @@
|
||||
"Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n",
|
||||
"In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.\n",
|
||||
"\n",
|
||||
"To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.\n",
|
||||
"The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt)."
|
||||
"To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -433,7 +428,7 @@
|
||||
"1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n",
|
||||
"2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n",
|
||||
"\n",
|
||||
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.\n",
|
||||
"We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n",
|
||||
"The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n",
|
||||
"\n",
|
||||
"This custom derivative rule can be wired in as follows:"
|
||||
@ -508,16 +503,16 @@
|
||||
"When defining our FFI wrapper for CPU, the function signature that we used was:\n",
|
||||
"\n",
|
||||
"```c++\n",
|
||||
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,\n",
|
||||
" ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)\n",
|
||||
"ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,\n",
|
||||
" ffi::ResultBuffer<ffi::F32> y)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"To update this to interface with a CUDA kernel, this signature becomes:\n",
|
||||
"\n",
|
||||
"```c++\n",
|
||||
"ffi::Error RmsNormImpl(cudaStream_t stream, float eps,\n",
|
||||
" ffi::Buffer<ffi::DataType::F32> x,\n",
|
||||
" ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)\n",
|
||||
" ffi::Buffer<ffi::F32> x,\n",
|
||||
" ffi::ResultBuffer<ffi::F32> y)\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"And the handler definition is updated to include a `Ctx` in its binding:\n",
|
||||
@ -528,8 +523,8 @@
|
||||
" ffi::Ffi::Bind()\n",
|
||||
" .Ctx<ffi::PlatformStream<cudaStream_t>>()\n",
|
||||
" .Attr<float>(\"eps\")\n",
|
||||
" .Arg<ffi::Buffer<ffi::DataType::F32>>() // x\n",
|
||||
" .Ret<ffi::Buffer<ffi::DataType::F32>>() // y\n",
|
||||
" .Arg<ffi::Buffer<ffi::F32>>() // x\n",
|
||||
" .Ret<ffi::Buffer<ffi::F32>>() // y\n",
|
||||
");\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
|
33
docs/ffi.md
33
docs/ffi.md
@ -34,10 +34,7 @@ JAX's FFI support is provided in two parts:
|
||||
In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.
|
||||
We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.
|
||||
|
||||
This tutorial comes with two supplementary files:
|
||||
|
||||
* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and
|
||||
* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.
|
||||
The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).
|
||||
|
||||
## A simple example
|
||||
|
||||
@ -96,7 +93,7 @@ and, for our example, this is the function that we want to expose to JAX via the
|
||||
|
||||
To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).
|
||||
For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).
|
||||
The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:
|
||||
The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:
|
||||
|
||||
```c++
|
||||
#include <functional>
|
||||
@ -124,12 +121,11 @@ std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
|
||||
// A wrapper function providing the interface between the XLA FFI call and our
|
||||
// library function `ComputeRmsNorm` above. This function handles the batch
|
||||
// dimensions by calling `ComputeRmsNorm` within a loop.
|
||||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
|
||||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::ResultBuffer<ffi::F32> y) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNorm input must be an array");
|
||||
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
|
||||
}
|
||||
for (int64_t n = 0; n < totalSize; n += lastDim) {
|
||||
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
|
||||
@ -144,8 +140,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
RmsNorm, RmsNormImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Attr<float>("eps")
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // y
|
||||
);
|
||||
```
|
||||
|
||||
@ -166,7 +162,6 @@ Now that we have our minimal FFI wrapper implemented, we need to expose this fun
|
||||
In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.
|
||||
|
||||
To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.
|
||||
The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt).
|
||||
|
||||
```{code-cell} ipython3
|
||||
:tags: [hide-output]
|
||||
@ -357,7 +352,7 @@ In this case, we actually define two new FFI calls:
|
||||
1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass.
|
||||
2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.
|
||||
|
||||
We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.
|
||||
We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.
|
||||
The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.
|
||||
|
||||
This custom derivative rule can be wired in as follows:
|
||||
@ -422,16 +417,16 @@ Since this documentation page is automatically generated on a machine without ac
|
||||
When defining our FFI wrapper for CPU, the function signature that we used was:
|
||||
|
||||
```c++
|
||||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)
|
||||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::ResultBuffer<ffi::F32> y)
|
||||
```
|
||||
|
||||
To update this to interface with a CUDA kernel, this signature becomes:
|
||||
|
||||
```c++
|
||||
ffi::Error RmsNormImpl(cudaStream_t stream, float eps,
|
||||
ffi::Buffer<ffi::DataType::F32> x,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y)
|
||||
ffi::Buffer<ffi::F32> x,
|
||||
ffi::ResultBuffer<ffi::F32> y)
|
||||
```
|
||||
|
||||
And the handler definition is updated to include a `Ctx` in its binding:
|
||||
@ -442,8 +437,8 @@ XLA_FFI_DEFINE_HANDLER(
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<cudaStream_t>>()
|
||||
.Attr<float>("eps")
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // y
|
||||
);
|
||||
```
|
||||
|
||||
|
@ -56,12 +56,11 @@ std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
|
||||
// A wrapper function providing the interface between the XLA FFI call and our
|
||||
// library function `ComputeRmsNorm` above. This function handles the batch
|
||||
// dimensions by calling `ComputeRmsNorm` within a loop.
|
||||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y) {
|
||||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::ResultBuffer<ffi::F32> y) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNorm input must be an array");
|
||||
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
|
||||
}
|
||||
for (int64_t n = 0; n < totalSize; n += lastDim) {
|
||||
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
|
||||
@ -75,17 +74,16 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Attr<float>("eps")
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // y
|
||||
);
|
||||
|
||||
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> y,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> res) {
|
||||
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::ResultBuffer<ffi::F32> y,
|
||||
ffi::ResultBuffer<ffi::F32> res) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNormFwd input must be an array");
|
||||
return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
|
||||
}
|
||||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
|
||||
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
|
||||
@ -94,13 +92,12 @@ ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::DataType::F32> x,
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
RmsNormFwd, RmsNormFwdImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Attr<float>("eps")
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // y
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // res
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Attr<float>("eps")
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // y
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // res
|
||||
);
|
||||
|
||||
void ComputeRmsNormBwd(int64_t size, float res, const float *x,
|
||||
@ -115,14 +112,12 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x,
|
||||
}
|
||||
}
|
||||
|
||||
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::DataType::F32> res,
|
||||
ffi::Buffer<ffi::DataType::F32> x,
|
||||
ffi::Buffer<ffi::DataType::F32> ct_y,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> ct_x) {
|
||||
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
|
||||
ffi::Buffer<ffi::F32> ct_y,
|
||||
ffi::ResultBuffer<ffi::F32> ct_x) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNormBwd inputs must be arrays");
|
||||
return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
|
||||
}
|
||||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
|
||||
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
|
||||
@ -131,11 +126,10 @@ ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::DataType::F32> res,
|
||||
return ffi::Error::Success();
|
||||
}
|
||||
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
RmsNormBwd, RmsNormBwdImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // res
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // x
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // ct_y
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // ct_x
|
||||
XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl,
|
||||
ffi::Ffi::Bind()
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // res
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // x
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // ct_y
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // ct_x
|
||||
);
|
||||
|
@ -22,7 +22,7 @@ namespace nb = nanobind;
|
||||
namespace ffi = xla::ffi;
|
||||
|
||||
ffi::Error ArrayAttrImpl(ffi::Span<const int32_t> array,
|
||||
ffi::Result<ffi::BufferR0<ffi::S32>> res) {
|
||||
ffi::ResultBufferR0<ffi::S32> res) {
|
||||
int64_t total = 0;
|
||||
for (int32_t x : array) {
|
||||
total += x;
|
||||
@ -37,8 +37,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl,
|
||||
.Ret<ffi::BufferR0<ffi::S32>>());
|
||||
|
||||
ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs,
|
||||
ffi::Result<ffi::BufferR0<ffi::S32>> secret,
|
||||
ffi::Result<ffi::BufferR0<ffi::S32>> count) {
|
||||
ffi::ResultBufferR0<ffi::S32> secret,
|
||||
ffi::ResultBufferR0<ffi::S32> count) {
|
||||
auto maybe_secret = attrs.get<int64_t>("secret");
|
||||
if (maybe_secret.has_error()) {
|
||||
return maybe_secret.error();
|
||||
|
@ -44,11 +44,9 @@ __global__ void FooFwdKernel(const float *a, const float *b, float *c,
|
||||
// Buffer type provides buffer dimensions, so the "n" argument here is not
|
||||
// strictly necessary, but it allows us to demonstrate the use of attributes
|
||||
// (.Attr in the FFI handler definition above).
|
||||
ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::DataType::F32> a,
|
||||
ffi::Buffer<ffi::DataType::F32> b,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> c,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
|
||||
size_t n) {
|
||||
ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer<ffi::F32> a,
|
||||
ffi::Buffer<ffi::F32> b, ffi::ResultBuffer<ffi::F32> c,
|
||||
ffi::ResultBuffer<ffi::F32> b_plus_1, size_t n) {
|
||||
const int block_dim = 128;
|
||||
const int grid_dim = 1;
|
||||
// Note how we access regular Buffer data vs Result Buffer data:
|
||||
@ -73,12 +71,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
FooFwd, FooFwdHost,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<cudaStream_t>>() // stream
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // a
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // c
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // a
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // b
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // c
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // b_plus_1
|
||||
.Attr<size_t>("n"),
|
||||
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
|
||||
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Backward pass //
|
||||
@ -106,11 +104,11 @@ __global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c
|
||||
}
|
||||
|
||||
ffi::Error FooBwdHost(cudaStream_t stream,
|
||||
ffi::Buffer<ffi::DataType::F32> c_grad,
|
||||
ffi::Buffer<ffi::DataType::F32> a,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_plus_1,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> a_grad,
|
||||
ffi::Result<ffi::Buffer<ffi::DataType::F32>> b_grad,
|
||||
ffi::Buffer<ffi::F32> c_grad,
|
||||
ffi::Buffer<ffi::F32> a,
|
||||
ffi::ResultBuffer<ffi::F32> b_plus_1,
|
||||
ffi::ResultBuffer<ffi::F32> a_grad,
|
||||
ffi::ResultBuffer<ffi::F32> b_grad,
|
||||
size_t n) {
|
||||
const int block_dim = 128;
|
||||
const int grid_dim = 1;
|
||||
@ -131,10 +129,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(
|
||||
FooBwd, FooBwdHost,
|
||||
ffi::Ffi::Bind()
|
||||
.Ctx<ffi::PlatformStream<cudaStream_t>>() // stream
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // c_grad
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // a
|
||||
.Arg<ffi::Buffer<ffi::DataType::F32>>() // b_plus_1
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // a_grad
|
||||
.Ret<ffi::Buffer<ffi::DataType::F32>>() // b_grad
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // c_grad
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // a
|
||||
.Arg<ffi::Buffer<ffi::F32>>() // b_plus_1
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // a_grad
|
||||
.Ret<ffi::Buffer<ffi::F32>>() // b_grad
|
||||
.Attr<size_t>("n"),
|
||||
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
|
||||
{xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled
|
||||
|
@ -59,11 +59,10 @@ std::pair<int64_t, int64_t> GetDims(const ffi::Buffer<T> &buffer) {
|
||||
// library function `ComputeRmsNorm` above. This function handles the batch
|
||||
// dimensions by calling `ComputeRmsNorm` within a loop.
|
||||
ffi::Error RmsNormImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::Result<ffi::Buffer<ffi::F32>> y) {
|
||||
ffi::ResultBuffer<ffi::F32> y) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNorm input must be an array");
|
||||
return ffi::Error::InvalidArgument("RmsNorm input must be an array");
|
||||
}
|
||||
for (int64_t n = 0; n < totalSize; n += lastDim) {
|
||||
ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));
|
||||
@ -82,12 +81,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl,
|
||||
);
|
||||
|
||||
ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer<ffi::F32> x,
|
||||
ffi::Result<ffi::Buffer<ffi::F32>> y,
|
||||
ffi::Result<ffi::Buffer<ffi::F32>> res) {
|
||||
ffi::ResultBuffer<ffi::F32> y,
|
||||
ffi::ResultBuffer<ffi::F32> res) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNormFwd input must be an array");
|
||||
return ffi::Error::InvalidArgument("RmsNormFwd input must be an array");
|
||||
}
|
||||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
|
||||
res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]),
|
||||
@ -118,11 +116,10 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x,
|
||||
|
||||
ffi::Error RmsNormBwdImpl(ffi::Buffer<ffi::F32> res, ffi::Buffer<ffi::F32> x,
|
||||
ffi::Buffer<ffi::F32> ct_y,
|
||||
ffi::Result<ffi::Buffer<ffi::F32>> ct_x) {
|
||||
ffi::ResultBuffer<ffi::F32> ct_x) {
|
||||
auto [totalSize, lastDim] = GetDims(x);
|
||||
if (lastDim == 0) {
|
||||
return ffi::Error(ffi::ErrorCode::kInvalidArgument,
|
||||
"RmsNormBwd inputs must be arrays");
|
||||
return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays");
|
||||
}
|
||||
for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) {
|
||||
ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),
|
||||
|
Loading…
x
Reference in New Issue
Block a user