From f75705426742b4c3a2c323f81a12a0c369cfd454 Mon Sep 17 00:00:00 2001 From: Dan Foreman-Mackey Date: Fri, 8 Nov 2024 11:48:45 -0500 Subject: [PATCH] Update some outdated syntax in FFI tutorial. --- docs/ffi.ipynb | 35 ++++++------ docs/ffi.md | 33 +++++------- docs/ffi/rms_norm.cc | 56 +++++++++----------- examples/ffi/src/jax_ffi_example/attrs.cc | 6 +-- examples/ffi/src/jax_ffi_example/cuda_e2e.cu | 40 +++++++------- examples/ffi/src/jax_ffi_example/rms_norm.cc | 17 +++--- 6 files changed, 83 insertions(+), 104 deletions(-) diff --git a/docs/ffi.ipynb b/docs/ffi.ipynb index ea8a86fa8..f1a699b5c 100644 --- a/docs/ffi.ipynb +++ b/docs/ffi.ipynb @@ -26,10 +26,7 @@ "In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases.\n", "We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below.\n", "\n", - "This tutorial comes with two supplementary files:\n", - "\n", - "* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and\n", - "* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code.\n", + "The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi).\n", "\n", "## A simple example\n", "\n", @@ -101,7 +98,7 @@ "\n", "To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla).\n", "For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call).\n", - "The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here:\n", + "The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here:\n", "\n", "```c++\n", "#include \n", @@ -129,12 +126,11 @@ "// A wrapper function providing the interface between the XLA FFI call and our\n", "// library function `ComputeRmsNorm` above. This function handles the batch\n", "// dimensions by calling `ComputeRmsNorm` within a loop.\n", - "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", - " ffi::Result> y) {\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::ResultBuffer y) {\n", " auto [totalSize, lastDim] = GetDims(x);\n", " if (lastDim == 0) {\n", - " return ffi::Error(ffi::ErrorCode::kInvalidArgument,\n", - " \"RmsNorm input must be an array\");\n", + " return ffi::Error::InvalidArgument(\"RmsNorm input must be an array\");\n", " }\n", " for (int64_t n = 0; n < totalSize; n += lastDim) {\n", " ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n]));\n", @@ -149,8 +145,8 @@ " RmsNorm, RmsNormImpl,\n", " ffi::Ffi::Bind()\n", " .Attr(\"eps\")\n", - " .Arg>() // x\n", - " .Ret>() // y\n", + " .Arg>() // x\n", + " .Ret>() // y\n", ");\n", "```\n", "\n", @@ -173,8 +169,7 @@ "Now that we have our minimal FFI wrapper implemented, we need to expose this function (`RmsNorm`) to Python.\n", "In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below.\n", "\n", - "To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble.\n", - "The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt)." + "To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble." ] }, { @@ -433,7 +428,7 @@ "1. `rms_norm_fwd` returns two outputs: (a) the \"primal\" result, and (b) the \"residuals\" which are used in the backwards pass.\n", "2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents.\n", "\n", - "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end.\n", + "We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end.\n", "The main point to emphasize here is that the \"residual\" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes.\n", "\n", "This custom derivative rule can be wired in as follows:" @@ -508,16 +503,16 @@ "When defining our FFI wrapper for CPU, the function signature that we used was:\n", "\n", "```c++\n", - "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", - " ffi::Result> y)\n", + "ffi::Error RmsNormImpl(float eps, ffi::Buffer x,\n", + " ffi::ResultBuffer y)\n", "```\n", "\n", "To update this to interface with a CUDA kernel, this signature becomes:\n", "\n", "```c++\n", "ffi::Error RmsNormImpl(cudaStream_t stream, float eps,\n", - " ffi::Buffer x,\n", - " ffi::Result> y)\n", + " ffi::Buffer x,\n", + " ffi::ResultBuffer y)\n", "```\n", "\n", "And the handler definition is updated to include a `Ctx` in its binding:\n", @@ -528,8 +523,8 @@ " ffi::Ffi::Bind()\n", " .Ctx>()\n", " .Attr(\"eps\")\n", - " .Arg>() // x\n", - " .Ret>() // y\n", + " .Arg>() // x\n", + " .Ret>() // y\n", ");\n", "```\n", "\n", diff --git a/docs/ffi.md b/docs/ffi.md index 5afc8f809..dbe901237 100644 --- a/docs/ffi.md +++ b/docs/ffi.md @@ -34,10 +34,7 @@ JAX's FFI support is provided in two parts: In this tutorial we demonstrate the use of both of these components using a simple example, and then go on to discuss some lower-level extensions for more complicated use cases. We start by presenting the FFI on CPU, and discuss generalizations to GPU or multi-device environments below. -This tutorial comes with two supplementary files: - -* [`rms_norm.cc`](ffi/rms_norm.cc), which includes all the backend code, and -* [`CMakeLists.txt`](ffi/CMakeLists.txt), which tells [CMake](https://cmake.org) how to build the code. +The end-to-end code for this example and some other more advanced use cases can be found in the JAX FFI examples project on GitHub at [`examples/ffi` in the JAX repository](https://github.com/jax-ml/jax/tree/main/examples/ffi). ## A simple example @@ -96,7 +93,7 @@ and, for our example, this is the function that we want to expose to JAX via the To expose our library function to JAX and XLA, we need to write a thin wrapper using the APIs provided by the header-only library in the [`xla/ffi/api`](https://github.com/openxla/xla/tree/main/xla/ffi/api) directory of the [XLA project](https://github.com/openxla/xla). For more information about this interface, take a look at [the XLA custom call documentation](https://openxla.org/xla/custom_call). -The full source listing can be downloaded [here](ffi/rms_norm.cc), but the key implementation details are reproduced here: +The full source listing can be downloaded [here](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc), but the key implementation details are reproduced here: ```c++ #include @@ -124,12 +121,11 @@ std::pair GetDims(const ffi::Buffer &buffer) { // A wrapper function providing the interface between the XLA FFI call and our // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -144,8 +140,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( RmsNorm, RmsNormImpl, ffi::Ffi::Bind() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); ``` @@ -166,7 +162,6 @@ Now that we have our minimal FFI wrapper implemented, we need to expose this fun In this tutorial, we compile `RmsNorm` into a shared library and load it using [ctypes](https://docs.python.org/3/library/ctypes.html), but another common pattern is to use [nanobind](https://nanobind.readthedocs.io/) or [pybind11](https://pybind11.readthedocs.io/) as discussed below. To compile the shared library, we're using CMake here, but you should be able to use your favorite build system without too much trouble. -The full `CMakeLists.txt` can be downloaded [here](ffi/CMakeLists.txt). ```{code-cell} ipython3 :tags: [hide-output] @@ -357,7 +352,7 @@ In this case, we actually define two new FFI calls: 1. `rms_norm_fwd` returns two outputs: (a) the "primal" result, and (b) the "residuals" which are used in the backwards pass. 2. `rms_norm_bwd` takes the residuals and the output co-tangents, and returns the input co-tangents. -We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](ffi/rms_norm.cc) to see how these functions are implemented on the back end. +We won't get into the details of the RMS normalization backwards pass, but take a look at the [C++ source code](https://github.com/jax-ml/jax/blob/main/examples/ffi/src/jax_ffi_example/rms_norm.cc) to see how these functions are implemented on the back end. The main point to emphasize here is that the "residual" computed has a different shape than the primal output, therefore, in the {func}`~jax.extend.ffi.ffi_call` to `res_norm_fwd`, the output type has two elements with different shapes. This custom derivative rule can be wired in as follows: @@ -422,16 +417,16 @@ Since this documentation page is automatically generated on a machine without ac When defining our FFI wrapper for CPU, the function signature that we used was: ```c++ -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) ``` To update this to interface with a CUDA kernel, this signature becomes: ```c++ ffi::Error RmsNormImpl(cudaStream_t stream, float eps, - ffi::Buffer x, - ffi::Result> y) + ffi::Buffer x, + ffi::ResultBuffer y) ``` And the handler definition is updated to include a `Ctx` in its binding: @@ -442,8 +437,8 @@ XLA_FFI_DEFINE_HANDLER( ffi::Ffi::Bind() .Ctx>() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); ``` diff --git a/docs/ffi/rms_norm.cc b/docs/ffi/rms_norm.cc index 4dc8a8904..467f13d44 100644 --- a/docs/ffi/rms_norm.cc +++ b/docs/ffi/rms_norm.cc @@ -56,12 +56,11 @@ std::pair GetDims(const ffi::Buffer &buffer) { // A wrapper function providing the interface between the XLA FFI call and our // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. -ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { +ffi::Error RmsNormImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -75,17 +74,16 @@ ffi::Error RmsNormImpl(float eps, ffi::Buffer x, XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, ffi::Ffi::Bind() .Attr("eps") - .Arg>() // x - .Ret>() // y + .Arg>() // x + .Ret>() // y ); -ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, - ffi::Result> y, - ffi::Result> res) { +ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, + ffi::ResultBuffer y, + ffi::ResultBuffer res) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormFwd input must be an array"); + return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), @@ -94,13 +92,12 @@ ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - RmsNormFwd, RmsNormFwdImpl, - ffi::Ffi::Bind() - .Attr("eps") - .Arg>() // x - .Ret>() // y - .Ret>() // res +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormFwd, RmsNormFwdImpl, + ffi::Ffi::Bind() + .Attr("eps") + .Arg>() // x + .Ret>() // y + .Ret>() // res ); void ComputeRmsNormBwd(int64_t size, float res, const float *x, @@ -115,14 +112,12 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x, } } -ffi::Error RmsNormBwdImpl(ffi::Buffer res, - ffi::Buffer x, - ffi::Buffer ct_y, - ffi::Result> ct_x) { +ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, + ffi::Buffer ct_y, + ffi::ResultBuffer ct_x) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormBwd inputs must be arrays"); + return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]), @@ -131,11 +126,10 @@ ffi::Error RmsNormBwdImpl(ffi::Buffer res, return ffi::Error::Success(); } -XLA_FFI_DEFINE_HANDLER_SYMBOL( - RmsNormBwd, RmsNormBwdImpl, - ffi::Ffi::Bind() - .Arg>() // res - .Arg>() // x - .Arg>() // ct_y - .Ret>() // ct_x +XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNormBwd, RmsNormBwdImpl, + ffi::Ffi::Bind() + .Arg>() // res + .Arg>() // x + .Arg>() // ct_y + .Ret>() // ct_x ); diff --git a/examples/ffi/src/jax_ffi_example/attrs.cc b/examples/ffi/src/jax_ffi_example/attrs.cc index 2a6e8d847..7ff5c98e5 100644 --- a/examples/ffi/src/jax_ffi_example/attrs.cc +++ b/examples/ffi/src/jax_ffi_example/attrs.cc @@ -22,7 +22,7 @@ namespace nb = nanobind; namespace ffi = xla::ffi; ffi::Error ArrayAttrImpl(ffi::Span array, - ffi::Result> res) { + ffi::ResultBufferR0 res) { int64_t total = 0; for (int32_t x : array) { total += x; @@ -37,8 +37,8 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(ArrayAttr, ArrayAttrImpl, .Ret>()); ffi::Error DictionaryAttrImpl(ffi::Dictionary attrs, - ffi::Result> secret, - ffi::Result> count) { + ffi::ResultBufferR0 secret, + ffi::ResultBufferR0 count) { auto maybe_secret = attrs.get("secret"); if (maybe_secret.has_error()) { return maybe_secret.error(); diff --git a/examples/ffi/src/jax_ffi_example/cuda_e2e.cu b/examples/ffi/src/jax_ffi_example/cuda_e2e.cu index 858b5f8a8..240adb6d6 100644 --- a/examples/ffi/src/jax_ffi_example/cuda_e2e.cu +++ b/examples/ffi/src/jax_ffi_example/cuda_e2e.cu @@ -44,11 +44,9 @@ __global__ void FooFwdKernel(const float *a, const float *b, float *c, // Buffer type provides buffer dimensions, so the "n" argument here is not // strictly necessary, but it allows us to demonstrate the use of attributes // (.Attr in the FFI handler definition above). -ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, - ffi::Buffer b, - ffi::Result> c, - ffi::Result> b_plus_1, - size_t n) { +ffi::Error FooFwdHost(cudaStream_t stream, ffi::Buffer a, + ffi::Buffer b, ffi::ResultBuffer c, + ffi::ResultBuffer b_plus_1, size_t n) { const int block_dim = 128; const int grid_dim = 1; // Note how we access regular Buffer data vs Result Buffer data: @@ -73,12 +71,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( FooFwd, FooFwdHost, ffi::Ffi::Bind() .Ctx>() // stream - .Arg>() // a - .Arg>() // b - .Ret>() // c - .Ret>() // b_plus_1 + .Arg>() // a + .Arg>() // b + .Ret>() // c + .Ret>() // b_plus_1 .Attr("n"), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled //----------------------------------------------------------------------------// // Backward pass // @@ -106,11 +104,11 @@ __global__ void FooBwdKernel(const float *c_grad, // incoming gradient wrt c } ffi::Error FooBwdHost(cudaStream_t stream, - ffi::Buffer c_grad, - ffi::Buffer a, - ffi::Result> b_plus_1, - ffi::Result> a_grad, - ffi::Result> b_grad, + ffi::Buffer c_grad, + ffi::Buffer a, + ffi::ResultBuffer b_plus_1, + ffi::ResultBuffer a_grad, + ffi::ResultBuffer b_grad, size_t n) { const int block_dim = 128; const int grid_dim = 1; @@ -131,10 +129,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL( FooBwd, FooBwdHost, ffi::Ffi::Bind() .Ctx>() // stream - .Arg>() // c_grad - .Arg>() // a - .Arg>() // b_plus_1 - .Ret>() // a_grad - .Ret>() // b_grad + .Arg>() // c_grad + .Arg>() // a + .Arg>() // b_plus_1 + .Ret>() // a_grad + .Ret>() // b_grad .Attr("n"), - {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled + {xla::ffi::Traits::kCmdBufferCompatible}); // cudaGraph enabled diff --git a/examples/ffi/src/jax_ffi_example/rms_norm.cc b/examples/ffi/src/jax_ffi_example/rms_norm.cc index 2fb8d96c8..455a0e557 100644 --- a/examples/ffi/src/jax_ffi_example/rms_norm.cc +++ b/examples/ffi/src/jax_ffi_example/rms_norm.cc @@ -59,11 +59,10 @@ std::pair GetDims(const ffi::Buffer &buffer) { // library function `ComputeRmsNorm` above. This function handles the batch // dimensions by calling `ComputeRmsNorm` within a loop. ffi::Error RmsNormImpl(float eps, ffi::Buffer x, - ffi::Result> y) { + ffi::ResultBuffer y) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNorm input must be an array"); + return ffi::Error::InvalidArgument("RmsNorm input must be an array"); } for (int64_t n = 0; n < totalSize; n += lastDim) { ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), &(y->typed_data()[n])); @@ -82,12 +81,11 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(RmsNorm, RmsNormImpl, ); ffi::Error RmsNormFwdImpl(float eps, ffi::Buffer x, - ffi::Result> y, - ffi::Result> res) { + ffi::ResultBuffer y, + ffi::ResultBuffer res) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormFwd input must be an array"); + return ffi::Error::InvalidArgument("RmsNormFwd input must be an array"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { res->typed_data()[idx] = ComputeRmsNorm(eps, lastDim, &(x.typed_data()[n]), @@ -118,11 +116,10 @@ void ComputeRmsNormBwd(int64_t size, float res, const float *x, ffi::Error RmsNormBwdImpl(ffi::Buffer res, ffi::Buffer x, ffi::Buffer ct_y, - ffi::Result> ct_x) { + ffi::ResultBuffer ct_x) { auto [totalSize, lastDim] = GetDims(x); if (lastDim == 0) { - return ffi::Error(ffi::ErrorCode::kInvalidArgument, - "RmsNormBwd inputs must be arrays"); + return ffi::Error::InvalidArgument("RmsNormBwd inputs must be arrays"); } for (int64_t n = 0, idx = 0; n < totalSize; n += lastDim, ++idx) { ComputeRmsNormBwd(lastDim, res.typed_data()[idx], &(x.typed_data()[n]),