Merge pull request #13950 from jakevdp:fix-gpu-headings

PiperOrigin-RevId: 501111194
This commit is contained in:
jax authors 2023-01-10 15:35:50 -08:00
commit 3f712480c6

View File

@ -1,11 +1,11 @@
# Custom operation for GPUs
# Custom operations for GPUs
JAX, a popular framework for machine learning, comes with a large number of operations. However, users occasionally run into a situation where they need a new operation that is not supported by JAX.
JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX.
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax).
# RMS Normalization
## RMS Normalization
For this tutorial, we are going to add the RMS normalization as a custom operation in JAX.
Note that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly. However, we are using it as an example to show the process of creating a custom operation for GPUs.
@ -81,7 +81,7 @@ PYBIND11_MODULE(gpu_ops, m) {
}
```
# Build `gpu_ops` extension module
## Build `gpu_ops` extension module
We build the `gpu_ops` Python extension module with the aforementioned code.
@ -98,11 +98,11 @@ c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}
strip build/gpu_ops$(${python_executable}-config --extension-suffix)
```
# Add RMS normalization to JAX as custom call
## Add RMS normalization to JAX as custom call
`gpu_ops` is just a Python extension module and we need more work to plug it into JAX.
## Create primitives
### Create primitives
We first create primitives, `_rms_norm_fwd_p` and `_rms_norm_bwd_p`, which the custom functions can be mapped to.
We set the `multiple_results` attribute to `True` for these operations, which means that the operation produces multiple outputs as a tuple.
@ -145,7 +145,7 @@ def rms_norm_bwd(g, invvar, x, weight, eps):
return grad_input, grad_weight
```
## Lowering to MLIR custom call
### Lowering to MLIR custom call
To map the custom functions to the new primitives, `_rms_norm_fwd_p` and `_rms_norm_bwd_p`, we need to:
@ -268,7 +268,7 @@ mlir.register_lowering(
)
```
# Let's test it
## Let's test it
```python
per_core_batch_size=4
@ -283,7 +283,7 @@ norm_shape = x.shape[-2:]
weight = jnp.ones(norm_shape, dtype=jnp.bfloat16)
```
## Test forward function
### Test forward function
```python
out = rms_norm_fwd(x, weight)
@ -299,7 +299,7 @@ Cell In [5], line 1
NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
```
# Abstract evaluation
## Abstract evaluation
The test above failed with `NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented`. Why did the test fail? What does it mean?
@ -368,15 +368,15 @@ def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
```
# Let's test it again
## Let's test it again
## Test forward function
### Test forward function
```python
out = rms_norm_fwd(x, weight)
```
## Test backward function
### Test backward function
Now let's test the backward operation using `jax.grad` and `jtu.check_grads`.
@ -403,7 +403,7 @@ Cell In [8], line 7
NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented
```
# Differentiation rule
## Differentiation rule
The backward operation failed with the error `NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented`. It means that, although we have defined `rms_norm_fwd` and `rms_norm_bwd`, JAX doesn't know the relationship between them.
@ -469,11 +469,11 @@ out = loss_grad(x, weight)
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
```
# Let's test it on multiple devices
## Let's test it on multiple devices
We are using [`jax.experimental.pjit.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#jax.experimental.pjit.pjit) for parallel execution on multiple devices, and we produce reference values with sequential execution on a single device.
## Test forward function
### Test forward function
Let's first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding `x` on all devices.
@ -594,7 +594,7 @@ True
With this modification, the `all-gather` operation is eliminated and the custom call is made on each shard of `x`.
## Test backward function
### Test backward function
We are moving onto the backward operation using `jax.grad` on multiple devices.
@ -763,7 +763,7 @@ return (
)
```
# Let's put it together
## Let's put it together
Here is the complete code.
@ -1065,11 +1065,11 @@ True
True
```
# Appendix
## Appendix
## `gpu_ops` code listing
### `gpu_ops` code listing
### `gpu_ops/kernel_helpers.h`
#### `gpu_ops/kernel_helpers.h`
```cpp
// This header is not specific to our application and you'll probably want
@ -1123,7 +1123,7 @@ const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
#endif
```
### `gpu_ops/kernels.h`
#### `gpu_ops/kernels.h`
```cpp
#ifndef _GPU_OPS_KERNELS_H_
@ -1157,7 +1157,7 @@ void rms_backward_affine(cudaStream_t stream, void **buffers,
#endif
```
### `gpu_ops/pybind11_kernel_helpers.h`
#### `gpu_ops/pybind11_kernel_helpers.h`
```cpp
// This header extends kernel_helpers.h with the pybind11 specific interface to
@ -1188,7 +1188,7 @@ template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
#endif
```
### `gpu_ops/gpu_ops.cpp`
#### `gpu_ops/gpu_ops.cpp`
```cpp
#include "kernels.h"
@ -1223,7 +1223,7 @@ PYBIND11_MODULE(gpu_ops, m) {
} // namespace
```
### `gpu_ops/rms_norm_kernels.cu`
#### `gpu_ops/rms_norm_kernels.cu`
```cpp
#include "kernel_helpers.h"