mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
streamline readme, add pmap
This commit is contained in:
parent
8927200316
commit
8dad859e04
205
README.md
205
README.md
@ -4,9 +4,10 @@
|
||||
|
||||
# JAX: Autograd and XLA [](https://travis-ci.org/google/jax)
|
||||
|
||||
[**Reference docs**](https://jax.readthedocs.io/en/latest/)
|
||||
[**Quickstart**](#quickstart-colab-in-the-cloud)
|
||||
| [**Transformations**](#transformations)
|
||||
| [**Install guide**](#installation)
|
||||
| [**Quickstart**](#quickstart-colab-in-the-cloud)
|
||||
| [**Reference docs**](https://jax.readthedocs.io/en/latest/)
|
||||
|
||||
JAX is [Autograd](https://github.com/hips/autograd) and
|
||||
[XLA](https://www.tensorflow.org/xla),
|
||||
@ -33,8 +34,10 @@ maximal performance without leaving Python.
|
||||
Dig a little deeper, and you'll see that JAX is really an extensible system for
|
||||
[composable function transformations](#transformations). Both
|
||||
[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
|
||||
are instances of such transformations. Another is [`vmap`](#auto-vectorization-with-vmap)
|
||||
for automatic vectorization, with more to come.
|
||||
are instances of such transformations. Others are
|
||||
[`vmap`](#auto-vectorization-with-vmap) for automatic vectorization and
|
||||
[`pmap`](#spmd-programming-with-pmap) for single-program multiple-data (SPMD)
|
||||
parallel programming, with more to come.
|
||||
|
||||
This is a research project, not an official Google product. Expect bugs and
|
||||
[sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
|
||||
@ -61,22 +64,35 @@ perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grad
|
||||
```
|
||||
|
||||
### Contents
|
||||
* [Transformations](#transformations)
|
||||
* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
|
||||
* [Transformations](#transformations)
|
||||
* [Current gotchas](#current-gotchas)
|
||||
* [Installation](#installation)
|
||||
* [Citing JAX](#citing-jax)
|
||||
* [Reference documentation](#reference-documentation)
|
||||
|
||||
## Quickstart: Colab in the Cloud
|
||||
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:
|
||||
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
|
||||
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb)
|
||||
|
||||
And for a deeper dive into JAX:
|
||||
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
|
||||
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
|
||||
- See the [full list of
|
||||
notebooks](https://github.com/google/jax/tree/master/docs/notebooks).
|
||||
|
||||
## Transformations
|
||||
|
||||
At its core, JAX is an extensible system for transforming numerical functions.
|
||||
We currently expose three important transformations: `grad`, `jit`, and `vmap`.
|
||||
Here are four of primary interest: `grad`, `jit`, `vmap`, and `pmap`.
|
||||
|
||||
### Automatic differentiation with grad
|
||||
### Automatic differentiation with `grad`
|
||||
|
||||
JAX has roughly the same API as [Autograd](https://github.com/hips/autograd).
|
||||
The most popular function is `grad` for reverse-mode gradients:
|
||||
The most popular function is
|
||||
[`grad`](https://jax.readthedocs.io/en/latest/jax.html#jax.grad)
|
||||
for reverse-mode gradients:
|
||||
|
||||
```python
|
||||
from jax import grad
|
||||
@ -88,25 +104,34 @@ def tanh(x): # Define a function
|
||||
|
||||
grad_tanh = grad(tanh) # Obtain its gradient function
|
||||
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
|
||||
# prints 0.41997434161402603
|
||||
# prints 0.4199743
|
||||
```
|
||||
|
||||
You can differentiate to any order with `grad`.
|
||||
|
||||
For more advanced autodiff, you can use `jax.vjp` for reverse-mode
|
||||
vector-Jacobian products and `jax.jvp` for forward-mode Jacobian-vector
|
||||
products. The two can be composed arbitrarily with one another, and with other
|
||||
JAX transformations. Here's one way to compose
|
||||
those to make a function that efficiently computes full Hessian matrices:
|
||||
```python
|
||||
print(grad(grad(grad(tanh)))(1.0))
|
||||
# prints 0.62162673
|
||||
```
|
||||
|
||||
For more advanced autodiff, you can use
|
||||
[`jax.vjp`](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp) for
|
||||
reverse-mode vector-Jacobian products and
|
||||
[`jax.jvp`](https://jax.readthedocs.io/en/latest/jax.html#jax.defjvp) for
|
||||
forward-mode Jacobian-vector products. The two can be composed arbitrarily with
|
||||
one another, and with other JAX transformations. Here's one way to compose those
|
||||
to make a function that efficiently computes [full Hessian
|
||||
matrices](https://jax.readthedocs.io/en/latest/jax.html#jax.hessian):
|
||||
|
||||
```python
|
||||
from jax import jit, jacfwd, jacrev
|
||||
|
||||
def hessian(fun):
|
||||
return jit(jacfwd(jacrev(fun)))
|
||||
```
|
||||
|
||||
As with Autograd, you're free to use differentiation with Python control
|
||||
structures:
|
||||
As with [Autograd](https://github.com/hips/autograd), you're free to use
|
||||
differentiation with Python control structures:
|
||||
|
||||
```python
|
||||
def abs_val(x):
|
||||
@ -120,10 +145,17 @@ print(abs_val_grad(1.0)) # prints 1.0
|
||||
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
|
||||
```
|
||||
|
||||
### Compilation with jit
|
||||
See the [reference docs on automatic
|
||||
differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
|
||||
and the [JAX Autodiff
|
||||
Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
|
||||
for more.
|
||||
|
||||
You can use XLA to compile your functions end-to-end with `jit`, used either as
|
||||
an `@jit` decorator or as a higher-order function.
|
||||
### Compilation with `jit`
|
||||
|
||||
You can use XLA to compile your functions end-to-end with
|
||||
[`jit`](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
|
||||
used either as an `@jit` decorator or as a higher-order function.
|
||||
|
||||
```python
|
||||
import jax.numpy as np
|
||||
@ -141,9 +173,16 @@ fast_f = jit(slow_f)
|
||||
|
||||
You can mix `jit` and `grad` and any other JAX transformation however you like.
|
||||
|
||||
### Auto-vectorization with vmap
|
||||
Using `jit` puts constraints on the kind of Python control flow
|
||||
the function can use; see
|
||||
the [Gotchas
|
||||
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#python-control-flow-+-JIT)
|
||||
for more.
|
||||
|
||||
`vmap` is the vectorizing map.
|
||||
### Auto-vectorization with `vmap`
|
||||
|
||||
[`vmap`](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap) is
|
||||
the vectorizing map.
|
||||
It has the familiar semantics of mapping a function along array axes, but
|
||||
instead of keeping the loop on the outside, it pushes the loop down into a
|
||||
function’s primitive operations for better performance.
|
||||
@ -203,23 +242,118 @@ JAX transformation! We use `vmap` with both forward- and reverse-mode automatic
|
||||
differentiation for fast Jacobian and Hessian matrix calculations in
|
||||
`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`.
|
||||
|
||||
### SPMD programming with `pmap`
|
||||
|
||||
## Quickstart: Colab in the Cloud
|
||||
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:
|
||||
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
|
||||
- [Training a Simple Neural Network, with PyTorch Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/Neural_Network_and_Data_Loading.ipynb)
|
||||
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb)
|
||||
For parallel programming of multiple accelerators, like multiple GPUs, use
|
||||
[`pmap`](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap).
|
||||
With `pmap` you write single-program multiple-data (SPMD) programs, including
|
||||
fast parallel collective communication operations.
|
||||
|
||||
And for a deeper dive into JAX:
|
||||
- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)
|
||||
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
|
||||
- [Directly using XLA in Python](https://jax.readthedocs.io/en/latest/notebooks/XLA_in_Python.html)
|
||||
- [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html)
|
||||
- [MAML Tutorial with JAX](https://jax.readthedocs.io/en/latest/notebooks/maml.html)
|
||||
- [Generative Modeling by Estimating Gradients of Data Distribution in JAX](https://jax.readthedocs.io/en/latest/notebooks/score_matching.html).
|
||||
Here's an example on an 8-GPU machine:
|
||||
|
||||
```python
|
||||
from jax import random
|
||||
|
||||
# Create 8 random 5000 x 6000 matrices, one per GPU
|
||||
keys = random.split(random.PRNGKey(0), 8)
|
||||
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
|
||||
|
||||
# Run a local matmul on each device in parallel (no data transfer)
|
||||
result = pmap(lambda x: np.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
|
||||
|
||||
# Compute the mean on each device in parallel and print the result
|
||||
print(pmap(np.mean)(result))
|
||||
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
|
||||
```
|
||||
|
||||
In addition to expressing pure maps, you can fast use [collective communication
|
||||
operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators)
|
||||
between devices:
|
||||
|
||||
```python
|
||||
from functools import partial
|
||||
from jax import lax
|
||||
|
||||
@partial(pmap, axis_name='i')
|
||||
def normalize(x):
|
||||
return x / lax.psum(x, 'i')
|
||||
|
||||
print(normalize(np.arange(4.)))
|
||||
# prints [0. 0.16666667 0.33333334 0.5 ]
|
||||
```
|
||||
|
||||
You can even [nest `pmap` functions](https://github.com/google/jax) for more
|
||||
sophisticated communication patterns.
|
||||
|
||||
It all composes, so you're free to differentiate through parallel computations:
|
||||
|
||||
```python
|
||||
from jax import grad
|
||||
|
||||
@pmap
|
||||
def f(x):
|
||||
y = np.sin(x)
|
||||
@pmap
|
||||
def g(z):
|
||||
return np.cos(z) * np.tan(y.sum()) * np.tanh(x).sum()
|
||||
return grad(lambda w: np.sum(g(w)))(x)
|
||||
|
||||
print(f(x))
|
||||
# [[ 0. , -0.7170853 ],
|
||||
# [-3.1085174 , -0.4824318 ],
|
||||
# [10.366636 , 13.135289 ],
|
||||
# [ 0.22163185, -0.52112055]]
|
||||
|
||||
print(grad(lambda x: np.sum(f(x)))(x))
|
||||
# [[ -3.2369726, -1.6356447],
|
||||
# [ 4.7572474, 11.606951 ],
|
||||
# [-98.524414 , 42.76499 ],
|
||||
# [ -1.6007166, -1.2568436]]
|
||||
```
|
||||
|
||||
When reverse-mode differentiating a `pmap` function (e.g. with `grad`), the
|
||||
backward pass of the computation is parallelized just like the forward pass.
|
||||
|
||||
See the [SPMD Cookbook](https://github.com/google/jax) and the [SPMD MNIST
|
||||
classifier from scratch
|
||||
example](https://github.com/google/jax/blob/master/examples/spmd_mnist_classifier_fromscratch.py)
|
||||
for more.
|
||||
|
||||
## Current gotchas
|
||||
|
||||
For a more thorough survey of current gotchas, with examples and explanations,
|
||||
we highly recommend reading the [Gotchas
|
||||
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
|
||||
Some standouts:
|
||||
|
||||
1. [In-place mutating updates of
|
||||
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-In-Place-Updates), like `x[i] += y`, aren't supported, but [there are functional alternatives](https://jax.readthedocs.io/en/latest/jax.ops.html). Under a `jit`, those functional alternatives will reuse buffers in-place automatically.
|
||||
2. [Random numbers are
|
||||
different](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers), but for [good reasons](https://github.com/google/jax/blob/master/design_notes/prng.md).
|
||||
3. If you're looking for [convolution
|
||||
operators](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Convolutions),
|
||||
they're in the `jax.lax` package.
|
||||
4. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
|
||||
[to enable
|
||||
double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Double-(64bit)-precision)
|
||||
(64-bit, e.g. `float64`) one needs to set the `jax_enable_x64` variable at
|
||||
startup (or set the environment variable `JAX_ENABLE_X64=True`).
|
||||
5. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
|
||||
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
|
||||
np.float32)).dtype` is `float64` rather than `float32`.
|
||||
6. Some transformations, like `jit`, [constrain how you can use Python control
|
||||
flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Control-Flow).
|
||||
You'll always get loud errors if something goes wrong. You might have to use
|
||||
[`jit`'s `static_argnums`
|
||||
parameter](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit),
|
||||
[structured control flow
|
||||
primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators)
|
||||
like
|
||||
[`lax.scan`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan),
|
||||
or just use `jit` on smaller subfunctions.
|
||||
|
||||
## Installation
|
||||
|
||||
JAX is written in pure Python, but it depends on XLA, which needs to be compiled
|
||||
and installed as the `jaxlib` package. Use the following instructions to
|
||||
install a binary package with `pip`, or to build JAX from source.
|
||||
@ -278,7 +412,8 @@ Please let us know on [the issue tracker](https://github.com/google/jax/issues)
|
||||
if you run into any errors or problems with the prebuilt wheels.
|
||||
|
||||
### Building JAX from source
|
||||
See [Building JAX from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
|
||||
See [Building JAX from
|
||||
source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source).
|
||||
|
||||
|
||||
## Citing JAX
|
||||
@ -290,7 +425,7 @@ To cite this repository:
|
||||
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and Skye Wanderman-Milne},
|
||||
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
|
||||
url = {http://github.com/google/jax},
|
||||
version = {0.1.46},
|
||||
version = {0.1.55},
|
||||
year = {2018},
|
||||
}
|
||||
```
|
||||
|
@ -167,9 +167,7 @@
|
||||
"\n",
|
||||
"Allowing mutation of variables in-place makes program analysis and transformation very difficult. JAX requires a pure functional expression of a numerical program. \n",
|
||||
"\n",
|
||||
"Instead, JAX offers the _functional_ update functions: __index_update__, __index_add__, __index_min__, __index_max__, and the __index__ helper.\n",
|
||||
"\n",
|
||||
"__NB__: _Fancy Indexing_ is __not__ yet supported, but will likely be added to JAX soon.\n",
|
||||
"Instead, JAX offers the _functional_ update functions: [__index_update__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update), [__index_add__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add), [__index_min__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_min.html#jax.ops.index_min), [__index_max__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_max.html#jax.ops.index_max), and the [__index__](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index.html#jax.ops.index) helper.\n",
|
||||
"\n",
|
||||
"️⚠️ inside `jit`'d code and `lax.while_loop` or `lax.fori_loop` the __size__ of slices can't be functions of argument _values_ but only functions of argument _shapes_ -- the slice start indices have no such restriction. See the below __Control Flow__ Section for more information on this limitation."
|
||||
]
|
||||
@ -1101,7 +1099,7 @@
|
||||
"source": [
|
||||
"### Structured control flow primitives\n",
|
||||
"\n",
|
||||
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. then you can use these 4 structured control flow primitives:\n",
|
||||
"There are more options for control flow in JAX. Say you want to avoid re-compilations but still want to use control flow that's traceable, and that avoids un-rolling large loops. Then you can use these 4 structured control flow primitives:\n",
|
||||
"\n",
|
||||
" - `lax.cond` _will be differentiable soon_\n",
|
||||
" - `lax.while_loop` __non-differentiable__*\n",
|
||||
|
Loading…
x
Reference in New Issue
Block a user