Small clarification.

This commit is contained in:
Frederic Bastien 2022-12-21 16:10:51 +00:00 committed by Sean Lee
parent 4a63488825
commit b42e68968c

View File

@ -1,6 +1,6 @@
# Custom operation for GPUs
JAX, a popular framework for machine learning, comes with a large number of operations, however, users occasionally run into a situation where they need a new operation that is not supported by JAX.
JAX, a popular framework for machine learning, comes with a large number of operations. However, users occasionally run into a situation where they need a new operation that is not supported by JAX.
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax).
@ -8,8 +8,8 @@ This tutorial contains information from [Extending JAX with custom C++ and CUDA
# RMS Normalization
For this tutorial, we are going to add the RMS normalization as a custom operation in JAX.
It is to be noted that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly, however, we are using it as an example to show the process of creating a custom operation for GPUs.
The CUDA code in `gpu_ops/rms_norm_kernels.cu` for this operation has been borrowed from <https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu> and adapted to eliminate any dependency on PyTorch.
Note that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly. However, we are using it as an example to show the process of creating a custom operation for GPUs.
The CUDA code in `gpu_ops/rms_norm_kernels.cu` for this operation has been borrowed from [Apex](https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu) and adapted to eliminate any dependency on PyTorch.
See [`gpu_ops` code listing](#gpu_ops-code-listing) for complete code listing of C++ and CUDA files.
`gpu_ops/rms_norm_kernels.cu` defines the following functions, which are declared with the XLA custom function signature.
@ -64,8 +64,8 @@ pybind11::dict RMSNormRegistrations() {
}
PYBIND11_MODULE(gpu_ops, m) {
m.def("rms_norm_registrations", &RMSNormRegistrations);
m.def("rms_norm_descriptor",
m.def("get_rms_norm_registrations", &RMSNormRegistrations);
m.def("create_rms_norm_descriptor",
[](int n1, int n2, double eps, gpu_ops::ElementType x_type,
gpu_ops::ElementType w_type, int part_grad_size) {
return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
@ -85,16 +85,17 @@ PYBIND11_MODULE(gpu_ops, m) {
We build the `gpu_ops` Python extension module with the aforementioned code.
```ipython3
!python -m pip install pybind11==2.10.1
import pybind11
import sys
```shell
python -m pip install pybind11==2.10.1
mkdir -p build
pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())")
python_executable=$(python -c 'import sys; print(sys.executable)')
!mkdir -p build
!nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
!c++ -I/usr/local/cuda/include -I{pybind11.get_include()} $({sys.executable}-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
!c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$({sys.executable}-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
!strip build/gpu_ops$({sys.executable}-config --extension-suffix)
nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
strip build/gpu_ops$(${python_executable}-config --extension-suffix)
```
# Add RMS normalization to JAX as custom call
@ -106,6 +107,7 @@ import sys
We first create primitives, `_rms_norm_fwd_p` and `_rms_norm_bwd_p`, which the custom functions can be mapped to.
We set the `multiple_results` attribute to `True` for these operations, which means that the operation produces multiple outputs as a tuple.
When it is set to `False`, the operation produces a single output without a tuple.
For more details, see [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).
```python
from functools import partial
@ -155,13 +157,15 @@ The functions `_rms_norm_fwd_cuda_lowering` and `_rms_norm_bwd_cuda_lowering` be
Note that an `RMSNormDescriptor` object is created in the lowering function, and passed to the custom call as `opaque`.
```python
from functools import reduce
from jax.interpreters import mlir
from jax.interpreters.mlir import ir
from jaxlib.mhlo_helpers import custom_call
# Register functions defined in gpu_ops as custom call target for GPUs
for _name, _value in gpu_ops.rms_norm_registrations().items():
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
@ -193,7 +197,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
n2 = reduce(lambda x, y: x * y, w_shape)
n1 = reduce(lambda x, y: x * y, x_shape) // n2
opaque = gpu_ops.rms_norm_descriptor(
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
@ -234,7 +238,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
part_grad_shape = ctx.avals_out[-1].shape
opaque = gpu_ops.rms_norm_descriptor(
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
@ -811,7 +815,7 @@ def rms_norm_bwd(eps, res, g):
# Register functions defined in gpu_ops as custom call target for GPUs
for _name, _value in gpu_ops.rms_norm_registrations().items():
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
xla_client.register_custom_call_target(_name, _value, platform="gpu")
@ -843,7 +847,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
n2 = reduce(lambda x, y: x * y, w_shape)
n1 = reduce(lambda x, y: x * y, x_shape) // n2
opaque = gpu_ops.rms_norm_descriptor(
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
@ -884,7 +888,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
part_grad_shape = ctx.avals_out[-1].shape
opaque = gpu_ops.rms_norm_descriptor(
opaque = gpu_ops.create_rms_norm_descriptor(
n1,
n2,
eps,
@ -1201,8 +1205,8 @@ pybind11::dict RMSNormRegistrations() {
}
PYBIND11_MODULE(gpu_ops, m) {
m.def("rms_norm_registrations", &RMSNormRegistrations);
m.def("rms_norm_descriptor",
m.def("get_rms_norm_registrations", &RMSNormRegistrations);
m.def("create_rms_norm_descriptor",
[](int n1, int n2, double eps, gpu_ops::ElementType x_type,
gpu_ops::ElementType w_type, int part_grad_size) {
return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{