diff --git a/README.md b/README.md index 5b2ab1b2c..2b7b0184f 100644 --- a/README.md +++ b/README.md @@ -60,261 +60,13 @@ grad_fun = jit(grad(logprob_fun)) # compiled gradient evaluation function perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0))) # fast per-example grads ``` -JAX started as a research project by [Matt Johnson](https://github.com/mattjj), -[Roy Frostig](https://github.com/froystig), [Dougal -Maclaurin](https://github.com/dougalm), and [Chris -Leary](https://github.com/learyg), and is now developed [in the -open](https://github.com/google/jax) by a growing number of -[contributors](#contributors). - ### Contents +* [Transformations](#transformations) * [Quickstart: Colab in the Cloud](#quickstart-colab-in-the-cloud) * [Installation](#installation) -* [Reference documentation](#reference-documentation) -* [A brief tour](#a-brief-tour) -* [What's supported](#whats-supported) -* [Transformations](#transformations) -* [Random numbers are different](#random-numbers-are-different) -* [Mini-libraries](#mini-libraries) -* [How it works](#how-it-works) -* [Current gotchas](#current-gotchas) * [Citing JAX](#citing-jax) +* [Reference documentation](#reference-documentation) -## Quickstart: Colab in the Cloud -Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: -- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) -- [Training a Simple Neural Network, with PyTorch Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) -- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb) - -And for a deeper dive into JAX: -- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) -- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) -- [Directly using XLA in Python](https://jax.readthedocs.io/en/latest/notebooks/XLA_in_Python.html) -- [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) -- [MAML Tutorial with JAX](https://jax.readthedocs.io/en/latest/notebooks/maml.html) -- [Generative Modeling by Estimating Gradients of Data Distribution in JAX](https://jax.readthedocs.io/en/latest/notebooks/score_matching.html). - - -## Installation -JAX is written in pure Python, but it depends on XLA, which needs to be compiled -and installed as the `jaxlib` package. Use the following instructions to -install a binary package with `pip`, or to build JAX from source. - -We support installing or building `jaxlib` on Linux (Ubuntu 16.04 or later) and -macOS (10.12 or later) platforms, but not yet Windows. We're not currently -working on Windows support, but contributions are welcome -(see [#438](https://github.com/google/jax/issues/438)). Some users have reported -success with building a CPU-only `jaxlib` from source using the Windows Subsytem -for Linux. - -### pip installation - -To install a CPU-only version, which might be useful for doing local -development on a laptop, you can run - -```bash -pip install --upgrade pip -pip install --upgrade jax jaxlib # CPU-only version -``` - -On Linux, it is often necessary to first update `pip` to a version that supports -`manylinux2010` wheels. - -If you want to install JAX with both CPU and GPU support, using existing CUDA -and CUDNN7 installations on your machine (for example, preinstalled on your -cloud VM), you can run - -```bash -# install jaxlib -PYTHON_VERSION=cp37 # alternatives: cp27, cp35, cp36, cp37 -CUDA_VERSION=cuda92 # alternatives: cuda90, cuda92, cuda100, cuda101 -PLATFORM=linux_x86_64 # alternatives: linux_x86_64 -BASE_URL='https://storage.googleapis.com/jax-releases' -pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.36-$PYTHON_VERSION-none-$PLATFORM.whl - -pip install --upgrade jax # install jax -``` - -The library package name must correspond to the version of the existing CUDA -installation you want to use, with `cuda101` for CUDA 10.1, `cuda100` for CUDA -10.0, `cuda92` for CUDA 9.2, and `cuda90` for CUDA 9.0. To find your CUDA and -CUDNN versions, you can run commands like these, depending on your CUDNN install -path: - -```bash -nvcc --version -grep CUDNN_MAJOR -A 2 /usr/local/cuda/include/cudnn.h # might need different path -``` - -The Python version must match your Python interpreter. There are prebuilt wheels -for Python 2.7, 3.5, 3.6, and 3.7; for anything else, you must build from -source. - -Please let us know on [the issue tracker](https://github.com/google/jax/issues) -if you run into any errors or problems with the prebuilt wheels. - -### Building JAX from source -See [Building JAX from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). - - -## Reference documentation - -For details about the JAX API, see the -[reference documentation](https://jax.readthedocs.io/). - -## Developer documentation - -For getting started as a JAX developer, see the -[developer documentation](https://jax.readthedocs.io/en/latest/developer.html). - -## A brief tour - -```python -In [1]: import jax.numpy as np - -In [2]: from jax import random - -In [3]: key = random.PRNGKey(0) - -In [4]: x = random.normal(key, (5000, 5000)) - -In [5]: print(np.dot(x, x.T) / 2) # fast! -[[ 2.52727051e+03 8.15895557e+00 -8.53276134e-01 ..., # ... - -In [6]: print(np.dot(x, x.T) / 2) # even faster! -# JIT-compiled code is cached and reused in the 2nd call -[[ 2.52727051e+03 8.15895557e+00 -8.53276134e-01 ..., # ... -``` - -What’s happening behind-the-scenes is that JAX is using XLA to just-in-time -(JIT) compile and execute these individual operations on the GPU. First the -`random.normal` call is compiled and the array referred to by `x` is generated -on the GPU. Next, each function called on `x` (namely `transpose`, `dot`, and -`divide`) is individually JIT-compiled and executed, each keeping its results on -the device. -It’s only when a value needs to be printed, plotted, saved, or passed into a raw -NumPy function that a read-only copy of the value is brought back to the host as -an ndarray and cached. The second call to `dot` is faster because the -JIT-compiled code is cached and reused, saving the compilation time. - -The fun really starts when you use `grad` for automatic differentiation and -`jit` to compile your own functions end-to-end. Here’s a more complete toy -example: - -```python -from jax import grad, jit -import jax.numpy as np - -def sigmoid(x): - return 0.5 * (np.tanh(x / 2.) + 1) - -# Outputs probability of a label being true according to logistic model. -def logistic_predictions(weights, inputs): - return sigmoid(np.dot(inputs, weights)) - -# Training loss is the negative log-likelihood of the training labels. -def loss(weights, inputs, targets): - preds = logistic_predictions(weights, inputs) - label_logprobs = np.log(preds) * targets + np.log(1 - preds) * (1 - targets) - return -np.sum(label_logprobs) - -# Build a toy dataset. -inputs = np.array([[0.52, 1.12, 0.77], - [0.88, -1.08, 0.15], - [0.52, 0.06, -1.30], - [0.74, -2.49, 1.39]]) -targets = np.array([True, True, False, True]) - -# Define a compiled function that returns gradients of the training loss -training_gradient_fun = jit(grad(loss)) - -# Optimize weights using gradient descent. -weights = np.array([0.0, 0.0, 0.0]) -print("Initial loss: {:0.2f}".format(loss(weights, inputs, targets))) -for i in range(100): - weights -= 0.1 * training_gradient_fun(weights, inputs, targets) - -print("Trained loss: {:0.2f}".format(loss(weights, inputs, targets))) -``` - -To see more, check out the [quickstart -notebook](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html), -a [simple MNIST classifier -example](https://github.com/google/jax/blob/master/examples/mnist_classifier.py) -and the rest of the [JAX -examples](https://github.com/google/jax/blob/master/examples/). - -## What's supported - -If you’re using JAX just as an accelerator-backed NumPy, without using `grad` or -`jit` in your code, then in principle there are no constraints, though some -NumPy functions haven’t been implemented yet. A list of supported functions can -be found in the [reference documentation](https://jax.readthedocs.io/). - -Generally using `np.dot(A, B)` is -better than `A.dot(B)` because the former gives us more opportunities to run the -computation on the device. NumPy also does a lot of work to cast any array-like -function arguments to arrays, as in `np.sum([x, y])`, while `jax.numpy` -typically requires explicit casting of array arguments, like -`np.sum(np.array([x, y]))`. - -For automatic differentiation with `grad`, JAX has the same restrictions -as [Autograd](https://github.com/hips/autograd). Specifically, differentiation -works with indexing (`x = A[i, j, :]`) but not indexed assignment (`A[i, j] = -x`) or indexed in-place updating (`A[i] += b`) (use -[`jax.ops.index_update`](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update) -or -[`jax.ops.index_add`](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_add.html#jax.ops.index_add) -instead). You can use lists, tuples, and -dicts freely: JAX doesn't even see them. Using `np.dot(A, B)` rather than -`A.dot(B)` is required for automatic differentiation when `A` is a raw ndarray. - -For compiling your own functions with `jit` there are a few more requirements. -Because `jit` aims to specialize Python functions only on shapes and dtypes -during tracing, rather than on concrete values, Python control flow that depends -on concrete values won’t be able to execute and will instead raise an error. If -you want compiled control flow, use structured control flow primitives like -`lax.cond` and `lax.while_loop`. Some indexing features, like slice-based -indexing, e.g. `A[i:i+5]` for argument-dependent `i`, or boolean-based indexing, -e.g. `A[bool_ind]` for argument-dependent `bool_ind`, produce abstract values of -unknown shape and are thus unsupported in `jit` functions. - -In general, JAX is intended to be used with a functional style of Python -programming. Functions passed to transformations like `grad` and `jit` are -expected to be free of side-effects. You can write print statements for -debugging but they may only be executed once if they're under a `jit` decorator. - -> TLDR **Do use** -> -> * Functional programming -> * [Many](https://jax.readthedocs.io/en/latest/jax.numpy.html) of NumPy’s -> functions (help us add more!) -> * [Some](https://jax.readthedocs.io/en/latest/jax.scipy.html) SciPy functions -> * Indexing and slicing of arrays like `x = A[[5, 1, 7], :, 2:4]` -> * Explicit array creation from lists like `A = np.array([x, y])` -> -> **Don’t use** -> -> * Assignment into arrays like `A[0, 0] = x` (use -> [`jax.ops.index_update`](https://jax.readthedocs.io/en/latest/_autosummary/jax.ops.index_update.html#jax.ops.index_update) -> instead) -> * Implicit casting to arrays like `np.sum([x, y])` (use `np.sum(np.array([x, -> y])` instead) -> * `A.dot(B)` method syntax for functions of more than one argument (use -> `np.dot(A, B)` instead) -> * Side-effects like mutation of arguments or mutation of global variables -> * The `out` argument of NumPy functions -> * Dtype casting like `np.float64(x)` (use `x.astype('float64')` or -> `x.astype(np.float64)` instead). -> -> **For jit functions, also don’t use** -> -> * Control flow based on dynamic values `if x > 0: ...`. Control flow based -> on shapes is fine: `if x.shape[0] > 2: ...` and `for subarr in array`. -> * Slicing `A[i:i+5]` for dynamic index `i` (use `lax.dynamic_slice` instead) -> or boolean indexing `A[bool_ind]` for traced values `bool_ind`. - -You should get loud errors if your code violates any of these. ## Transformations @@ -452,248 +204,83 @@ differentiation for fast Jacobian and Hessian matrix calculations in `jax.jacfwd`, `jax.jacrev`, and `jax.hessian`. -## Random numbers are different +## Quickstart: Colab in the Cloud +Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks: +- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) +- [Training a Simple Neural Network, with PyTorch Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/Neural_Network_and_Data_Loading.ipynb) +- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/master/docs/notebooks/neural_network_with_tfds_data.ipynb) -JAX needs a [functional pseudo-random number generator (PRNG) system](design_notes/prng.md) to provide -reproducible results invariant to compilation boundaries and backends, while -also maximizing performance by enabling vectorized generation and -parallelization across random calls. The `numpy.random` library doesn’t have -those properties. The `jax.random` library meets those needs: it’s functionally -pure, but it doesn’t require you to pass stateful random objects back out of -every function. +And for a deeper dive into JAX: +- [Common gotchas and sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) +- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) +- [Directly using XLA in Python](https://jax.readthedocs.io/en/latest/notebooks/XLA_in_Python.html) +- [How JAX primitives work](https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html) +- [MAML Tutorial with JAX](https://jax.readthedocs.io/en/latest/notebooks/maml.html) +- [Generative Modeling by Estimating Gradients of Data Distribution in JAX](https://jax.readthedocs.io/en/latest/notebooks/score_matching.html). -The `jax.random` library uses -[count-based PRNGs](http://www.thesalmons.org/john/random123/papers/random123sc11.pdf) -and a functional array-oriented -[splitting model](http://publications.lib.chalmers.se/records/fulltext/183348/local_183348.pdf). -To generate random values, you call a function like `jax.random.normal` and give -it a PRNG key: -```python -import jax.random as random +## Installation +JAX is written in pure Python, but it depends on XLA, which needs to be compiled +and installed as the `jaxlib` package. Use the following instructions to +install a binary package with `pip`, or to build JAX from source. -key = random.PRNGKey(0) -print(random.normal(key, shape=(3,))) # [ 1.81608593 -0.48262325 0.33988902] +We support installing or building `jaxlib` on Linux (Ubuntu 16.04 or later) and +macOS (10.12 or later) platforms, but not yet Windows. We're not currently +working on Windows support, but contributions are welcome +(see [#438](https://github.com/google/jax/issues/438)). Some users have reported +success with building a CPU-only `jaxlib` from source using the Windows Subsytem +for Linux. + +### pip installation + +To install a CPU-only version, which might be useful for doing local +development on a laptop, you can run + +```bash +pip install --upgrade pip +pip install --upgrade jax jaxlib # CPU-only version ``` -If we make the same call again with the same key, we get the same values: +On Linux, it is often necessary to first update `pip` to a version that supports +`manylinux2010` wheels. -```python -print(random.normal(key, shape=(3,))) # [ 1.81608593 -0.48262325 0.33988902] +If you want to install JAX with both CPU and GPU support, using existing CUDA +and CUDNN7 installations on your machine (for example, preinstalled on your +cloud VM), you can run + +```bash +# install jaxlib +PYTHON_VERSION=cp37 # alternatives: cp27, cp35, cp36, cp37 +CUDA_VERSION=cuda92 # alternatives: cuda90, cuda92, cuda100, cuda101 +PLATFORM=linux_x86_64 # alternatives: linux_x86_64 +BASE_URL='https://storage.googleapis.com/jax-releases' +pip install --upgrade $BASE_URL/$CUDA_VERSION/jaxlib-0.1.36-$PYTHON_VERSION-none-$PLATFORM.whl + +pip install --upgrade jax # install jax ``` -The key never gets updated. So how do we get fresh random values? We use -`jax.random.split` to create new keys from existing ones. A common pattern is to -split off a new key for every function call that needs random values: +The library package name must correspond to the version of the existing CUDA +installation you want to use, with `cuda101` for CUDA 10.1, `cuda100` for CUDA +10.0, `cuda92` for CUDA 9.2, and `cuda90` for CUDA 9.0. To find your CUDA and +CUDNN versions, you can run commands like these, depending on your CUDNN install +path: -```python -key = random.PRNGKey(0) - -key, subkey = random.split(key) -print(random.normal(subkey, shape=(3,))) # [ 1.1378783 -1.22095478 -0.59153646] - -key, subkey = random.split(key) -print(random.normal(subkey, shape=(3,))) # [-0.06607265 0.16676566 1.17800343] +```bash +nvcc --version +grep CUDNN_MAJOR -A 2 /usr/local/cuda/include/cudnn.h # might need different path ``` -By splitting the PRNG key, not only do we avoid having to thread random states -back out of every function call, but also we can generate multiple random arrays -in parallel because we can avoid unnecessary sequential dependencies. +The Python version must match your Python interpreter. There are prebuilt wheels +for Python 2.7, 3.5, 3.6, and 3.7; for anything else, you must build from +source. -There's a gotcha here, which is that it's easy to unintentionally reuse a key -without splitting. We intend to add a check for this (a sort of dynamic linear -typing) but for now it's something to be careful about. +Please let us know on [the issue tracker](https://github.com/google/jax/issues) +if you run into any errors or problems with the prebuilt wheels. -For more detailed information on the design and the reasoning behind it, see the -[PRNG design doc](design_notes/prng.md). +### Building JAX from source +See [Building JAX from source](https://jax.readthedocs.io/en/latest/developer.html#building-from-source). -## Mini-libraries - -JAX provides some small, experimental libraries for machine learning. These -libraries are in part about providing tools and in part about serving as -examples for how to build such libraries using JAX. Each one is only a few -hundred lines of code, so take a look inside and adapt them as you need! - -### Neural-net building with Stax - -**Stax** is a functional neural network building library. The basic idea is that -a single layer or an entire network can be modeled as an `(init_fun, apply_fun)` -pair. The `init_fun` is used to initialize network parameters and the -`apply_fun` takes parameters and inputs to produce outputs. There are -constructor functions for common basic pairs, like `Conv` and `Relu`, and these -pairs can be composed in series using `stax.serial` or in parallel using -`stax.parallel`. - -Here’s an example: - -```python -import jax.numpy as np -from jax import random -from jax.experimental import stax -from jax.experimental.stax import Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax - -# Use stax to set up network initialization and evaluation functions -net_init, net_apply = stax.serial( - Conv(32, (3, 3), padding='SAME'), Relu, - Conv(64, (3, 3), padding='SAME'), Relu, - MaxPool((2, 2)), Flatten, - Dense(128), Relu, - Dense(10), LogSoftmax, -) - -# Initialize parameters, not committing to a batch shape -rng = random.PRNGKey(0) -in_shape = (-1, 28, 28, 1) -out_shape, net_params = net_init(rng, in_shape) - -# Apply network to dummy inputs -inputs = np.zeros((128, 28, 28, 1)) -predictions = net_apply(net_params, inputs) -``` - -### First-order optimization - -JAX has a minimal optimization library focused on stochastic first-order -optimizers. Every optimizer is modeled as an `(init_fun, update_fun, -get_params)` triple of functions. The `init_fun` is used to initialize the -optimizer state, which could include things like momentum variables, and the -`update_fun` accepts a gradient and an optimizer state to produce a new -optimizer state. The `get_params` function extracts the current iterate (i.e. -the current parameters) from the optimizer state. The parameters being optimized -can be ndarrays or arbitrarily-nested list/tuple/dict structures, so you can -store your parameters however you’d like. - -Here’s an example, using `jit` to compile the whole update end-to-end: - -```python -from jax.experimental import optimizers -from jax import jit, grad - -# Define a simple squared-error loss -def loss(params, batch): - inputs, targets = batch - predictions = net_apply(params, inputs) - return np.sum((predictions - targets)**2) - -# Use optimizers to set optimizer initialization and update functions -opt_init, opt_update, get_params = optimizers.momentum(step_size=1e-3, mass=0.9) - -# Define a compiled update step -@jit -def step(i, opt_state, batch): - params = get_params(opt_state) - g = grad(loss)(params, batch) - return opt_update(i, g, opt_state) - -# Dummy input data stream -data_generator = ((np.zeros((128, 28, 28, 1)), np.zeros((128, 10))) - for _ in range(10)) - -# Optimize parameters in a loop -opt_state = opt_init(net_params) -for i in range(10): - opt_state = step(i, opt_state, next(data_generator)) -net_params = get_params(opt_state) -``` - -## How it works - -Programming in machine learning is about expressing and transforming functions. -Transformations include automatic differentiation, compilation for accelerators, -and automatic batching. High-level languages like Python are great for -expressing functions, but usually all we can do with them is apply them. We lose -access to their internal structure which would let us perform transformations. - -JAX is a tool for specializing and translating high-level Python+NumPy functions -into a representation that can be transformed and then lifted back into a Python -function. - -![simplified-lifecycle](https://raw.githubusercontent.com/google/jax/master/images/lifecycle.png) - -JAX specializes Python functions by tracing. Tracing a function means monitoring -all the basic operations that are applied to its input to produce its output, -and recording these operations and the data-flow between them in a directed -acyclic graph (DAG). To perform tracing, JAX wraps primitive operations, like -basic numerical kernels, so that when they’re called they add themselves to a -list of operations performed along with their inputs and outputs. To keep track -of how data flows between these primitives, values being tracked are wrapped in -instances of the `Tracer` class. - -When a Python function is provided to `grad` or `jit`, it’s wrapped for tracing -and returned. When the wrapped function is called, we abstract the concrete -arguments provided into instances of the `AbstractValue` class, box them for -tracing in instances of the `Tracer` class, and call the function on them. -Abstract arguments represent sets of possible values rather than specific -values: for example, `jit` abstracts ndarray arguments to abstract values that -represent all ndarrays with the same shape and dtype. In contrast, `grad` -abstracts ndarray arguments to represent an infinitesimal neighborhood of the -underlying -value. By tracing the Python function on these abstract values, we ensure that -it’s specialized enough so that it’s tractable to transform, and that it’s still -general enough so that the transformed result is useful, and possibly reusable. -These transformed functions are then lifted back into Python callables in a way -that allows them to be traced and transformed again as needed. - -The primitive functions that JAX traces are mostly in 1:1 correspondence with -[XLA HLO](https://www.tensorflow.org/xla/operation_semantics) and are defined -in [lax.py](https://github.com/google/jax/blob/master/jax/lax/lax.py). This 1:1 -correspondence makes most of the translations to XLA essentially trivial, and -ensures we only have a small set of primitives to cover for other -transformations like automatic differentiation. The [`jax.numpy` -layer](https://github.com/google/jax/blob/master/jax/numpy/) is written in pure -Python simply by expressing NumPy functions in terms of the LAX functions (and -other NumPy functions we’ve already written). That makes `jax.numpy` easy to -extend. - -When you use `jax.numpy`, the underlying LAX primitives are `jit`-compiled -behind the scenes, allowing you to write unrestricted Python+Numpy code while -still executing each primitive operation on an accelerator. - -But JAX can do more: instead of just compiling and dispatching to a fixed set of -individual primitives, you can use `jit` on larger and larger functions to be -end-to-end compiled and optimized. For example, instead of just compiling and -dispatching a convolution op, you can compile a whole network, or a whole -gradient evaluation and optimizer update step. - -The tradeoff is that `jit` functions have to satisfy some additional -specialization requirements: since we want to compile traces that are -specialized on shapes and dtypes, but not specialized all the way to concrete -values, the Python code under a `jit` decorator must be applicable to abstract -values. If we try to evaluate `x > 0` on an abstract `x`, the result is an -abstract value representing the set `{True, False}`, and so a Python branch like -`if x > 0` will raise an error: it doesn’t know which way to go! -See [What’s supported](#whats-supported) for more -information about `jit` requirements. - -The good news about this tradeoff is that `jit` is opt-in: JAX libraries use -`jit` on individual operations and functions behind the scenes, allowing you to -write unrestricted Python+Numpy and still make use of a hardware accelerator. -But when you want to maximize performance, you can often use `jit` in your own -code to compile and end-to-end optimize much bigger functions. - -## Current gotchas - -For a survey of current gotchas, with examples and explanations, we highly -recommend reading the [Gotchas Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). - -Some stand-out gotchas that might surprise NumPy users: -1. JAX enforces single-precision (32-bit, e.g. `float32`) values by default, and - to enable double-precision (64-bit, e.g. `float64`) one needs to set the - `jax_enable_x64` variable **at startup** (or set the environment variable - `JAX_ENABLE_X64=True`, see [the Gotchas Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#scrollTo=Double-(64bit)-precision)) -2. Some of NumPy's dtype promotion semantics involving a mix of Python scalars - and NumPy types aren't preserved, namely `np.add(1, np.array([2], - np.float32)).dtype` is `float64` rather than `float32`. -3. In-place mutation of arrays isn't supported, though [there is an - alternative](https://jax.readthedocs.io/en/latest/jax.ops.html). Generally - JAX requires functional code. -4. PRNGs are different and can be awkward, though for [good - reasons](https://github.com/google/jax/blob/master/design_notes/prng.md), and - non-reuse (linearity) is not yet checked. - -See [the notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) for much more information. - ## Citing JAX To cite this repository: @@ -718,19 +305,10 @@ compilation to XLA, was described in a [paper that appeared at SysML covering JAX's ideas and capabilities in a more comprehensive and up-to-date paper. -## Contributors +## Reference documentation -So far, JAX includes lots of help and [contributions](https://github.com/google/jax/graphs/contributors). In addition to the code contributions reflected on GitHub, JAX has benefitted substantially from the advice of -[Jamie Townsend](https://github.com/j-towns), -[Peter Hawkins](https://github.com/hawkinsp), -[Jonathan Ragan-Kelley](https://people.eecs.berkeley.edu/~jrk/), -[Alex Wiltschko](http://github.com/alexbw), -George Dahl, -[Stephan Hoyer](http://stephanhoyer.com/), -Sam Schoenholz, -[Eli Bendersky](https://github.com/eliben), -Zak Stone, -[Alexey Radul](https://github.com/axch), -Michael Isard, -Skye Wanderman-Milne, -and many others. +For details about the JAX API, see the +[reference documentation](https://jax.readthedocs.io/). + +For getting started as a JAX developer, see the +[developer documentation](https://jax.readthedocs.io/en/latest/developer.html).