mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #14837 from nouiz:doc_gpu_custom
PiperOrigin-RevId: 515444284
This commit is contained in:
commit
d9395959ca
@ -4,15 +4,42 @@ JAX ships with a large number of built-in operations, but users occasionally run
|
||||
|
||||
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
|
||||
|
||||
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax).
|
||||
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax) and
|
||||
supposes that you are familiar with [JAX primitive](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).
|
||||
|
||||
## RMS normalization
|
||||
|
||||
For this tutorial, we are going to add the RMS normalization as a custom operation in JAX.
|
||||
Note that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly. However, we are using it as an example to show the process of creating a custom operation for GPUs.
|
||||
The CUDA code in `gpu_ops/rms_norm_kernels.cu` for this operation has been borrowed from [Apex](https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu) and adapted to eliminate any dependency on PyTorch.
|
||||
See [`gpu_ops` code listing](#gpu_ops-code-listing) for complete code listing of C++ and CUDA files.
|
||||
|
||||
|
||||
## High-level steps
|
||||
|
||||
This tutorial shows how to write both a custom operation and its gradient.
|
||||
|
||||
In C:
|
||||
You need to follow these steps in C for each new JAX primitive:
|
||||
* Have CUDA kernel(s).
|
||||
* Create a C function that dispatches the CUDA kernel that will be called by XLA.
|
||||
* Create a descriptor to convey information needed for the computation.
|
||||
* The types, the shapes and other attributes.
|
||||
* Bind C functions to Python
|
||||
* To create the descriptor and to call the primitive during execution.
|
||||
|
||||
In Python:
|
||||
You need to follow these steps in Python:
|
||||
* Define a new JAX primitive (instruction/operation)
|
||||
* Write Python functions to build the graph nodes with the primitive.
|
||||
* Define its abstract evaluation.
|
||||
* Define its lowering to MLIR.
|
||||
* [Optional] Define the gradient.
|
||||
* [Optional] Use [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) (or one of the experimental [custom_partitioning](https://jax.readthedocs.io/en/latest/jax.experimental.custom_partitioning.html) or [shard_map](https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html) functions) for fast multi-GPU.
|
||||
|
||||
|
||||
## C code
|
||||
|
||||
See [`gpu_ops` code listing](#gpu_ops-code-listing) for a complete code listing of C++ and CUDA files.
|
||||
`gpu_ops/rms_norm_kernels.cu` defines the following functions, which are declared with the XLA custom function signature.
|
||||
These functions are responsible for launching RMS normalization kernels with the given `buffers` on the specified `stream`.
|
||||
|
||||
@ -85,6 +112,7 @@ PYBIND11_MODULE(gpu_ops, m) {
|
||||
## Build `gpu_ops` extension module
|
||||
|
||||
We build the `gpu_ops` Python extension module with the aforementioned code.
|
||||
(See [`gpu_ops` code listing](#gpu_ops-code-listing) for a complete code listing of C++ and CUDA files.)
|
||||
|
||||
```shell
|
||||
python -m pip install pybind11==2.10.1
|
||||
|
Loading…
x
Reference in New Issue
Block a user