diff --git a/README.md b/README.md index 2b7b0184f..8554a8ce6 100644 --- a/README.md +++ b/README.md @@ -4,9 +4,10 @@ # JAX: Autograd and XLA [![Test status](https://travis-ci.org/google/jax.svg?branch=master)](https://travis-ci.org/google/jax) -[**Reference docs**](https://jax.readthedocs.io/en/latest/) +[**Quickstart**](#quickstart-colab-in-the-cloud) +| [**Transformations**](#transformations) | [**Install guide**](#installation) -| [**Quickstart**](#quickstart-colab-in-the-cloud) +| [**Reference docs**](https://jax.readthedocs.io/en/latest/) JAX is [Autograd](https://github.com/hips/autograd) and [XLA](https://www.tensorflow.org/xla), @@ -33,8 +34,10 @@ maximal performance without leaving Python. Dig a little deeper, and you'll see that JAX is really an extensible system for [composable function transformations](#transformations). Both [`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit) -are instances of such transformations. Another is [`vmap`](#auto-vectorization-with-vmap) -for automatic vectorization, with more to come. +are instances of such transformations. Others are +[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and +[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD) +parallel programming, with more to come. This is a research project, not an official Google product. Expect bugs and [sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). @@ -61,22 +64,35 @@ perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grad ``` ### Contents -* [Transformations](#transformations) * [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud) +* [Transformations](#transformations) +* [Current gotchas](#current-gotchas) * [Installation](#installation) * [Citing JAX](#citing-jax) * [Reference documentation](#reference-documentation) +## Quickstart: Colab in the Cloud +Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) +- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb) + +And for a deeper dive into JAX: +- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) +- See the [full list of +notebooks](https://github.com/google/jax/tree/master/docs/notebooks). ## Transformations At its core, JAX is an extensible system for transforming numerical functions. -We currently expose three important transformations: `grad`, `jit`, and `vmap`. +Here are four of primary interest: `grad`, `jit`, `vmap`, and `pmap`. -### Automatic differentiation with grad +### Automatic differentiation with `grad` JAX has roughly the same API as [Autograd](https://github.com/hips/autograd). -The most popular function is `grad` for reverse-mode gradients: +The most popular function is +[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad) +for reverse-mode gradients: ```python from jax import grad @@ -88,25 +104,34 @@ def tanh(x): # Define a function grad_tanh = grad(tanh) # Obtain its gradient function print(grad_tanh(1.0)) # Evaluate it at x = 1.0 -# prints 0.41997434161402603 +# prints 0.4199743 ``` You can differentiate to any order with `grad`. -For more advanced autodiff, you can use `jax.vjp` for reverse-mode -vector-Jacobian products and `jax.jvp` for forward-mode Jacobian-vector -products. The two can be composed arbitrarily with one another, and with other -JAX transformations. Here's one way to compose -those to make a function that efficiently computes full Hessian matrices: +```python +print(grad(grad(grad(tanh)))(1.0)) +# prints 0.62162673 +``` + +For more advanced autodiff, you can use +[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for +reverse-mode vector-Jacobian products and +[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp) for +forward-mode Jacobian-vector products. The two can be composed arbitrarily with +one another, and with other JAX transformations. Here's one way to compose those +to make a function that efficiently computes [full Hessian +matrices](https://jax.readthedocs.io/en/latest/jax.html#jax.hessian): ```python from jax import jit, jacfwd, jacrev + def hessian(fun): return jit(jacfwd(jacrev(fun))) ``` -As with Autograd, you're free to use differentiation with Python control -structures: +As with [Autograd](https://github.com/hips/autograd), you're free to use +differentiation with Python control structures: ```python def abs_val(x): @@ -120,10 +145,17 @@ print(abs_val_grad(1.0)) # prints 1.0 print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated) ``` -### Compilation with jit +See the [reference docs on automatic +differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation) +and the [JAX Autodiff +Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +for more. -You can use XLA to compile your functions end-to-end with `jit`, used either as -an `@jit` decorator or as a higher-order function. +### Compilation with `jit` + +You can use XLA to compile your functions end-to-end with +[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), +used either as an `@jit` decorator or as a higher-order function. ```python import jax.numpy as np @@ -141,9 +173,16 @@ fast_f = jit(slow_f) You can mix `jit` and `grad` and any other JAX transformation however you like. -### Auto-vectorization with vmap +Using `jit` puts constraints on the kind of Python control flow +the function can use; see +the [Gotchas +Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT) +for more. -`vmap` is the vectorizing map. +### Auto-vectorization with `vmap` + +[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is +the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance. @@ -203,23 +242,118 @@ JAX transformation! We use `vmap` with both forward- and reverse-mode automatic differentiation for fast Jacobian and Hessian matrix calculations in `jax.jacfwd`, `jax.jacrev`, and `jax.hessian`. +### SPMD programming with `pmap` -## Quickstart: Colab in the Cloud -Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) -- [Training a Simple Neural Network, with PyTorch Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) -- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb) +For parallel programming of multiple accelerators, like multiple GPUs, use +[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap). +With `pmap` you write single-program multiple-data (SPMD) programs, including +fast parallel collective communication operations. -And for a deeper dive into JAX: -- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) -- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) -- [Directly using XLA in Python](https://jax.readthedocs.io/en/latest/notebooks/XLA_in_Python.html) -- [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) -- [MAML Tutorial with JAX](https://jax.readthedocs.io/en/latest/notebooks/maml.html) -- [Generative Modeling by Estimating Gradients of Data Distribution in JAX](https://jax.readthedocs.io/en/latest/notebooks/score_matching.html). +Here's an example on an 8-GPU machine: +```python +from jax import random + +# Create 8 random 5000 x 6000 matrices, one per GPU +keys = random.split(random.PRNGKey(0), 8) +mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys) + +# Run a local matmul on each device in parallel (no data transfer) +result = pmap(lambda x: np.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000) + +# Compute the mean on each device in parallel and print the result +print(pmap(np.mean)(result)) +# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157] +``` + +In addition to expressing pure maps, you can fast use [collective communication +operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) +between devices: + +```python +from functools import partial +from jax import lax + +@partial(pmap, axis_name='i') +def normalize(x): + return x / lax.psum(x, 'i') + +print(normalize(np.arange(4.))) +# prints [0. 0.16666667 0.33333334 0.5 ] +``` + +You can even [nest `pmap` functions](https://github.com/google/jax) for more +sophisticated communication patterns. + +It all composes, so you're free to differentiate through parallel computations: + +```python +from jax import grad + +@pmap +def f(x): + y = np.sin(x) + @pmap + def g(z): + return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum() + return grad(lambda w: np.sum(g(w)))(x) + +print(f(x)) +# [[ 0. , -0.7170853 ], +# [-3.1085174 , -0.4824318 ], +# [10.366636 , 13.135289 ], +# [ 0.22163185, -0.52112055]] + +print(grad(lambda x: np.sum(f(x)))(x)) +# [[ -3.2369726, -1.6356447], +# [ 4.7572474, 11.606951 ], +# [-98.524414 , 42.76499 ], +# [ -1.6007166, -1.2568436]] +``` + +When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the +backward pass of the computation is parallelized just like the forward pass. + +See the [SPMD Cookbook](https://github.com/google/jax) and the [SPMD MNIST +classifier from scratch +example](https://github.com/google/jax/blob/master/examples/spmd_mnist_classifier_fromscratch.py) +for more. + +## Current gotchas + +For a more thorough survey of current gotchas, with examples and explanations, +we highly recommend reading the [Gotchas +Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Some standouts: + +1. [In-place mutating updates of + arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-In-Place-Updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically. +2. [Random numbers are + different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers), but for [good reasons](https://github.com/google/jax/blob/master/design_notes/prng.md). +3. If you're looking for [convolution + operators](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Convolutions), + they're in the `jax.lax` package. +4. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and + [to enable + double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision) + (64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at + startup (or set the environment variable `JAX_ENABLE_X64=True`). +5. Some of NumPy's dtype promotion semantics involving a mix of Python scalars + and NumPy types aren't preserved, namely `np.add(1, np.array([2], + np.float32)).dtype` is `float64` rather than `float32`. +6. Some transformations, like `jit`, [constrain how you can use Python control + flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow). + You'll always get loud errors if something goes wrong. You might have to use + [`jit`'s `static_argnums` + parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit), + [structured control flow + primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) + like + [`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan), + or just use `jit` on smaller subfunctions. ## Installation + JAX is written in pure Python, but it depends on XLA, which needs to be compiled and installed as the `jaxlib` package. Use the following instructions to install a binary package with `pip`, or to build JAX from source. @@ -278,7 +412,8 @@ Please let us know on [the issue tracker](https://github.com/google/jax/issues) if you run into any errors or problems with the prebuilt wheels. ### Building JAX from source -See [Building JAX from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). +See [Building JAX from +source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). ## Citing JAX @@ -290,7 +425,7 @@ To cite this repository: author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and Skye Wanderman-Milne}, title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs}, url = {http://github.com/google/jax}, - version = {0.1.46}, + version = {0.1.55}, year = {2018}, } ``` diff --git a/docs/notebooks/Common_Gotchas_in_JAX.ipynb b/docs/notebooks/Common_Gotchas_in_JAX.ipynb index 9104331c9..e53785047 100644 --- a/docs/notebooks/Common_Gotchas_in_JAX.ipynb +++ b/docs/notebooks/Common_Gotchas_in_JAX.ipynb @@ -167,9 +167,7 @@ "\n", "Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program. \n", "\n", - "Instead, JAX offers the _functional_ update functions: __index_update__, __index_add__, __index_min__, __index_max__, and the __index__ helper.\n", - "\n", - "__NB__: _Fancy Indexing_ is __not__ yet supported, but will likely be added to JAX soon.\n", + "Instead, JAX offers the _functional_ update functions: [__index_update__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update), [__index_add__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add), [__index_min__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_min.html#jax.ops.index_min), [__index_max__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_max.html#jax.ops.index_max), and the [__index__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index.html#jax.ops.index) helper.\n", "\n", "️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation." ] @@ -1101,7 +1099,7 @@ "source": [ "### Structured control flow primitives\n", "\n", - "There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:\n", + "There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n", "\n", " - `lax.cond` _will be differentiable soon_\n", " - `lax.while_loop` __non-differentiable__*\n",