Merge pull request #13911 from 8bitmp3:main

PiperOrigin-RevId: 502743593
This commit is contained in:
jax authors 2023-01-17 18:50:27 -08:00
commit b0e30eb067
2 changed files with 23 additions and 23 deletions

View File

@ -1,11 +1,12 @@
# Custom operations for GPUs
# Custom operations for GPUs with C++ and CUDA
JAX ships with a large number of built-in operations, but users occasionally run into a situation where they need a new operation that is not supported by JAX.
To accommodate such scenarios, JAX allows users to define custom operations and this tutorial is to explain how we can define one for GPUs and use it in single-GPU and multi-GPU environments.
This tutorial contains information from [Extending JAX with custom C++ and CUDA code](https://github.com/dfm/extending-jax).
## RMS Normalization
## RMS normalization
For this tutorial, we are going to add the RMS normalization as a custom operation in JAX.
Note that the RMS normalization can be expressed with [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) directly. However, we are using it as an example to show the process of creating a custom operation for GPUs.
@ -370,13 +371,13 @@ _rms_norm_bwd_p.def_abstract_eval(_rms_norm_bwd_abstract)
## Let's test it again
### Test forward function
### Test the forward function
```python
out = rms_norm_fwd(x, weight)
```
### Test backward function
### Test the backward function
Now let's test the backward operation using `jax.grad` and `jtu.check_grads`.
@ -473,7 +474,7 @@ jtu.check_grads(loss, (x, weight), modes=["rev"], order=1)
We are using [`jax.experimental.pjit.pjit`](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html#jax.experimental.pjit.pjit) for parallel execution on multiple devices, and we produce reference values with sequential execution on a single device.
### Test forward function
### Test the forward function
Let's first test the forward operation on multiple devices. We are creating a simple 1D mesh and sharding `x` on all devices.
@ -594,7 +595,7 @@ True
With this modification, the `all-gather` operation is eliminated and the custom call is made on each shard of `x`.
### Test backward function
### Test the backward function
We are moving onto the backward operation using `jax.grad` on multiple devices.

View File

@ -170,12 +170,12 @@ sets up symbolic links from site-packages into the repository.
(running-tests)=
# Running the tests
## Running the tests
There are two supported mechanisms for running the JAX tests, either using Bazel
or using pytest.
## Using Bazel
### Using Bazel
First, configure the JAX build by running:
```
@ -221,7 +221,7 @@ MULTI_GPU="--run_under $PWD/build/parallel_accelerator_execute.sh --test_env=JAX
bazel test //tests:gpu_tests //tests:backend_independent_tests --test_env=XLA_PYTHON_CLIENT_PREALLOCATE=false --test_tag_filters=-multiaccelerator $MULTI_GPU
```
## Using pytest
### Using `pytest`
To run all the JAX tests using `pytest`, we recommend using `pytest-xdist`,
which can run tests in parallel. First, install `pytest-xdist` and
@ -232,7 +232,7 @@ Then, from the repository root directory run:
pytest -n auto tests
```
## Controlling test behavior
### Controlling test behavior
JAX generates test cases combinatorially, and you can control the number of
cases that are generated and checked for each test (default is 10) using the
@ -279,7 +279,8 @@ python tests/lax_numpy_test.py --test_targets="testPad"
The Colab notebooks are tested for errors as part of the documentation build.
## Doctests
### Doctests
JAX uses pytest in doctest mode to test the code examples within the documentation.
You can run this using
```
@ -294,7 +295,7 @@ Keep in mind that there are several files that are marked to be skipped when the
doctest command is run on the full package; you can see the details in
[`ci-build.yaml`](https://github.com/google/jax/blob/main/.github/workflows/ci-build.yaml)
# Type checking
## Type checking
We use `mypy` to check the type hints. To check types locally the same way
as the CI checks them:
@ -312,7 +313,7 @@ in the GitHub CI:
pre-commit run mypy
```
# Linting
## Linting
JAX uses the [flake8](https://flake8.pycqa.org/) linter to ensure code quality. You can check
your local changes by running:
@ -330,7 +331,7 @@ the GitHub tests:
pre-commit run flake8
```
# Update documentation
## Update documentation
To rebuild the documentation, install several packages:
```
@ -352,14 +353,14 @@ in place of `auto` to control how many CPU cores to use.
(update-notebooks)=
## Update notebooks
### Update notebooks
We use [jupytext](https://jupytext.readthedocs.io/) to maintain two synced copies of the notebooks
in `docs/notebooks`: one in `ipynb` format, and one in `md` format. The advantage of the former
is that it can be opened and executed directly in Colab; the advantage of the latter is that
it makes it much easier to track diffs within version control.
### Editing ipynb
#### Editing `ipynb`
For making large changes that substantially modify code and outputs, it is easiest to
edit the notebooks in Jupyter or in Colab. To edit notebooks in the Colab interface,
@ -367,12 +368,12 @@ open <http://colab.research.google.com> and `Upload` from your local repo.
Update it as needed, `Run all cells` then `Download ipynb`.
You may want to test that it executes properly, using `sphinx-build` as explained above.
### Editing md
#### Editing `md`
For making smaller changes to the text content of the notebooks, it is easiest to edit the
`.md` versions using a text editor.
### Syncing notebooks
#### Syncing notebooks
After editing either the ipynb or md versions of the notebooks, you can sync the two versions
using [jupytext](https://jupytext.readthedocs.io/) by running `jupytext --sync` on the updated
@ -395,7 +396,7 @@ git add docs -u # pre-commit runs on files in git staging.
pre-commit run jupytext
```
### Creating new notebooks
#### Creating new notebooks
If you are adding a new notebook to the documentation and would like to use the `jupytext --sync`
command discussed here, you can set up your notebook for jupytext by using the following command:
@ -407,7 +408,7 @@ jupytext --set-formats ipynb,md:myst path/to/the/notebook.ipynb
This works by adding a `"jupytext"` metadata field to the notebook file which specifies the
desired formats, and which the `jupytext --sync` command recognizes when invoked.
### Notebooks within the sphinx build
#### Notebooks within the Sphinx build
Some of the notebooks are built automatically as part of the pre-submit checks and
as part of the [Read the docs](https://jax.readthedocs.io/en/latest) build.
@ -419,7 +420,7 @@ re-saves the notebook.
We exclude some notebooks from the build, e.g., because they contain long computations.
See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs/conf.py).
## Documentation building on readthedocs.io
### Documentation building on `readthedocs.io`
JAX's auto-generated documentation is at <https://jax.readthedocs.io/>.
@ -456,5 +457,3 @@ python -m pip install --exists-action=w --no-cache-dir -r docs/requirements.txt
cd docs
python `which sphinx-build` -T -E -b html -d _build/doctrees-readthedocs -D language=en . _build/html
```