Merge pull request #14837 from nouiz:doc_gpu_custom

PiperOrigin-RevId: 515444284
This commit is contained in:
jax authors 2023-03-09 14:36:30 -08:00
commit d9395959ca

View File

@ -4,15 +4,42 @@ JAX ships with a large number of built-in operations, but users occasionally run
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax).
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax) and
supposes that you are familiar with [JAX primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).
## RMS normalization
For this tutorial, we are going to add the RMS normalization as a custom operation in JAX.
Note that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly. However, we are using it as an example to show the process of creating a custom operation for GPUs.
The CUDA code in `gpu_ops/rms_norm_kernels.cu` for this operation has been borrowed from [Apex](https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu) and adapted to eliminate any dependency on PyTorch.
See [`gpu_ops` code listing](#gpu_ops-code-listing) for complete code listing of C++ and CUDA files.
## High-level steps
This tutorial shows how to write both a custom operation and its gradient.
In C:
You need to follow these steps in C for each new JAX primitive:
* Have CUDA kernel(s).
* Create a C function that dispatches the CUDA kernel that will be called by XLA.
* Create a descriptor to convey information needed for the computation.
* The types, the shapes and other attributes.
* Bind C functions to Python
* To create the descriptor and to call the primitive during execution.
In Python:
You need to follow these steps in Python:
* Define a new JAX primitive (instruction/operation)
* Write Python functions to build the graph nodes with the primitive.
* Define its abstract evaluation.
* Define its lowering to MLIR.
* [Optional] Define the gradient.
* [Optional] Use [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) (or one of the experimental [custom_partitioning](https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html) or [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) functions) for fast multi-GPU.
## C code
See [`gpu_ops` code listing](#gpu_ops-code-listing) for a complete code listing of C++ and CUDA files.
`gpu_ops/rms_norm_kernels.cu` defines the following functions, which are declared with the XLA custom function signature.
These functions are responsible for launching RMS normalization kernels with the given `buffers` on the specified `stream`.
@ -85,6 +112,7 @@ PYBIND11_MODULE(gpu_ops, m) {
## Build `gpu_ops` extension module
We build the `gpu_ops` Python extension module with the aforementioned code.
(See [`gpu_ops` code listing](#gpu_ops-code-listing) for a complete code listing of C++ and CUDA files.)
```shell
python -m pip install pybind11==2.10.1