rocm_jax/README.md

785 lines
33 KiB
Markdown
Raw Normal View History

<div align="center">
<img src="https://raw.githubusercontent.com/google/jax/master/images/jax_logo_250px.png" alt="logo"></img>
</div>
2018-12-13 17:12:06 -08:00
# JAX: Autograd and XLA [![Test status](https://travis-ci.org/google/jax.svg?branch=master)](https://travis-ci.org/google/jax)
2018-12-07 07:39:16 -08:00
[**Reference docs**](https://jax.readthedocs.io/en/latest/)
| [**Install guide**](#installation)
| [**Quickstart**](#quickstart-colab-in-the-cloud)
JAX is [Autograd](https://github.com/hips/autograd) and
2018-12-07 07:39:16 -08:00
[XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md),
brought together for high-performance machine learning research.
With its updated version of [Autograd](https://github.com/hips/autograd),
JAX can automatically differentiate native
Python and NumPy functions. It can differentiate through loops, branches,
recursion, and closures, and it can take derivatives of derivatives of
derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
via [`grad`](#automatic-differentiation-with-grad) as well as forward-mode differentiation,
and the two can be composed arbitrarily to any order.
2018-12-07 07:39:16 -08:00
Whats new is that JAX uses
[XLA](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/g3doc/overview.md)
2018-12-07 11:38:53 -05:00
to compile and run your NumPy programs on GPUs and TPUs. Compilation happens
under the hood by default, with library calls getting just-in-time compiled and
executed. But JAX also lets you just-in-time compile your own Python functions
into XLA-optimized kernels using a one-function API,
[`jit`](#compilation-with-jit). Compilation and automatic differentiation can be
composed arbitrarily, so you can express sophisticated algorithms and get
maximal performance without leaving Python.
Dig a little deeper, and you'll see that JAX is really an extensible system for
2018-12-12 21:32:30 -08:00
[composable function transformations](#transformations). Both
[`grad`](#automatic-differentiation-with-grad) and [`jit`](#compilation-with-jit)
are instances of such transformations. Another is [`vmap`](#auto-vectorization-with-vmap)
for automatic vectorization, with more to come.
2018-12-07 07:39:16 -08:00
This is a research project, not an official Google product. Expect bugs and
[sharp edges](https://colab.research.google.com/github/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb).
2019-03-24 07:59:50 -07:00
Please help by trying it out, [reporting
2018-12-07 07:39:16 -08:00
bugs](https://github.com/google/jax/issues), and letting us know what you
think!
```python
import jax.numpy as np
from jax import grad, jit, vmap
def predict(params, inputs):
for W, b in params:
outputs = np.dot(inputs, W) + b
inputs = np.tanh(outputs)
return outputs
def logprob_fun(params, inputs, targets):
preds = predict(params, inputs)
return np.sum((preds - targets)**2)
grad_fun = jit(grad(logprob_fun)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grads
2018-12-07 07:39:16 -08:00
```
JAX started as a research project by [Matt Johnson](https://github.com/mattjj),
[Roy Frostig](https://github.com/froystig), [Dougal
Maclaurin](https://github.com/dougalm), and [Chris
Leary](https://github.com/learyg), and is now developed [in the
open](https://github.com/google/jax) by a growing number of
[contributors](#contributors).
2018-12-07 21:51:36 -05:00
### Contents
* [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud)
* [Installation](#installation)
* [Running the tests](#running-the-tests)
* [Reference documentation](#reference-documentation)
2018-12-07 21:51:36 -05:00
* [A brief tour](#a-brief-tour)
* [What's supported](#whats-supported)
* [Transformations](#transformations)
* [Random numbers are different](#random-numbers-are-different)
* [Mini-libraries](#mini-libraries)
* [How it works](#how-it-works)
* [What we're working on](#what-were-working-on)
* [Current gotchas](#current-gotchas)
2018-12-07 07:39:16 -08:00
## Quickstart: Colab in the Cloud
2019-03-24 07:59:50 -07:00
Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://colab.research.google.com/github/google/jax/blob/master/notebooks/quickstart.ipynb)
2018-12-10 19:22:10 -05:00
- [Training a Simple Neural Network, with PyTorch Data Loading](https://colab.research.google.com/github/google/jax/blob/master/notebooks/neural_network_and_data_loading.ipynb)
2019-03-13 13:42:38 -04:00
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/notebooks/neural_network_with_tfds_data.ipynb)
2019-03-24 07:59:50 -07:00
And for a deeper dive into JAX:
- [Common gotchas and sharp edges](https://colab.research.google.com/github/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb)
2019-03-24 07:59:50 -07:00
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://colab.research.google.com/github/google/jax/blob/master/notebooks/autodiff_cookbook.ipynb)
- [Directly using XLA in Python](https://colab.research.google.com/github/google/jax/blob/master/notebooks/XLA_in_Python.ipynb)
- [MAML Tutorial with JAX](https://colab.research.google.com/github/google/jax/blob/master/notebooks/maml.ipynb)
- [Generative Modeling by Estimating Gradeints of Data Distribution in JAX](https://colab.research.google.com/github/google/jax/blob/master/notebooks/score-matching.ipynb).
2018-12-10 19:22:10 -05:00
2018-12-07 07:39:16 -08:00
## Installation
JAX is written in pure Python, but it depends on XLA, which needs to be compiled
and installed as the `jaxlib` package. Use the following instructions to
install a binary package with `pip`, or to build JAX from source.
We support installing or building `jaxlib` on Linux and macOS platforms, but not
Windows. We're not currently working on Windows support, but contributions are
2019-04-15 07:54:46 -07:00
welcome (see [#438](https://github.com/google/jax/issues/438)).
2018-12-07 07:39:16 -08:00
### pip installation
To install a CPU-only version, which might be useful for doing local
development on a laptop, you can run
```bash
pip install --upgrade jax jaxlib # CPU-only version
2018-12-07 07:39:16 -08:00
```
If you want to install JAX with both CPU and GPU support, using existing CUDA
and CUDNN7 installations on your machine (for example, preinstalled on your
cloud VM), you can run
```bash
# install jaxlib
PYTHON_VERSION=cp37 # alternatives: cp27, cp35, cp36, cp37
2018-12-07 07:39:16 -08:00
CUDA_VERSION=cuda92 # alternatives: cuda90, cuda92, cuda100
PLATFORM=linux_x86_64 # alternatives: linux_x86_64
2019-08-08 22:09:30 -04:00
BASE_URL='https://storage.googleapis.com/jax-releases'
pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.23-$PYTHON_VERSION-none-$PLATFORM.whl
2018-12-07 07:39:16 -08:00
pip install --upgrade jax # install jax
2018-12-07 07:39:16 -08:00
```
The library package name must correspond to the version of the existing CUDA
installation you want to use, with `cuda100` for CUDA 10.0, `cuda92` for CUDA
9.2, and `cuda90` for CUDA 9.0. To find your CUDA and CUDNN versions, you can
2019-01-02 12:51:26 -08:00
run commands like these, depending on your CUDNN install path:
2018-12-07 07:39:16 -08:00
```bash
nvcc --version
grep CUDNN_MAJOR -A 2 /usr/local/cuda/include/cudnn.h # might need different path
```
The Python version must match your Python interpreter. There are prebuilt wheels
for Python 2.7, 3.5, 3.6, and 3.7; for anything else, you must build from
source.
Please let us know on [the issue tracker](https://github.com/google/jax/issues)
if you run into any errors or problems with the prebuilt wheels.
### Building JAX from source
First, obtain the JAX source code.
```bash
git clone https://github.com/google/jax
cd jax
```
You must also install some prerequisites:
* a C++ compiler (g++ or clang)
* Numpy
* Scipy
* Cython
On Ubuntu 18.04 or Debian you can install the necessary prerequisites with:
```
sudo apt-get install g++ python python3-dev python3-numpy python3-scipy cython3
```
If you are building on a Mac, make sure XCode and the XCode command line tools
are installed.
You can also install the necessary Python dependencies using `pip`:
```
pip install numpy scipy cython
```
To build `jaxlib` with CUDA support, you can run
```bash
python build/build.py --enable_cuda
pip install -e build # installs jaxlib (includes XLA)
pip install -e . # installs jax (pure Python)
```
See `python build/build.py --help` for configuration options, including ways to
specify the paths to CUDA and CUDNN, which you must have installed. The build
also depends on NumPy, and a compiler toolchain corresponding to that of
Ubuntu 16.04 or newer.
To build `jaxlib` without CUDA GPU support (CPU only), drop the `--enable_cuda`:
```bash
python build/build.py
pip install -e build # installs jaxlib (includes XLA)
pip install -e . # installs jax
```
To upgrade to the latest version from GitHub, just run `git pull` from the JAX
repository root, and rebuild by running `build.py` if necessary. You shouldn't have
to reinstall because `pip install -e` sets up symbolic links from site-packages
into the repository.
## Running the tests
To run all the JAX tests, we recommend using `pytest-xdist`, which can run tests in
parallel. First, install `pytest-xdist` by running `pip install pytest-xdist`.
Then, from the repository root directory run
```bash
2019-04-16 10:00:20 -04:00
pytest -n auto tests
```
JAX generates test cases combinatorially, and you can control the number of
cases that are generated and checked for each test (default 10):
```bash
2019-04-16 10:00:20 -04:00
JAX_NUM_GENERATED_CASES=100 pytest -n auto tests
```
You can run a more specific set of tests using
[`pytest`](https://docs.pytest.org/en/latest/usage.html#specifying-tests-selecting-tests)'s
built-in selection mechanisms, or alternatively you can run a specific test
file directly to see more detailed information about the cases being run:
```bash
python tests/lax_numpy_test.py --num_generated_cases=5
```
## Reference documentation
For details about the JAX API, see the
[reference documentation](https://jax.readthedocs.io/).
2018-12-07 07:39:16 -08:00
## A brief tour
```python
In [1]: import jax.numpy as np
In [2]: from jax import random
In [3]: key = random.PRNGKey(0)
In [4]: x = random.normal(key, (5000, 5000))
In [5]: print(np.dot(x, x.T) / 2) # fast!
[[ 2.52727051e+03 8.15895557e+00 -8.53276134e-01 ..., # ...
2019-03-03 00:17:17 +05:30
In [6]: print(np.dot(x, x.T) / 2) # even faster!
# JIT-compiled code is cached and reused in the 2nd call
2018-12-07 07:39:16 -08:00
[[ 2.52727051e+03 8.15895557e+00 -8.53276134e-01 ..., # ...
```
Whats happening behind-the-scenes is that JAX is using XLA to just-in-time
(JIT) compile and execute these individual operations on the GPU. First the
`random.normal` call is compiled and the array referred to by `x` is generated
on the GPU. Next, each function called on `x` (namely `transpose`, `dot`, and
2018-12-07 11:38:53 -05:00
`divide`) is individually JIT-compiled and executed, each keeping its results on
the device.
2018-12-07 07:39:16 -08:00
Its only when a value needs to be printed, plotted, saved, or passed into a raw
NumPy function that a read-only copy of the value is brought back to the host as
2018-12-07 11:38:53 -05:00
an ndarray and cached. The second call to `dot` is faster because the
2018-12-07 07:39:16 -08:00
JIT-compiled code is cached and reused, saving the compilation time.
The fun really starts when you use `grad` for automatic differentiation and
`jit` to compile your own functions end-to-end. Heres a more complete toy
example:
```python
from jax import grad, jit
import jax.numpy as np
def sigmoid(x):
return 0.5 * (np.tanh(x / 2.) + 1)
# Outputs probability of a label being true according to logistic model.
def logistic_predictions(weights, inputs):
return sigmoid(np.dot(inputs, weights))
# Training loss is the negative log-likelihood of the training labels.
def loss(weights, inputs, targets):
preds = logistic_predictions(weights, inputs)
label_logprobs = np.log(preds) * targets + np.log(1 - preds) * (1 - targets)
return -np.sum(label_logprobs)
2018-12-07 07:39:16 -08:00
# Build a toy dataset.
inputs = np.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = np.array([True, True, False, True])
# Define a compiled function that returns gradients of the training loss
training_gradient_fun = jit(grad(loss))
# Optimize weights using gradient descent.
weights = np.array([0.0, 0.0, 0.0])
print("Initial loss: {:0.2f}".format(loss(weights, inputs, targets)))
for i in range(100):
weights -= 0.1 * training_gradient_fun(weights, inputs, targets)
print("Trained loss: {:0.2f}".format(loss(weights, inputs, targets)))
```
To see more, check out the [quickstart
notebook](https://colab.research.google.com/github/google/jax/blob/master/notebooks/quickstart.ipynb),
a [simple MNIST classifier
example](https://github.com/google/jax/blob/master/examples/mnist_classifier.py)
and the rest of the [JAX
examples](https://github.com/google/jax/blob/master/examples/).
## What's supported
If youre using JAX just as an accelerator-backed NumPy, without using `grad` or
`jit` in your code, then in principle there are no constraints, though some
NumPy functions havent been implemented yet. A list of supported functions can
be found in the [reference documentation](https://jax.readthedocs.io/).
Generally using `np.dot(A, B)` is
2018-12-07 07:39:16 -08:00
better than `A.dot(B)` because the former gives us more opportunities to run the
computation on the device. NumPy also does a lot of work to cast any array-like
function arguments to arrays, as in `np.sum([x, y])`, while `jax.numpy`
typically requires explicit casting of array arguments, like
`np.sum(np.array([x, y]))`.
2018-12-07 11:38:53 -05:00
For automatic differentiation with `grad`, JAX has the same restrictions
2018-12-07 07:39:16 -08:00
as [Autograd](https://github.com/hips/autograd). Specifically, differentiation
works with indexing (`x = A[i, j, :]`) but not indexed assignment (`A[i, j] =
x`) or indexed in-place updating (`A[i] += b`) (use
[`jax.ops.index_update`](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update)
or
[`jax.ops.index_add`](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add)
instead). You can use lists, tuples, and
2019-01-02 12:51:26 -08:00
dicts freely: JAX doesn't even see them. Using `np.dot(A, B)` rather than
2018-12-07 11:38:53 -05:00
`A.dot(B)` is required for automatic differentiation when `A` is a raw ndarray.
2018-12-07 07:39:16 -08:00
For compiling your own functions with `jit` there are a few more requirements.
Because `jit` aims to specialize Python functions only on shapes and dtypes
during tracing, rather than on concrete values, Python control flow that depends
on concrete values wont be able to execute and will instead raise an error. If
you want compiled control flow, use structured control flow primitives like
lax.cond and lax.while_loop. Some indexing features, like slice-based indexing
`A[i:i+5]` for argument-dependent `i`, or boolean-based indexing `A[bool_ind]`
for argument-dependent `bool_ind`, produce abstract values of unknown shape and
are thus unsupported in `jit` functions.
2018-12-07 07:39:16 -08:00
2018-12-07 11:38:53 -05:00
In general, JAX is intended to be used with a functional style of Python
programming. Functions passed to transformations like `grad` and `jit` are
expected to be free of side-effects. You can write print statements for
debugging but they may only be executed once if they're under a `jit` decorator.
2018-12-07 07:39:16 -08:00
> TLDR **Do use**
>
> * Functional programming
> * [Many](https://jax.readthedocs.io/en/latest/jax.numpy.html) of NumPys
2018-12-07 07:39:16 -08:00
> functions (help us add more!)
> * [Some](https://jax.readthedocs.io/en/latest/jax.scipy.html) SciPy functions
2018-12-07 07:39:16 -08:00
> * Indexing and slicing of arrays like `x = A[[5, 1, 7], :, 2:4]`
> * Explicit array creation from lists like `A = np.array([x, y])`
>
> **Dont use**
>
> * Assignment into arrays like `A[0, 0] = x` (use
2019-07-17 03:18:28 -05:00
> [`jax.ops.index_update`](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update)
> instead)
2018-12-07 07:39:16 -08:00
> * Implicit casting to arrays like `np.sum([x, y])` (use `np.sum(np.array([x,
> y])` instead)
> * `A.dot(B)` method syntax for functions of more than one argument (use
> `np.dot(A, B)` instead)
> * Side-effects like mutation of arguments or mutation of global variables
> * The `out` argument of NumPy functions
> * Dtype casting like `np.float64(x)` (use `x.astype('float64')` or
> `x.astype(np.float64)` instead).
2018-12-07 07:39:16 -08:00
>
> **For jit functions, also dont use**
>
> * Control flow based on dynamic values `if x > 0: ...`. Control flow based
> on shapes is fine: `if x.shape[0] > 2: ...` and `for subarr in array`.
> * Slicing `A[i:i+5]` for dynamic index `i` (use `lax.dynamic_slice` instead)
> or boolean indexing `A[bool_ind]` for traced values `bool_ind`.
You should get loud errors if your code violates any of these.
## Transformations
2018-12-07 11:38:53 -05:00
At its core, JAX is an extensible system for transforming numerical functions.
We currently expose three important transformations: `grad`, `jit`, and `vmap`.
2018-12-07 07:39:16 -08:00
### Automatic differentiation with grad
JAX has roughly the same API as [Autograd](https://github.com/hips/autograd).
The most popular function is `grad` for reverse-mode gradients:
```python
from jax import grad
import jax.numpy as np
def tanh(x): # Define a function
y = np.exp(-2.0 * x)
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
# prints 0.41997434161402603
```
You can differentiate to any order with `grad`.
For more advanced autodiff, you can use `jax.vjp` for reverse-mode
vector-Jacobian products and `jax.jvp` for forward-mode Jacobian-vector
products. The two can be composed arbitrarily with one another, and with other
JAX transformations. Here's one way to compose
those to make a function that efficiently computes full Hessian matrices:
```python
from jax import jit, jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
```
As with Autograd, you're free to use differentiation with Python control
structures:
```python
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = grad(abs_val)
2018-12-08 03:40:33 -05:00
print(abs_val_grad(1.0)) # prints 1.0
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
2018-12-07 07:39:16 -08:00
```
### Compilation with jit
You can use XLA to compile your functions end-to-end with `jit`, used either as
an `@jit` decorator or as a higher-order function.
```python
import jax.numpy as np
from jax import jit
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
x = np.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
```
You can mix `jit` and `grad` and any other JAX transformation however you like.
### Auto-vectorization with vmap
2018-12-07 11:38:53 -05:00
`vmap` is the vectorizing map.
2018-12-07 07:39:16 -08:00
It has the familiar semantics of mapping a function along array axes, but
instead of keeping the loop on the outside, it pushes the loop down into a
functions primitive operations for better performance.
Using `vmap` can save you from having to carry around batch dimensions in your
code. For example, consider this simple *unbatched* neural network prediction
function:
```python
def predict(params, input_vec):
assert input_vec.ndim == 1
for W, b in params:
output_vec = np.dot(W, input_vec) + b # `input_vec` on the right-hand side!
input_vec = np.tanh(output_vec)
return output_vec
```
We often instead write `np.dot(inputs, W)` to allow for a batch dimension on the
left side of `inputs`, but weve written this particular prediction function to
apply only to single input vectors. If we wanted to apply this function to a
batch of inputs at once, semantically we could just write
```python
from functools import partial
predictions = np.stack(list(map(partial(predict, params), input_batch)))
```
But pushing one example through the network at a time would be slow! Its better
to vectorize the computation, so that at every layer were doing matrix-matrix
multiplies rather than matrix-vector multiplies.
The `vmap` function does that transformation for us. That is, if we write
```python
from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
2018-12-11 12:55:39 -08:00
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
2018-12-07 07:39:16 -08:00
```
then the `vmap` function will push the outer loop inside the function, and our
machine will end up executing matrix-matrix multiplications exactly as if wed
done the batching by hand.
Its easy enough to manually batch a simple neural network without `vmap`, but
in other cases manual vectorization can be impractical or impossible. Take the
problem of efficiently computing per-example gradients: that is, for a fixed set
of parameters, we want to compute the gradient of our loss function evaluated
separately at each example in a batch. With `vmap`, its easy:
```python
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
2018-12-07 07:39:16 -08:00
```
Of course, `vmap` can be arbitrarily composed with `jit`, `grad`, and any other
JAX transformation! We use `vmap` with both forward- and reverse-mode automatic
differentiation for fast Jacobian and Hessian matrix calculations in
`jax.jacfwd`, `jax.jacrev`, and `jax.hessian`.
## Random numbers are different
2019-01-18 07:03:12 -08:00
JAX needs a [functional pseudo-random number generator (PRNG) system](design_notes/prng.md) to provide
2018-12-07 07:39:16 -08:00
reproducible results invariant to compilation boundaries and backends, while
also maximizing performance by enabling vectorized generation and
parallelization across random calls. The `numpy.random` library doesnt have
those properties. The `jax.random` library meets those needs: its functionally
pure, but it doesnt require you to pass stateful random objects back out of
every function.
The `jax.random` library uses
[count-based PRNGs](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf)
and a functional array-oriented
[splitting model](http://publications.lib.chalmers.se/records/fulltext/183348/local_183348.pdf).
To generate random values, you call a function like `jax.random.normal` and give
it a PRNG key:
```python
import jax.random as random
key = random.PRNGKey(0)
print(random.normal(key, shape=(3,))) # [ 1.81608593 -0.48262325 0.33988902]
```
If we make the same call again with the same key, we get the same values:
```python
print(random.normal(key, shape=(3,))) # [ 1.81608593 -0.48262325 0.33988902]
```
The key never gets updated. So how do we get fresh random values? We use
`jax.random.split` to create new keys from existing ones. A common pattern is to
split off a new key for every function call that needs random values:
```python
key = random.PRNGKey(0)
key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,))) # [ 1.1378783 -1.22095478 -0.59153646]
key, subkey = random.split(key)
print(random.normal(subkey, shape=(3,))) # [-0.06607265 0.16676566 1.17800343]
```
By splitting the PRNG key, not only do we avoid having to thread random states
back out of every function call, but also we can generate multiple random arrays
in parallel because we can avoid unnecessary sequential dependencies.
2018-12-07 11:38:53 -05:00
There's a gotcha here, which is that it's easy to unintentionally reuse a key
without splitting. We intend to add a check for this (a sort of dynamic linear
typing) but for now it's something to be careful about.
2019-01-18 07:03:12 -08:00
For more detailed information on the design and the reasoning behind it, see the
[PRNG design doc](design_notes/prng.md).
2018-12-07 11:38:53 -05:00
2018-12-07 07:39:16 -08:00
## Mini-libraries
JAX provides some small, experimental libraries for machine learning. These
libraries are in part about providing tools and in part about serving as
examples for how to build such libraries using JAX. Each one is only a few
hundred lines of code, so take a look inside and adapt them as you need!
### Neural-net building with Stax
**Stax** is a functional neural network building library. The basic idea is that
a single layer or an entire network can be modeled as an `(init_fun, apply_fun)`
pair. The `init_fun` is used to initialize network parameters and the
`apply_fun` takes parameters and inputs to produce outputs. There are
constructor functions for common basic pairs, like `Conv` and `Relu`, and these
2018-12-07 11:38:53 -05:00
pairs can be composed in series using `stax.serial` or in parallel using
2018-12-07 07:39:16 -08:00
`stax.parallel`.
Heres an example:
```python
2018-12-18 16:31:51 -08:00
import jax.numpy as np
2019-04-03 12:54:02 +01:00
from jax import random
2018-12-07 07:39:16 -08:00
from jax.experimental import stax
2018-12-18 16:31:51 -08:00
from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax
2018-12-07 07:39:16 -08:00
2018-12-18 16:31:51 -08:00
# Use stax to set up network initialization and evaluation functions
2018-12-07 07:39:16 -08:00
net_init, net_apply = stax.serial(
Conv(32, (3, 3), padding='SAME'), Relu,
2018-12-07 16:16:37 -05:00
Conv(64, (3, 3), padding='SAME'), Relu,
2018-12-07 07:39:16 -08:00
MaxPool((2, 2)), Flatten,
Dense(128), Relu,
2018-12-18 16:31:51 -08:00
Dense(10), LogSoftmax,
2018-12-07 07:39:16 -08:00
)
# Initialize parameters, not committing to a batch shape
2019-04-03 12:54:02 +01:00
rng = random.PRNGKey(0)
2018-12-18 16:31:51 -08:00
in_shape = (-1, 28, 28, 1)
2019-04-03 12:54:02 +01:00
out_shape, net_params = net_init(rng, in_shape)
2018-12-07 07:39:16 -08:00
2018-12-18 16:31:51 -08:00
# Apply network to dummy inputs
inputs = np.zeros((128, 28, 28, 1))
2018-12-07 07:39:16 -08:00
predictions = net_apply(net_params, inputs)
```
2019-02-06 11:02:16 -08:00
### First-order optimization
2018-12-07 07:39:16 -08:00
2019-02-06 11:02:16 -08:00
JAX has a minimal optimization library focused on stochastic first-order
optimizers. Every optimizer is modeled as an `(init_fun, update_fun,
get_params)` triple of functions. The `init_fun` is used to initialize the
optimizer state, which could include things like momentum variables, and the
`update_fun` accepts a gradient and an optimizer state to produce a new
optimizer state. The `get_params` function extracts the current iterate (i.e.
the current parameters) from the optimizer state. The parameters being optimized
2018-12-07 07:39:16 -08:00
can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can
store your parameters however youd like.
Heres an example, using `jit` to compile the whole update end-to-end:
```python
2019-02-06 11:02:16 -08:00
from jax.experimental import optimizers
2018-12-18 16:31:51 -08:00
from jax import jit, grad
2018-12-07 07:39:16 -08:00
2018-12-18 16:31:51 -08:00
# Define a simple squared-error loss
def loss(params, batch):
inputs, targets = batch
predictions = net_apply(params, inputs)
return np.sum((predictions - targets)**2)
2019-02-06 11:02:16 -08:00
# Use optimizers to set optimizer initialization and update functions
opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9)
2018-12-07 07:39:16 -08:00
# Define a compiled update step
@jit
def step(i, opt_state, batch):
params = get_params(opt_state)
2018-12-07 07:39:16 -08:00
g = grad(loss)(params, batch)
return opt_update(i, g, opt_state)
2018-12-18 16:31:51 -08:00
# Dummy input data stream
data_generator = ((np.zeros((128, 28, 28, 1)), np.zeros((128, 10)))
for _ in range(10))
2018-12-07 07:39:16 -08:00
# Optimize parameters in a loop
opt_state = opt_init(net_params)
2018-12-18 16:31:51 -08:00
for i in range(10):
2018-12-07 07:39:16 -08:00
opt_state = step(i, opt_state, next(data_generator))
net_params = get_params(opt_state)
2018-12-07 07:39:16 -08:00
```
## How it works
Programming in machine learning is about expressing and transforming functions.
Transformations include automatic differentiation, compilation for accelerators,
2018-12-07 11:38:53 -05:00
and automatic batching. High-level languages like Python are great for
2018-12-07 07:39:16 -08:00
expressing functions, but usually all we can do with them is apply them. We lose
access to their internal structure which would let us perform transformations.
JAX is a tool for specializing and translating high-level Python+NumPy functions
into a representation that can be transformed and then lifted back into a Python
function.
![simplified-lifecycle](https://raw.githubusercontent.com/google/jax/master/images/lifecycle.png)
JAX specializes Python functions by tracing. Tracing a function means monitoring
all the basic operations that are applied to its input to produce its output,
and recording these operations and the data-flow between them in a directed
acyclic graph (DAG). To perform tracing, JAX wraps primitive operations, like
basic numerical kernels, so that when theyre called they add themselves to a
list of operations performed along with their inputs and outputs. To keep track
of how data flows between these primitives, values being tracked are wrapped in
instances of the `Tracer` class.
When a Python function is provided to `grad` or `jit`, its wrapped for tracing
and returned. When the wrapped function is called, we abstract the concrete
arguments provided into instances of the `AbstractValue` class, box them for
tracing in instances of the `Tracer` class, and call the function on them.
Abstract arguments represent sets of possible values rather than specific
values: for example, `jit` abstracts ndarray arguments to abstract values that
represent all ndarrays with the same shape and dtype. In contrast, `grad`
2018-12-07 11:38:53 -05:00
abstracts ndarray arguments to represent an infinitesimal neighborhood of the
underlying
2018-12-07 07:39:16 -08:00
value. By tracing the Python function on these abstract values, we ensure that
its specialized enough so that its tractable to transform, and that its still
2018-12-07 11:38:53 -05:00
general enough so that the transformed result is useful, and possibly reusable.
These transformed functions are then lifted back into Python callables in a way
that allows them to be traced and transformed again as needed.
2018-12-07 07:39:16 -08:00
The primitive functions that JAX traces are mostly in 1:1 correspondence with
[XLA HLO](https://www.tensorflow.org/xla/operation_semantics) and are defined
2019-08-07 09:26:46 +02:00
in [lax.py](https://github.com/google/jax/blob/master/jax/lax/lax.py). This 1:1
2018-12-07 07:39:16 -08:00
correspondence makes most of the translations to XLA essentially trivial, and
ensures we only have a small set of primitives to cover for other
transformations like automatic differentiation. The [`jax.numpy`
layer](https://github.com/google/jax/blob/master/jax/numpy/) is written in pure
Python simply by expressing NumPy functions in terms of the LAX functions (and
other NumPy functions weve already written). That makes `jax.numpy` easy to
extend.
When you use `jax.numpy`, the underlying LAX primitives are `jit`-compiled
behind the scenes, allowing you to write unrestricted Python+Numpy code while
still executing each primitive operation on an accelerator.
But JAX can do more: instead of just compiling and dispatching to a fixed set of
individual primitives, you can use `jit` on larger and larger functions to be
end-to-end compiled and optimized. For example, instead of just compiling and
dispatching a convolution op, you can compile a whole network, or a whole
gradient evaluation and optimizer update step.
The tradeoff is that `jit` functions have to satisfy some additional
specialization requirements: since we want to compile traces that are
specialized on shapes and dtypes, but not specialized all the way to concrete
values, the Python code under a `jit` decorator must be applicable to abstract
values. If we try to evaluate `x > 0` on an abstract `x`, the result is an
abstract value representing the set `{True, False}`, and so a Python branch like
2019-04-03 12:54:02 +01:00
`if x > 0` will raise an error: it doesnt know which way to go!
2018-12-07 07:39:16 -08:00
See [Whats supported](#whats-supported) for more
information about `jit` requirements.
The good news about this tradeoff is that `jit` is opt-in: JAX libraries use
`jit` on individual operations and functions behind the scenes, allowing you to
write unrestricted Python+Numpy and still make use of a hardware accelerator.
But when you want to maximize performance, you can often use `jit` in your own
code to compile and end-to-end optimize much bigger functions.
## What we're working on
1. Documentation!
2. Cloud TPU support
3. Multi-GPU and multi-TPU support
4. Full NumPy coverage and some SciPy coverage
5. Full coverage for vmap
6. Make everything faster
* Lowering the XLA function dispatch overhead
* Linear algebra routines (MKL on CPU, MAGMA on GPU)
7. `cond` and `while` primitives with efficient automatic differentiation
## Current gotchas
For a survey of current gotchas, with examples and explanations, we highly
recommend reading the [Gotchas Notebook](https://colab.research.google.com/github/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb).
2019-03-24 07:59:50 -07:00
2019-04-05 12:59:52 -07:00
Some stand-out gotchas that might surprise NumPy users:
1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and
to enable double-precision (64-bit, e.g. `float64`) one needs to set the
`jax_enable_x64` variable **at startup** (or set the environment variable
2019-08-22 13:13:52 +08:00
`JAX_ENABLE_X64=True`, see [the Gotchas Notebook](https://colab.research.google.com/github/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb#scrollTo=YTktlwTTMgFl))
2. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
np.float32)).dtype` is `float64` rather than `float32`.
3. In-place mutation of arrays isn't supported, though [there is an
alternative](https://jax.readthedocs.io/en/latest/jax.ops.html). Generally
JAX requires functional code.
4. PRNGs are different and can be awkward, though for [good
reasons](https://github.com/google/jax/blob/master/design_notes/prng.md), and
non-reuse (linearity) is not yet checked.
2019-03-24 07:59:50 -07:00
See [the notebook](https://colab.research.google.com/github/google/jax/blob/master/notebooks/Common_Gotchas_in_JAX.ipynb) for much more information.
2018-12-07 07:39:16 -08:00
## Contributors
2019-03-24 07:59:50 -07:00
So far, JAX includes lots of help and [contributions](https://github.com/google/jax/graphs/contributors). In addition to the code contributions reflected on GitHub, JAX has benefitted substantially from the advice of
[Jamie Townsend](https://github.com/j-towns),
[Peter Hawkins](https://github.com/hawkinsp),
2018-12-09 01:24:38 -05:00
[Jonathan Ragan-Kelley](https://people.eecs.berkeley.edu/~jrk/),
[Alex Wiltschko](http://github.com/alexbw),
George Dahl,
2018-12-09 01:24:38 -05:00
[Stephan Hoyer](http://stephanhoyer.com/),
Sam Schoenholz,
[Eli Bendersky](https://github.com/eliben),
Zak Stone,
[Alexey Radul](https://github.com/axch),
Michael Isard,
Skye Wanderman-Milne,
and many others.