mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13950 from jakevdp:fix-gpu-headings
PiperOrigin-RevId: 501111194
This commit is contained in:
commit
3f712480c6
@ -1,11 +1,11 @@
|
||||
# Custom operation for GPUs
|
||||
# Custom operations for GPUs
|
||||
|
||||
JAX, a popular framework for machine learning, comes with a large number of operations. However, users occasionally run into a situation where they need a new operation that is not supported by JAX.
|
||||
JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX.
|
||||
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
|
||||
|
||||
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax).
|
||||
|
||||
# RMS Normalization
|
||||
## RMS Normalization
|
||||
|
||||
For this tutorial, we are going to add the RMS normalization as a custom operation in JAX.
|
||||
Note that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly. However, we are using it as an example to show the process of creating a custom operation for GPUs.
|
||||
@ -81,7 +81,7 @@ PYBIND11_MODULE(gpu_ops, m) {
|
||||
}
|
||||
```
|
||||
|
||||
# Build `gpu_ops` extension module
|
||||
## Build `gpu_ops` extension module
|
||||
|
||||
We build the `gpu_ops` Python extension module with the aforementioned code.
|
||||
|
||||
@ -98,11 +98,11 @@ c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}
|
||||
strip build/gpu_ops$(${python_executable}-config --extension-suffix)
|
||||
```
|
||||
|
||||
# Add RMS normalization to JAX as custom call
|
||||
## Add RMS normalization to JAX as custom call
|
||||
|
||||
`gpu_ops` is just a Python extension module and we need more work to plug it into JAX.
|
||||
|
||||
## Create primitives
|
||||
### Create primitives
|
||||
|
||||
We first create primitives, `_rms_norm_fwd_p` and `_rms_norm_bwd_p`, which the custom functions can be mapped to.
|
||||
We set the `multiple_results` attribute to `True` for these operations, which means that the operation produces multiple outputs as a tuple.
|
||||
@ -145,7 +145,7 @@ def rms_norm_bwd(g, invvar, x, weight, eps):
|
||||
return grad_input, grad_weight
|
||||
```
|
||||
|
||||
## Lowering to MLIR custom call
|
||||
### Lowering to MLIR custom call
|
||||
|
||||
To map the custom functions to the new primitives, `_rms_norm_fwd_p` and `_rms_norm_bwd_p`, we need to:
|
||||
|
||||
@ -268,7 +268,7 @@ mlir.register_lowering(
|
||||
)
|
||||
```
|
||||
|
||||
# Let's test it
|
||||
## Let's test it
|
||||
|
||||
```python
|
||||
per_core_batch_size=4
|
||||
@ -283,7 +283,7 @@ norm_shape = x.shape[-2:]
|
||||
weight = jnp.ones(norm_shape, dtype=jnp.bfloat16)
|
||||
```
|
||||
|
||||
## Test forward function
|
||||
### Test forward function
|
||||
|
||||
```python
|
||||
out = rms_norm_fwd(x, weight)
|
||||
@ -299,7 +299,7 @@ Cell In [5], line 1
|
||||
NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented
|
||||
```
|
||||
|
||||
# Abstract evaluation
|
||||
## Abstract evaluation
|
||||
|
||||
The test above failed with `NotImplementedError: Abstract evaluation for 'rms_norm_fwd' not implemented`. Why did the test fail? What does it mean?
|
||||
|
||||
@ -368,15 +368,15 @@ def _rms_norm_bwd_abstract(grad_output, invvar, x, weight, eps):
|
||||
_rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
|
||||
```
|
||||
|
||||
# Let's test it again
|
||||
## Let's test it again
|
||||
|
||||
## Test forward function
|
||||
### Test forward function
|
||||
|
||||
```python
|
||||
out = rms_norm_fwd(x, weight)
|
||||
```
|
||||
|
||||
## Test backward function
|
||||
### Test backward function
|
||||
|
||||
Now let's test the backward operation using `jax.grad` and `jtu.check_grads`.
|
||||
|
||||
@ -403,7 +403,7 @@ Cell In [8], line 7
|
||||
NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented
|
||||
```
|
||||
|
||||
# Differentiation rule
|
||||
## Differentiation rule
|
||||
|
||||
The backward operation failed with the error `NotImplementedError: Differentiation rule for 'rms_norm_fwd' not implemented`. It means that, although we have defined `rms_norm_fwd` and `rms_norm_bwd`, JAX doesn't know the relationship between them.
|
||||
|
||||
@ -469,11 +469,11 @@ out = loss_grad(x, weight)
|
||||
jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
|
||||
```
|
||||
|
||||
# Let's test it on multiple devices
|
||||
## Let's test it on multiple devices
|
||||
|
||||
We are using [`jax.experimental.pjit.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#jax.experimental.pjit.pjit) for parallel execution on multiple devices, and we produce reference values with sequential execution on a single device.
|
||||
|
||||
## Test forward function
|
||||
### Test forward function
|
||||
|
||||
Let's first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding `x` on all devices.
|
||||
|
||||
@ -594,7 +594,7 @@ True
|
||||
|
||||
With this modification, the `all-gather` operation is eliminated and the custom call is made on each shard of `x`.
|
||||
|
||||
## Test backward function
|
||||
### Test backward function
|
||||
|
||||
We are moving onto the backward operation using `jax.grad` on multiple devices.
|
||||
|
||||
@ -763,7 +763,7 @@ return (
|
||||
)
|
||||
```
|
||||
|
||||
# Let's put it together
|
||||
## Let's put it together
|
||||
|
||||
Here is the complete code.
|
||||
|
||||
@ -1065,11 +1065,11 @@ True
|
||||
True
|
||||
```
|
||||
|
||||
# Appendix
|
||||
## Appendix
|
||||
|
||||
## `gpu_ops` code listing
|
||||
### `gpu_ops` code listing
|
||||
|
||||
### `gpu_ops/kernel_helpers.h`
|
||||
#### `gpu_ops/kernel_helpers.h`
|
||||
|
||||
```cpp
|
||||
// This header is not specific to our application and you'll probably want
|
||||
@ -1123,7 +1123,7 @@ const T *UnpackDescriptor(const char *opaque, std::size_t opaque_len) {
|
||||
#endif
|
||||
```
|
||||
|
||||
### `gpu_ops/kernels.h`
|
||||
#### `gpu_ops/kernels.h`
|
||||
|
||||
```cpp
|
||||
#ifndef _GPU_OPS_KERNELS_H_
|
||||
@ -1157,7 +1157,7 @@ void rms_backward_affine(cudaStream_t stream, void **buffers,
|
||||
#endif
|
||||
```
|
||||
|
||||
### `gpu_ops/pybind11_kernel_helpers.h`
|
||||
#### `gpu_ops/pybind11_kernel_helpers.h`
|
||||
|
||||
```cpp
|
||||
// This header extends kernel_helpers.h with the pybind11 specific interface to
|
||||
@ -1188,7 +1188,7 @@ template <typename T> pybind11::capsule EncapsulateFunction(T *fn) {
|
||||
#endif
|
||||
```
|
||||
|
||||
### `gpu_ops/gpu_ops.cpp`
|
||||
#### `gpu_ops/gpu_ops.cpp`
|
||||
|
||||
```cpp
|
||||
#include "kernels.h"
|
||||
@ -1223,7 +1223,7 @@ PYBIND11_MODULE(gpu_ops, m) {
|
||||
} // namespace
|
||||
```
|
||||
|
||||
### `gpu_ops/rms_norm_kernels.cu`
|
||||
#### `gpu_ops/rms_norm_kernels.cu`
|
||||
|
||||
```cpp
|
||||
#include "kernel_helpers.h"
|
||||
|
Loading…
x
Reference in New Issue
Block a user