mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #13911 from 8bitmp3:main
PiperOrigin-RevId: 502743593
This commit is contained in:
commit
b0e30eb067
@ -1,11 +1,12 @@
|
||||
# Custom operations for GPUs
|
||||
# Custom operations for GPUs with C++ and CUDA
|
||||
|
||||
JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX.
|
||||
|
||||
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
|
||||
|
||||
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax).
|
||||
|
||||
## RMS Normalization
|
||||
## RMS normalization
|
||||
|
||||
For this tutorial, we are going to add the RMS normalization as a custom operation in JAX.
|
||||
Note that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly. However, we are using it as an example to show the process of creating a custom operation for GPUs.
|
||||
@ -370,13 +371,13 @@ _rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
|
||||
|
||||
## Let's test it again
|
||||
|
||||
### Test forward function
|
||||
### Test the forward function
|
||||
|
||||
```python
|
||||
out = rms_norm_fwd(x, weight)
|
||||
```
|
||||
|
||||
### Test backward function
|
||||
### Test the backward function
|
||||
|
||||
Now let's test the backward operation using `jax.grad` and `jtu.check_grads`.
|
||||
|
||||
@ -473,7 +474,7 @@ jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
|
||||
|
||||
We are using [`jax.experimental.pjit.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#jax.experimental.pjit.pjit) for parallel execution on multiple devices, and we produce reference values with sequential execution on a single device.
|
||||
|
||||
### Test forward function
|
||||
### Test the forward function
|
||||
|
||||
Let's first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding `x` on all devices.
|
||||
|
||||
@ -594,7 +595,7 @@ True
|
||||
|
||||
With this modification, the `all-gather` operation is eliminated and the custom call is made on each shard of `x`.
|
||||
|
||||
### Test backward function
|
||||
### Test the backward function
|
||||
|
||||
We are moving onto the backward operation using `jax.grad` on multiple devices.
|
||||
|
||||
|
@ -170,12 +170,12 @@ sets up symbolic links from site-packages into the repository.
|
||||
|
||||
(running-tests)=
|
||||
|
||||
# Running the tests
|
||||
## Running the tests
|
||||
|
||||
There are two supported mechanisms for running the JAX tests, either using Bazel
|
||||
or using pytest.
|
||||
|
||||
## Using Bazel
|
||||
### Using Bazel
|
||||
|
||||
First, configure the JAX build by running:
|
||||
```
|
||||
@ -221,7 +221,7 @@ MULTI_GPU="--run_under $PWD/build/parallel_accelerator_execute.sh --test_env=JAX
|
||||
bazel test //tests:gpu_tests //tests:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false --test_tag_filters=-multiaccelerator $MULTI_GPU
|
||||
```
|
||||
|
||||
## Using pytest
|
||||
### Using `pytest`
|
||||
|
||||
To run all the JAX tests using `pytest`, we recommend using `pytest-xdist`,
|
||||
which can run tests in parallel. First, install `pytest-xdist` and
|
||||
@ -232,7 +232,7 @@ Then, from the repository root directory run:
|
||||
pytest -n auto tests
|
||||
```
|
||||
|
||||
## Controlling test behavior
|
||||
### Controlling test behavior
|
||||
|
||||
JAX generates test cases combinatorially, and you can control the number of
|
||||
cases that are generated and checked for each test (default is 10) using the
|
||||
@ -279,7 +279,8 @@ python tests/lax_numpy_test.py --test_targets="testPad"
|
||||
|
||||
The Colab notebooks are tested for errors as part of the documentation build.
|
||||
|
||||
## Doctests
|
||||
### Doctests
|
||||
|
||||
JAX uses pytest in doctest mode to test the code examples within the documentation.
|
||||
You can run this using
|
||||
```
|
||||
@ -294,7 +295,7 @@ Keep in mind that there are several files that are marked to be skipped when the
|
||||
doctest command is run on the full package; you can see the details in
|
||||
[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml)
|
||||
|
||||
# Type checking
|
||||
## Type checking
|
||||
|
||||
We use `mypy` to check the type hints. To check types locally the same way
|
||||
as the CI checks them:
|
||||
@ -312,7 +313,7 @@ in the GitHub CI:
|
||||
pre-commit run mypy
|
||||
```
|
||||
|
||||
# Linting
|
||||
## Linting
|
||||
|
||||
JAX uses the [flake8](https://flake8.pycqa.org/) linter to ensure code quality. You can check
|
||||
your local changes by running:
|
||||
@ -330,7 +331,7 @@ the GitHub tests:
|
||||
pre-commit run flake8
|
||||
```
|
||||
|
||||
# Update documentation
|
||||
## Update documentation
|
||||
|
||||
To rebuild the documentation, install several packages:
|
||||
```
|
||||
@ -352,14 +353,14 @@ in place of `auto` to control how many CPU cores to use.
|
||||
|
||||
(update-notebooks)=
|
||||
|
||||
## Update notebooks
|
||||
### Update notebooks
|
||||
|
||||
We use [jupytext](https://jupytext.readthedocs.io/) to maintain two synced copies of the notebooks
|
||||
in `docs/notebooks`: one in `ipynb` format, and one in `md` format. The advantage of the former
|
||||
is that it can be opened and executed directly in Colab; the advantage of the latter is that
|
||||
it makes it much easier to track diffs within version control.
|
||||
|
||||
### Editing ipynb
|
||||
#### Editing `ipynb`
|
||||
|
||||
For making large changes that substantially modify code and outputs, it is easiest to
|
||||
edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface,
|
||||
@ -367,12 +368,12 @@ open <http://colab.research.google.com> and `Upload` from your local repo.
|
||||
Update it as needed, `Run all cells` then `Download ipynb`.
|
||||
You may want to test that it executes properly, using `sphinx-build` as explained above.
|
||||
|
||||
### Editing md
|
||||
#### Editing `md`
|
||||
|
||||
For making smaller changes to the text content of the notebooks, it is easiest to edit the
|
||||
`.md` versions using a text editor.
|
||||
|
||||
### Syncing notebooks
|
||||
#### Syncing notebooks
|
||||
|
||||
After editing either the ipynb or md versions of the notebooks, you can sync the two versions
|
||||
using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync` on the updated
|
||||
@ -395,7 +396,7 @@ git add docs -u # pre-commit runs on files in git staging.
|
||||
pre-commit run jupytext
|
||||
```
|
||||
|
||||
### Creating new notebooks
|
||||
#### Creating new notebooks
|
||||
|
||||
If you are adding a new notebook to the documentation and would like to use the `jupytext --sync`
|
||||
command discussed here, you can set up your notebook for jupytext by using the following command:
|
||||
@ -407,7 +408,7 @@ jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb
|
||||
This works by adding a `"jupytext"` metadata field to the notebook file which specifies the
|
||||
desired formats, and which the `jupytext --sync` command recognizes when invoked.
|
||||
|
||||
### Notebooks within the sphinx build
|
||||
#### Notebooks within the Sphinx build
|
||||
|
||||
Some of the notebooks are built automatically as part of the pre-submit checks and
|
||||
as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build.
|
||||
@ -419,7 +420,7 @@ re-saves the notebook.
|
||||
We exclude some notebooks from the build, e.g., because they contain long computations.
|
||||
See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs/conf.py).
|
||||
|
||||
## Documentation building on readthedocs.io
|
||||
### Documentation building on `readthedocs.io`
|
||||
|
||||
JAX's auto-generated documentation is at <https://jax.readthedocs.io/>.
|
||||
|
||||
@ -456,5 +457,3 @@ python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
|
||||
cd docs
|
||||
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html
|
||||
```
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user