mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Small clarification.
This commit is contained in:
parent
4a63488825
commit
b42e68968c
@ -1,6 +1,6 @@
|
||||
# Custom operation for GPUs
|
||||
|
||||
JAX, a popular framework for machine learning, comes with a large number of operations, however, users occasionally run into a situation where they need a new operation that is not supported by JAX.
|
||||
JAX, a popular framework for machine learning, comes with a large number of operations. However, users occasionally run into a situation where they need a new operation that is not supported by JAX.
|
||||
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
|
||||
|
||||
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax).
|
||||
@ -8,8 +8,8 @@ This tutorial contains information from [Extending JAX with custom C++ and CUDA
|
||||
# RMS Normalization
|
||||
|
||||
For this tutorial, we are going to add the RMS normalization as a custom operation in JAX.
|
||||
It is to be noted that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly, however, we are using it as an example to show the process of creating a custom operation for GPUs.
|
||||
The CUDA code in `gpu_ops/rms_norm_kernels.cu` for this operation has been borrowed from <https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu> and adapted to eliminate any dependency on PyTorch.
|
||||
Note that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly. However, we are using it as an example to show the process of creating a custom operation for GPUs.
|
||||
The CUDA code in `gpu_ops/rms_norm_kernels.cu` for this operation has been borrowed from [Apex](https://github.com/NVIDIA/apex/blob/master/csrc/layer_norm_cuda_kernel.cu) and adapted to eliminate any dependency on PyTorch.
|
||||
See [`gpu_ops` code listing](#gpu_ops-code-listing) for complete code listing of C++ and CUDA files.
|
||||
|
||||
`gpu_ops/rms_norm_kernels.cu` defines the following functions, which are declared with the XLA custom function signature.
|
||||
@ -64,8 +64,8 @@ pybind11::dict RMSNormRegistrations() {
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(gpu_ops, m) {
|
||||
m.def("rms_norm_registrations", &RMSNormRegistrations);
|
||||
m.def("rms_norm_descriptor",
|
||||
m.def("get_rms_norm_registrations", &RMSNormRegistrations);
|
||||
m.def("create_rms_norm_descriptor",
|
||||
[](int n1, int n2, double eps, gpu_ops::ElementType x_type,
|
||||
gpu_ops::ElementType w_type, int part_grad_size) {
|
||||
return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
|
||||
@ -85,16 +85,17 @@ PYBIND11_MODULE(gpu_ops, m) {
|
||||
|
||||
We build the `gpu_ops` Python extension module with the aforementioned code.
|
||||
|
||||
```ipython3
|
||||
!python -m pip install pybind11==2.10.1
|
||||
import pybind11
|
||||
import sys
|
||||
```shell
|
||||
python -m pip install pybind11==2.10.1
|
||||
mkdir -p build
|
||||
pybind_include_path=$(python -c "import pybind11; print(pybind11.get_include())")
|
||||
python_executable=$(python -c 'import sys; print(sys.executable)')
|
||||
|
||||
!mkdir -p build
|
||||
!nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
|
||||
!c++ -I/usr/local/cuda/include -I{pybind11.get_include()} $({sys.executable}-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
|
||||
!c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$({sys.executable}-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
|
||||
!strip build/gpu_ops$({sys.executable}-config --extension-suffix)
|
||||
|
||||
nvcc --threads 4 -Xcompiler -Wall -ldl --expt-relaxed-constexpr -O3 -DNDEBUG -Xcompiler -O3 --generate-code=arch=compute_70,code=[compute_70,sm_70] --generate-code=arch=compute_75,code=[compute_75,sm_75] --generate-code=arch=compute_80,code=[compute_80,sm_80] --generate-code=arch=compute_86,code=[compute_86,sm_86] -Xcompiler=-fPIC -Xcompiler=-fvisibility=hidden -x cu -c gpu_ops/rms_norm_kernels.cu -o build/rms_norm_kernels.cu.o
|
||||
c++ -I/usr/local/cuda/include -I$pybind_include_path $(${python_executable}-config --cflags) -O3 -DNDEBUG -O3 -fPIC -fvisibility=hidden -flto -fno-fat-lto-objects -o build/gpu_ops.cpp.o -c gpu_ops/gpu_ops.cpp
|
||||
c++ -fPIC -O3 -DNDEBUG -O3 -flto -shared -o build/gpu_ops$(${python_executable}-config --extension-suffix) build/gpu_ops.cpp.o build/rms_norm_kernels.cu.o -L/usr/local/cuda/lib64 -lcudadevrt -lcudart_static -lrt -lpthread -ldl
|
||||
strip build/gpu_ops$(${python_executable}-config --extension-suffix)
|
||||
```
|
||||
|
||||
# Add RMS normalization to JAX as custom call
|
||||
@ -106,6 +107,7 @@ import sys
|
||||
We first create primitives, `_rms_norm_fwd_p` and `_rms_norm_bwd_p`, which the custom functions can be mapped to.
|
||||
We set the `multiple_results` attribute to `True` for these operations, which means that the operation produces multiple outputs as a tuple.
|
||||
When it is set to `False`, the operation produces a single output without a tuple.
|
||||
For more details, see [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html).
|
||||
|
||||
```python
|
||||
from functools import partial
|
||||
@ -155,13 +157,15 @@ The functions `_rms_norm_fwd_cuda_lowering` and `_rms_norm_bwd_cuda_lowering` be
|
||||
Note that an `RMSNormDescriptor` object is created in the lowering function, and passed to the custom call as `opaque`.
|
||||
|
||||
```python
|
||||
from functools import reduce
|
||||
|
||||
from jax.interpreters import mlir
|
||||
from jax.interpreters.mlir import ir
|
||||
from jaxlib.mhlo_helpers import custom_call
|
||||
|
||||
|
||||
# Register functions defined in gpu_ops as custom call target for GPUs
|
||||
for _name, _value in gpu_ops.rms_norm_registrations().items():
|
||||
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="gpu")
|
||||
|
||||
|
||||
@ -193,7 +197,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
|
||||
n2 = reduce(lambda x, y: x * y, w_shape)
|
||||
n1 = reduce(lambda x, y: x * y, x_shape) // n2
|
||||
|
||||
opaque = gpu_ops.rms_norm_descriptor(
|
||||
opaque = gpu_ops.create_rms_norm_descriptor(
|
||||
n1,
|
||||
n2,
|
||||
eps,
|
||||
@ -234,7 +238,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
|
||||
|
||||
part_grad_shape = ctx.avals_out[-1].shape
|
||||
|
||||
opaque = gpu_ops.rms_norm_descriptor(
|
||||
opaque = gpu_ops.create_rms_norm_descriptor(
|
||||
n1,
|
||||
n2,
|
||||
eps,
|
||||
@ -811,7 +815,7 @@ def rms_norm_bwd(eps, res, g):
|
||||
|
||||
|
||||
# Register functions defined in gpu_ops as custom call target for GPUs
|
||||
for _name, _value in gpu_ops.rms_norm_registrations().items():
|
||||
for _name, _value in gpu_ops.get_rms_norm_registrations().items():
|
||||
xla_client.register_custom_call_target(_name, _value, platform="gpu")
|
||||
|
||||
|
||||
@ -843,7 +847,7 @@ def _rms_norm_fwd_cuda_lowering(ctx, x, weight, eps):
|
||||
n2 = reduce(lambda x, y: x * y, w_shape)
|
||||
n1 = reduce(lambda x, y: x * y, x_shape) // n2
|
||||
|
||||
opaque = gpu_ops.rms_norm_descriptor(
|
||||
opaque = gpu_ops.create_rms_norm_descriptor(
|
||||
n1,
|
||||
n2,
|
||||
eps,
|
||||
@ -884,7 +888,7 @@ def _rms_norm_bwd_cuda_lowering(ctx, grad_output, invvar, x, weight, eps):
|
||||
|
||||
part_grad_shape = ctx.avals_out[-1].shape
|
||||
|
||||
opaque = gpu_ops.rms_norm_descriptor(
|
||||
opaque = gpu_ops.create_rms_norm_descriptor(
|
||||
n1,
|
||||
n2,
|
||||
eps,
|
||||
@ -1201,8 +1205,8 @@ pybind11::dict RMSNormRegistrations() {
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(gpu_ops, m) {
|
||||
m.def("rms_norm_registrations", &RMSNormRegistrations);
|
||||
m.def("rms_norm_descriptor",
|
||||
m.def("get_rms_norm_registrations", &RMSNormRegistrations);
|
||||
m.def("create_rms_norm_descriptor",
|
||||
[](int n1, int n2, double eps, gpu_ops::ElementType x_type,
|
||||
gpu_ops::ElementType w_type, int part_grad_size) {
|
||||
return gpu_ops::PackDescriptor(gpu_ops::RMSNormDescriptor{
|
||||
|
Loading…
x
Reference in New Issue
Block a user