streamline readme, add pmap

This commit is contained in:
Matthew Johnson 2019-12-14 07:00:39 -08:00 committed by Matthew Johnson
parent 8927200316
commit 8dad859e04
2 changed files with 172 additions and 39 deletions

205
README.md
View File

@ -4,9 +4,10 @@
# JAX: Autograd and XLA [![Test status](https://travis-ci.org/google/jax.svg?branch=master)](https://travis-ci.org/google/jax)
[**Reference docs**](https://jax.readthedocs.io/en/latest/)
[**Quickstart**](#quickstart-colab-in-the-cloud)
| [**Transformations**](#transformations)
| [**Install guide**](#installation)
| [**Quickstart**](#quickstart-colab-in-the-cloud)
| [**Reference docs**](https://jax.readthedocs.io/en/latest/)
JAX is [Autograd](https://github.com/hips/autograd) and
[XLA](https://www.tensorflow.org/xla),
@ -33,8 +34,10 @@ maximal performance without leaving Python.
Dig a little deeper, and you'll see that JAX is really an extensible system for
[composable function transformations](#transformations). Both
[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
are instances of such transformations. Another is [`vmap`](#auto-vectorization-with-vmap)
for automatic vectorization, with more to come.
are instances of such transformations. Others are
[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and
[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
parallel programming, with more to come.
This is a research project, not an official Google product. Expect bugs and
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
@ -61,22 +64,35 @@ perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grad
```
### Contents
* [Transformations](#transformations)
* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
* [Transformations](#transformations)
* [Current gotchas](#current-gotchas)
* [Installation](#installation)
* [Citing JAX](#citing-jax)
* [Reference documentation](#reference-documentation)
## Quickstart: Colab in the Cloud
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb)
And for a deeper dive into JAX:
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- See the [full list of
notebooks](https://github.com/google/jax/tree/master/docs/notebooks).
## Transformations
At its core, JAX is an extensible system for transforming numerical functions.
We currently expose three important transformations: `grad`, `jit`, and `vmap`.
Here are four of primary interest: `grad`, `jit`, `vmap`, and `pmap`.
### Automatic differentiation with grad
### Automatic differentiation with `grad`
JAX has roughly the same API as [Autograd](https://github.com/hips/autograd).
The most popular function is `grad` for reverse-mode gradients:
The most popular function is
[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad)
for reverse-mode gradients:
```python
from jax import grad
@ -88,25 +104,34 @@ def tanh(x): # Define a function
grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
# prints 0.41997434161402603
# prints 0.4199743
```
You can differentiate to any order with `grad`.
For more advanced autodiff, you can use `jax.vjp` for reverse-mode
vector-Jacobian products and `jax.jvp` for forward-mode Jacobian-vector
products. The two can be composed arbitrarily with one another, and with other
JAX transformations. Here's one way to compose
those to make a function that efficiently computes full Hessian matrices:
```python
print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673
```
For more advanced autodiff, you can use
[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
reverse-mode vector-Jacobian products and
[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp) for
forward-mode Jacobian-vector products. The two can be composed arbitrarily with
one another, and with other JAX transformations. Here's one way to compose those
to make a function that efficiently computes [full Hessian
matrices](https://jax.readthedocs.io/en/latest/jax.html#jax.hessian):
```python
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
```
As with Autograd, you're free to use differentiation with Python control
structures:
As with [Autograd](https://github.com/hips/autograd), you're free to use
differentiation with Python control structures:
```python
def abs_val(x):
@ -120,10 +145,17 @@ print(abs_val_grad(1.0)) # prints 1.0
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
```
### Compilation with jit
See the [reference docs on automatic
differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
and the [JAX Autodiff
Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
for more.
You can use XLA to compile your functions end-to-end with `jit`, used either as
an `@jit` decorator or as a higher-order function.
### Compilation with `jit`
You can use XLA to compile your functions end-to-end with
[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
used either as an `@jit` decorator or as a higher-order function.
```python
import jax.numpy as np
@ -141,9 +173,16 @@ fast_f = jit(slow_f)
You can mix `jit` and `grad` and any other JAX transformation however you like.
### Auto-vectorization with vmap
Using `jit` puts constraints on the kind of Python control flow
the function can use; see
the [Gotchas
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
for more.
`vmap` is the vectorizing map.
### Auto-vectorization with `vmap`
[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is
the vectorizing map.
It has the familiar semantics of mapping a function along array axes, but
instead of keeping the loop on the outside, it pushes the loop down into a
functions primitive operations for better performance.
@ -203,23 +242,118 @@ JAX transformation! We use `vmap` with both forward- and reverse-mode automatic
differentiation for fast Jacobian and Hessian matrix calculations in
`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`.
### SPMD programming with `pmap`
## Quickstart: Colab in the Cloud
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [Training a Simple Neural Network, with PyTorch Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb)
For parallel programming of multiple accelerators, like multiple GPUs, use
[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
With `pmap` you write single-program multiple-data (SPMD) programs, including
fast parallel collective communication operations.
And for a deeper dive into JAX:
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
- [Directly using XLA in Python](https://jax.readthedocs.io/en/latest/notebooks/XLA_in_Python.html)
- [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html)
- [MAML Tutorial with JAX](https://jax.readthedocs.io/en/latest/notebooks/maml.html)
- [Generative Modeling by Estimating Gradients of Data Distribution in JAX](https://jax.readthedocs.io/en/latest/notebooks/score_matching.html).
Here's an example on an 8-GPU machine:
```python
from jax import random
# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: np.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
# Compute the mean on each device in parallel and print the result
print(pmap(np.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
```
In addition to expressing pure maps, you can fast use [collective communication
operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)
between devices:
```python
from functools import partial
from jax import lax
@partial(pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
print(normalize(np.arange(4.)))
# prints [0. 0.16666667 0.33333334 0.5 ]
```
You can even [nest `pmap` functions](https://github.com/google/jax) for more
sophisticated communication patterns.
It all composes, so you're free to differentiate through parallel computations:
```python
from jax import grad
@pmap
def f(x):
y = np.sin(x)
@pmap
def g(z):
return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum()
return grad(lambda w: np.sum(g(w)))(x)
print(f(x))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
print(grad(lambda x: np.sum(f(x)))(x))
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
```
When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
backward pass of the computation is parallelized just like the forward pass.
See the [SPMD Cookbook](https://github.com/google/jax) and the [SPMD MNIST
classifier from scratch
example](https://github.com/google/jax/blob/master/examples/spmd_mnist_classifier_fromscratch.py)
for more.
## Current gotchas
For a more thorough survey of current gotchas, with examples and explanations,
we highly recommend reading the [Gotchas
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
Some standouts:
1. [In-place mutating updates of
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-In-Place-Updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
2. [Random numbers are
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers), but for [good reasons](https://github.com/google/jax/blob/master/design_notes/prng.md).
3. If you're looking for [convolution
operators](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Convolutions),
they're in the `jax.lax` package.
4. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
[to enable
double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision)
(64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
startup (or set the environment variable `JAX_ENABLE_X64=True`).
5. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
np.float32)).dtype` is `float64` rather than `float32`.
6. Some transformations, like `jit`, [constrain how you can use Python control
flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow).
You'll always get loud errors if something goes wrong. You might have to use
[`jit`'s `static_argnums`
parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
[structured control flow
primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators)
like
[`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan),
or just use `jit` on smaller subfunctions.
## Installation
JAX is written in pure Python, but it depends on XLA, which needs to be compiled
and installed as the `jaxlib` package. Use the following instructions to
install a binary package with `pip`, or to build JAX from source.
@ -278,7 +412,8 @@ Please let us know on [the issue tracker](https://github.com/google/jax/issues)
if you run into any errors or problems with the prebuilt wheels.
### Building JAX from source
See [Building JAX from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
See [Building JAX from
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
## Citing JAX
@ -290,7 +425,7 @@ To cite this repository:
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and Skye Wanderman-Milne},
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
url = {http://github.com/google/jax},
version = {0.1.46},
version = {0.1.55},
year = {2018},
}
```

View File

@ -167,9 +167,7 @@
"\n",
"Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program. \n",
"\n",
"Instead, JAX offers the _functional_ update functions: __index_update__, __index_add__, __index_min__, __index_max__, and the __index__ helper.\n",
"\n",
"__NB__: _Fancy Indexing_ is __not__ yet supported, but will likely be added to JAX soon.\n",
"Instead, JAX offers the _functional_ update functions: [__index_update__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update), [__index_add__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add), [__index_min__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_min.html#jax.ops.index_min), [__index_max__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_max.html#jax.ops.index_max), and the [__index__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index.html#jax.ops.index) helper.\n",
"\n",
"️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation."
]
@ -1101,7 +1099,7 @@
"source": [
"### Structured control flow primitives\n",
"\n",
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:\n",
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
"\n",
" - `lax.cond` _will be differentiable soon_\n",
" - `lax.while_loop` __non-differentiable__*\n",