2018-12-08 08:05:16 -05:00
< div align = "center" >
2024-09-20 07:51:48 -07:00
< img src = "https://raw.githubusercontent.com/jax-ml/jax/main/images/jax_logo_250px.png" alt = "logo" > < / img >
2018-12-08 08:05:16 -05:00
< / div >
2018-11-21 19:47:32 -08:00
2024-09-10 17:15:29 +00:00
# Transformable numerical computing at scale
2020-12-30 22:00:20 +00:00
2024-10-02 10:09:00 -07:00
[](https://github.com/jax-ml/jax/actions/workflows/ci-build.yaml)
[](https://pypi.org/project/jax/)
2018-12-07 07:39:16 -08:00
2019-12-14 07:00:39 -08:00
[**Quickstart** ](#quickstart-colab-in-the-cloud )
| [**Transformations** ](#transformations )
2019-04-16 09:47:26 -04:00
| [**Install guide** ](#installation )
2020-06-12 10:35:59 -07:00
| [**Neural net libraries** ](#neural-network-libraries )
2021-03-25 19:15:45 +00:00
| [**Change logs** ](https://jax.readthedocs.io/en/latest/changelog.html )
2019-12-14 07:00:39 -08:00
| [**Reference docs** ](https://jax.readthedocs.io/en/latest/ )
2019-04-16 09:47:26 -04:00
2020-07-29 16:22:12 -04:00
2020-01-28 09:57:15 -05:00
## What is JAX?
2024-02-26 14:14:11 -08:00
JAX is a Python library for accelerator-oriented array computation and program transformation,
designed for high-performance numerical computing and large-scale machine learning.
2018-12-07 07:39:16 -08:00
2018-12-10 07:13:51 -08:00
With its updated version of [Autograd ](https://github.com/hips/autograd ),
JAX can automatically differentiate native
2018-12-07 10:48:48 -05:00
Python and NumPy functions. It can differentiate through loops, branches,
recursion, and closures, and it can take derivatives of derivatives of
derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation)
2018-12-12 19:01:32 -08:00
via [`grad` ](#automatic-differentiation-with-grad ) as well as forward-mode differentiation,
and the two can be composed arbitrarily to any order.
2018-12-07 07:39:16 -08:00
2021-03-08 16:25:04 -08:00
What’ s new is that JAX uses [XLA ](https://www.tensorflow.org/xla )
2018-12-07 11:38:53 -05:00
to compile and run your NumPy programs on GPUs and TPUs. Compilation happens
2018-12-07 10:48:48 -05:00
under the hood by default, with library calls getting just-in-time compiled and
executed. But JAX also lets you just-in-time compile your own Python functions
into XLA-optimized kernels using a one-function API,
[`jit` ](#compilation-with-jit ). Compilation and automatic differentiation can be
composed arbitrarily, so you can express sophisticated algorithms and get
2019-12-14 08:34:01 -08:00
maximal performance without leaving Python. You can even program multiple GPUs
or TPU cores at once using [`pmap` ](#spmd-programming-with-pmap ), and
differentiate through the whole thing.
2018-12-12 19:01:32 -08:00
Dig a little deeper, and you'll see that JAX is really an extensible system for
2018-12-12 21:32:30 -08:00
[composable function transformations ](#transformations ). Both
2018-12-12 19:01:32 -08:00
[`grad` ](#automatic-differentiation-with-grad ) and [`jit` ](#compilation-with-jit )
2019-12-14 07:00:39 -08:00
are instances of such transformations. Others are
[`vmap` ](#auto-vectorization-with-vmap ) for automatic vectorization and
[`pmap` ](#spmd-programming-with-pmap ) for single-program multiple-data (SPMD)
2019-12-14 08:34:01 -08:00
parallel programming of multiple accelerators, with more to come.
2018-12-07 07:39:16 -08:00
2024-12-11 14:50:14 +01:00
This is a research project, not an official Google product. Expect
2019-10-29 08:53:35 +01:00
[sharp edges ](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html ).
2019-03-24 07:59:50 -07:00
Please help by trying it out, [reporting
2024-09-20 07:51:48 -07:00
bugs](https://github.com/jax-ml/jax/issues), and letting us know what you
2018-12-07 07:39:16 -08:00
think!
```python
2020-07-15 13:17:38 -07:00
import jax.numpy as jnp
2018-12-07 07:39:16 -08:00
from jax import grad, jit, vmap
def predict(params, inputs):
for W, b in params:
2020-07-15 13:17:38 -07:00
outputs = jnp.dot(inputs, W) + b
2021-09-14 11:47:19 -07:00
inputs = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
2018-12-07 07:39:16 -08:00
2021-09-14 11:47:19 -07:00
def loss(params, inputs, targets):
2018-12-07 07:39:16 -08:00
preds = predict(params, inputs)
2020-07-15 13:17:38 -07:00
return jnp.sum((preds - targets)**2)
2018-12-07 07:39:16 -08:00
2021-09-14 11:47:19 -07:00
grad_loss = jit(grad(loss)) # compiled gradient evaluation function
perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example grads
2018-12-07 07:39:16 -08:00
```
2018-12-07 21:51:36 -05:00
### Contents
* [Quickstart: Colab in the Cloud ](#quickstart-colab-in-the-cloud )
2019-12-14 07:00:39 -08:00
* [Transformations ](#transformations )
* [Current gotchas ](#current-gotchas )
2018-12-07 21:51:36 -05:00
* [Installation ](#installation )
2020-11-09 00:33:04 +01:00
* [Neural net libraries ](#neural-network-libraries )
2019-09-19 15:02:03 -07:00
* [Citing JAX ](#citing-jax )
2019-12-10 13:50:36 -08:00
* [Reference documentation ](#reference-documentation )
2018-12-07 21:51:36 -05:00
2019-12-14 07:00:39 -08:00
## Quickstart: Colab in the Cloud
2019-12-14 08:34:01 -08:00
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
Here are some starter notebooks:
2024-04-18 13:11:25 -07:00
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization ](https://jax.readthedocs.io/en/latest/quickstart.html )
2024-09-20 07:51:48 -07:00
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading ](https://colab.research.google.com/github/jax-ml/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb )
2019-12-14 07:00:39 -08:00
2019-12-14 08:34:01 -08:00
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU
2024-09-20 07:51:48 -07:00
Colabs](https://github.com/jax-ml/jax/tree/main/cloud_tpu_colabs).
2019-12-14 08:34:01 -08:00
For a deeper dive into JAX:
2019-12-14 07:00:39 -08:00
- [The Autodiff Cookbook, Part 1: easy and powerful automatic differentiation in JAX ](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html )
- [Common gotchas and sharp edges ](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html )
- See the [full list of
2024-09-20 07:51:48 -07:00
notebooks](https://github.com/jax-ml/jax/tree/main/docs/notebooks).
2020-01-11 16:31:59 -08:00
2018-12-07 07:39:16 -08:00
## Transformations
2018-12-07 11:38:53 -05:00
At its core, JAX is an extensible system for transforming numerical functions.
2021-08-02 17:57:09 -07:00
Here are four transformations of primary interest: `grad` , `jit` , `vmap` , and
`pmap` .
2018-12-07 07:39:16 -08:00
2019-12-14 07:00:39 -08:00
### Automatic differentiation with `grad`
2018-12-07 07:39:16 -08:00
JAX has roughly the same API as [Autograd ](https://github.com/hips/autograd ).
2019-12-14 07:00:39 -08:00
The most popular function is
[`grad` ](https://jax.readthedocs.io/en/latest/jax.html#jax.grad )
for reverse-mode gradients:
2018-12-07 07:39:16 -08:00
```python
from jax import grad
2020-07-15 13:17:38 -07:00
import jax.numpy as jnp
2018-12-07 07:39:16 -08:00
def tanh(x): # Define a function
2020-07-15 13:17:38 -07:00
y = jnp.exp(-2.0 * x)
2018-12-07 07:39:16 -08:00
return (1.0 - y) / (1.0 + y)
grad_tanh = grad(tanh) # Obtain its gradient function
print(grad_tanh(1.0)) # Evaluate it at x = 1.0
2019-12-14 07:00:39 -08:00
# prints 0.4199743
2018-12-07 07:39:16 -08:00
```
You can differentiate to any order with `grad` .
2019-12-14 07:00:39 -08:00
```python
print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673
```
For more advanced autodiff, you can use
[`jax.vjp` ](https://jax.readthedocs.io/en/latest/jax.html#jax.vjp ) for
reverse-mode vector-Jacobian products and
2020-01-15 15:00:38 -08:00
[`jax.jvp` ](https://jax.readthedocs.io/en/latest/jax.html#jax.jvp ) for
2019-12-14 07:00:39 -08:00
forward-mode Jacobian-vector products. The two can be composed arbitrarily with
one another, and with other JAX transformations. Here's one way to compose those
to make a function that efficiently computes [full Hessian
2023-03-26 19:21:52 -07:00
matrices](https://jax.readthedocs.io/en/latest/_autosummary/jax.hessian.html#jax .hessian):
2018-12-07 07:39:16 -08:00
```python
from jax import jit, jacfwd, jacrev
2019-12-14 07:00:39 -08:00
2018-12-07 07:39:16 -08:00
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
```
2019-12-14 07:00:39 -08:00
As with [Autograd ](https://github.com/hips/autograd ), you're free to use
differentiation with Python control structures:
2018-12-07 07:39:16 -08:00
```python
def abs_val(x):
if x > 0:
return x
else:
return -x
abs_val_grad = grad(abs_val)
2018-12-08 03:40:33 -05:00
print(abs_val_grad(1.0)) # prints 1.0
print(abs_val_grad(-1.0)) # prints -1.0 (abs_val is re-evaluated)
2018-12-07 07:39:16 -08:00
```
2019-12-14 07:00:39 -08:00
See the [reference docs on automatic
differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic -differentiation)
and the [JAX Autodiff
Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html)
for more.
2018-12-07 07:39:16 -08:00
2019-12-14 07:00:39 -08:00
### Compilation with `jit`
You can use XLA to compile your functions end-to-end with
[`jit` ](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit ),
used either as an `@jit` decorator or as a higher-order function.
2018-12-07 07:39:16 -08:00
```python
2020-07-15 13:17:38 -07:00
import jax.numpy as jnp
2018-12-07 07:39:16 -08:00
from jax import jit
def slow_f(x):
# Element-wise ops see a large benefit from fusion
return x * x + x * 2.0
2020-07-15 13:17:38 -07:00
x = jnp.ones((5000, 5000))
2018-12-07 07:39:16 -08:00
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x) # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x) # ~ 14.5 ms / loop (also on GPU via JAX)
```
You can mix `jit` and `grad` and any other JAX transformation however you like.
2019-12-14 07:00:39 -08:00
Using `jit` puts constraints on the kind of Python control flow
the function can use; see
2024-11-06 10:43:17 -08:00
the tutorial on [Control Flow and Logical Operators with JIT ](https://jax.readthedocs.io/en/latest/control-flow.html )
2019-12-14 07:00:39 -08:00
for more.
### Auto-vectorization with `vmap`
2018-12-07 07:39:16 -08:00
2019-12-14 07:00:39 -08:00
[`vmap` ](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap ) is
the vectorizing map.
2018-12-07 07:39:16 -08:00
It has the familiar semantics of mapping a function along array axes, but
instead of keeping the loop on the outside, it pushes the loop down into a
function’ s primitive operations for better performance.
Using `vmap` can save you from having to carry around batch dimensions in your
code. For example, consider this simple *unbatched* neural network prediction
function:
```python
def predict(params, input_vec):
assert input_vec.ndim == 1
2021-03-08 16:25:04 -08:00
activations = input_vec
2018-12-07 07:39:16 -08:00
for W, b in params:
2021-03-08 16:25:04 -08:00
outputs = jnp.dot(W, activations) + b # `activations` on the right-hand side!
2021-09-14 11:47:19 -07:00
activations = jnp.tanh(outputs) # inputs to the next layer
return outputs # no activation on last layer
2018-12-07 07:39:16 -08:00
```
2021-03-08 16:25:04 -08:00
We often instead write `jnp.dot(activations, W)` to allow for a batch dimension on the
left side of `activations` , but we’ ve written this particular prediction function to
2018-12-07 07:39:16 -08:00
apply only to single input vectors. If we wanted to apply this function to a
batch of inputs at once, semantically we could just write
```python
from functools import partial
2020-07-15 13:17:38 -07:00
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))
2018-12-07 07:39:16 -08:00
```
But pushing one example through the network at a time would be slow! It’ s better
to vectorize the computation, so that at every layer we’ re doing matrix-matrix
2020-11-29 18:51:24 -07:00
multiplication rather than matrix-vector multiplication.
2018-12-07 07:39:16 -08:00
The `vmap` function does that transformation for us. That is, if we write
```python
from jax import vmap
2018-12-11 12:52:09 -08:00
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
2018-12-11 12:55:39 -08:00
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
2018-12-07 07:39:16 -08:00
```
then the `vmap` function will push the outer loop inside the function, and our
machine will end up executing matrix-matrix multiplications exactly as if we’ d
done the batching by hand.
It’ s easy enough to manually batch a simple neural network without `vmap` , but
in other cases manual vectorization can be impractical or impossible. Take the
problem of efficiently computing per-example gradients: that is, for a fixed set
of parameters, we want to compute the gradient of our loss function evaluated
separately at each example in a batch. With `vmap` , it’ s easy:
```python
2018-12-11 12:52:09 -08:00
per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)
2018-12-07 07:39:16 -08:00
```
Of course, `vmap` can be arbitrarily composed with `jit` , `grad` , and any other
JAX transformation! We use `vmap` with both forward- and reverse-mode automatic
differentiation for fast Jacobian and Hessian matrix calculations in
`jax.jacfwd` , `jax.jacrev` , and `jax.hessian` .
2019-12-14 07:00:39 -08:00
### SPMD programming with `pmap`
2018-12-07 07:39:16 -08:00
2019-12-14 07:00:39 -08:00
For parallel programming of multiple accelerators, like multiple GPUs, use
[`pmap` ](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap ).
With `pmap` you write single-program multiple-data (SPMD) programs, including
2020-01-17 17:48:27 +00:00
fast parallel collective communication operations. Applying `pmap` will mean
that the function you write is compiled by XLA (similarly to `jit` ), then
2020-04-22 17:27:33 -04:00
replicated and executed in parallel across devices.
2018-12-07 07:39:16 -08:00
2019-12-14 07:00:39 -08:00
Here's an example on an 8-GPU machine:
```python
2020-02-20 18:44:21 -06:00
from jax import random, pmap
2020-07-15 13:17:38 -07:00
import jax.numpy as jnp
2019-12-14 07:00:39 -08:00
# Create 8 random 5000 x 6000 matrices, one per GPU
2024-08-11 08:09:47 -07:00
keys = random.split(random.key(0), 8)
2019-12-14 07:00:39 -08:00
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)
# Run a local matmul on each device in parallel (no data transfer)
2020-07-15 13:17:38 -07:00
result = pmap(lambda x: jnp.dot(x, x.T))(mats) # result.shape is (8, 5000, 5000)
2019-12-14 07:00:39 -08:00
# Compute the mean on each device in parallel and print the result
2020-07-15 13:17:38 -07:00
print(pmap(jnp.mean)(result))
2019-12-14 07:00:39 -08:00
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]
```
2020-01-17 17:48:56 +00:00
In addition to expressing pure maps, you can use fast [collective communication
2019-12-14 07:00:39 -08:00
operations](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel -operators)
between devices:
```python
from functools import partial
from jax import lax
2018-12-07 07:39:16 -08:00
2019-12-14 07:00:39 -08:00
@partial (pmap, axis_name='i')
def normalize(x):
return x / lax.psum(x, 'i')
2020-07-15 13:17:38 -07:00
print(normalize(jnp.arange(4.)))
2019-12-14 07:00:39 -08:00
# prints [0. 0.16666667 0.33333334 0.5 ]
```
2024-09-20 07:51:48 -07:00
You can even [nest `pmap` functions ](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb#scrollTo=MdRscR5MONuN ) for more
2019-12-14 07:00:39 -08:00
sophisticated communication patterns.
It all composes, so you're free to differentiate through parallel computations:
```python
from jax import grad
@pmap
def f(x):
2020-07-15 13:17:38 -07:00
y = jnp.sin(x)
2019-12-14 07:00:39 -08:00
@pmap
def g(z):
2020-07-15 13:17:38 -07:00
return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
return grad(lambda w: jnp.sum(g(w)))(x)
2019-12-14 07:00:39 -08:00
print(f(x))
# [[ 0. , -0.7170853 ],
# [-3.1085174 , -0.4824318 ],
# [10.366636 , 13.135289 ],
# [ 0.22163185, -0.52112055]]
2020-07-15 13:17:38 -07:00
print(grad(lambda x: jnp.sum(f(x)))(x))
2019-12-14 07:00:39 -08:00
# [[ -3.2369726, -1.6356447],
# [ 4.7572474, 11.606951 ],
# [-98.524414 , 42.76499 ],
# [ -1.6007166, -1.2568436]]
```
When reverse-mode differentiating a `pmap` function (e.g. with `grad` ), the
backward pass of the computation is parallelized just like the forward pass.
2019-12-14 08:23:27 -08:00
See the [SPMD
2024-09-20 07:51:48 -07:00
Cookbook](https://colab.research.google.com/github/jax-ml/jax/blob/main/cloud_tpu_colabs/Pmap_Cookbook.ipynb)
2019-12-14 08:23:27 -08:00
and the [SPMD MNIST classifier from scratch
2024-09-20 07:51:48 -07:00
example](https://github.com/jax-ml/jax/blob/main/examples/spmd_mnist_classifier_fromscratch.py)
2019-12-14 07:00:39 -08:00
for more.
## Current gotchas
For a more thorough survey of current gotchas, with examples and explanations,
we highly recommend reading the [Gotchas
Notebook](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html).
Some standouts:
2020-01-23 10:25:49 -08:00
1. JAX transformations only work on [pure functions ](https://en.wikipedia.org/wiki/Pure_function ), which don't have side-effects and respect [referential transparency ](https://en.wikipedia.org/wiki/Referential_transparency ) (i.e. object identity testing with `is` isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like `Exception: Can't lift Traced...` or `Exception: Different traces at same level` .
2019-12-14 07:00:39 -08:00
1. [In-place mutating updates of
2021-07-25 09:40:24 -05:00
arrays](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#in -place-updates), like `x[i] += y` , aren't supported, but [there are functional alternatives ](https://jax.readthedocs.io/en/latest/jax.ops.html ). Under a `jit` , those functional alternatives will reuse buffers in-place automatically.
2020-01-23 10:25:49 -08:00
1. [Random numbers are
2024-11-15 10:30:13 -08:00
different](https://jax.readthedocs.io/en/latest/random-numbers.html), but for [good reasons ](https://github.com/jax-ml/jax/blob/main/docs/jep/263-prng.md ).
2020-01-23 10:25:49 -08:00
1. If you're looking for [convolution
2021-07-25 09:40:24 -05:00
operators](https://jax.readthedocs.io/en/latest/notebooks/convolutions.html),
2019-12-14 07:00:39 -08:00
they're in the `jax.lax` package.
2020-01-23 10:25:49 -08:00
1. JAX enforces single-precision (32-bit, e.g. `float32` ) values by default, and
2019-12-14 07:00:39 -08:00
[to enable
2021-07-25 09:40:24 -05:00
double-precision](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double -64bit-precision)
2019-12-14 07:00:39 -08:00
(64-bit, e.g. `float64` ) one needs to set the `jax_enable_x64` variable at
startup (or set the environment variable `JAX_ENABLE_X64=True` ).
2021-07-16 12:24:19 +02:00
On TPU, JAX uses 32-bit values by default for everything _except_ internal
temporary variables in 'matmul-like' operations, such as `jax.numpy.dot` and `lax.conv` .
2024-04-02 09:16:57 -07:00
Those ops have a `precision` parameter which can be used to approximate 32-bit operations
via three bfloat16 passes, with a cost of possibly slower runtime.
Non-matmul operations on TPU lower to implementations that often emphasize speed over
accuracy, so in practice computations on TPU will be less precise than similar
computations on other backends.
2020-01-23 10:25:49 -08:00
1. Some of NumPy's dtype promotion semantics involving a mix of Python scalars
2019-12-14 07:00:39 -08:00
and NumPy types aren't preserved, namely `np.add(1, np.array([2],
np.float32)).dtype` is ` float64` rather than ` float32`.
2020-01-23 10:25:49 -08:00
1. Some transformations, like `jit` , [constrain how you can use Python control
2024-11-06 10:43:17 -08:00
flow](https://jax.readthedocs.io/en/latest/control-flow.html).
2019-12-14 07:00:39 -08:00
You'll always get loud errors if something goes wrong. You might have to use
[`jit` 's `static_argnums`
parameter](https://jax.readthedocs.io/en/latest/jax.html#just -in-time-compilation-jit),
[structured control flow
primitives](https://jax.readthedocs.io/en/latest/jax.lax.html#control -flow-operators)
like
[`lax.scan` ](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html#jax.lax.scan ),
or just use `jit` on smaller subfunctions.
2018-12-07 07:39:16 -08:00
2019-12-10 13:50:36 -08:00
## Installation
2019-12-14 07:00:39 -08:00
2023-09-22 14:58:38 -04:00
### Supported platforms
2018-12-07 07:39:16 -08:00
2024-07-11 17:05:16 -07:00
| | Linux x86_64 | Linux aarch64 | Mac x86_64 | Mac aarch64 | Windows x86_64 | Windows WSL2 x86_64 |
2023-10-10 12:34:26 -04:00
|------------|--------------|---------------|--------------|--------------|----------------|---------------------|
| CPU | yes | yes | yes | yes | yes | yes |
| NVIDIA GPU | yes | yes | no | n/a | no | experimental |
| Google TPU | yes | n/a | n/a | n/a | n/a | n/a |
2024-07-11 17:05:16 -07:00
| AMD GPU | yes | no | experimental | n/a | no | no |
| Apple GPU | n/a | no | n/a | experimental | n/a | n/a |
2024-10-28 10:47:59 -07:00
| Intel GPU | experimental | n/a | n/a | n/a | no | no |
2023-08-10 08:29:58 -07:00
2022-07-19 15:04:16 -07:00
2023-09-22 14:58:38 -04:00
### Instructions
2018-12-07 07:39:16 -08:00
2024-07-11 17:05:16 -07:00
| Platform | Instructions |
|-----------------|-----------------------------------------------------------------------------------------------------------------|
| CPU | `pip install -U jax` |
| NVIDIA GPU | `pip install -U "jax[cuda12]"` |
2025-02-24 17:50:29 -05:00
| Google TPU | `pip install -U "jax[tpu]"` |
| AMD GPU (Linux) | Follow [AMD's instructions ](https://github.com/jax-ml/jax/blob/main/build/rocm/README.md ). |
2024-07-11 17:05:16 -07:00
| Mac GPU | Follow [Apple's instructions ](https://developer.apple.com/metal/jax/ ). |
2024-10-28 10:47:59 -07:00
| Intel GPU | Follow [Intel's instructions ](https://github.com/intel/intel-extension-for-openxla/blob/main/docs/acc_jax.md ). |
2018-12-07 07:39:16 -08:00
2023-09-22 14:58:38 -04:00
See [the documentation ](https://jax.readthedocs.io/en/latest/installation.html )
2023-09-26 09:49:06 -04:00
for information on alternative installation strategies. These include compiling
2023-09-22 14:58:38 -04:00
from source, installing with Docker, using other versions of CUDA, a
community-supported conda build, and answers to some frequently-asked questions.
2023-07-11 12:45:13 -04:00
2018-12-07 07:39:16 -08:00
2020-06-12 10:35:59 -07:00
## Neural network libraries
2024-11-04 22:35:26 +00:00
Multiple Google research groups at Google DeepMind and Alphabet develop and share libraries
for training neural networks in JAX. If you want a fully featured library for neural network
2020-10-30 08:42:04 -07:00
training with examples and how-to guides, try
2024-11-04 22:35:26 +00:00
[Flax ](https://github.com/google/flax ) and its [documentation site ](https://flax.readthedocs.io/en/latest/nnx/index.html ).
Check out the [JAX Ecosystem section ](https://jax.readthedocs.io/en/latest/#ecosystem )
on the JAX documentation site for a list of JAX-based network libraries, which includes
[Optax ](https://github.com/deepmind/optax ) for gradient processing and
optimization, [chex ](https://github.com/deepmind/chex ) for reliable code and testing, and
[Equinox ](https://github.com/patrick-kidger/equinox ) for neural networks.
(Watch the NeurIPS 2020 JAX Ecosystem at DeepMind talk
[here ](https://www.youtube.com/watch?v=iDxJxIyzSiM ) for additional details.)
2018-12-07 07:39:16 -08:00
2019-09-19 15:02:03 -07:00
## Citing JAX
To cite this repository:
```
@software {jax2018github,
2020-11-16 19:03:15 -08:00
author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
2019-09-19 15:02:03 -07:00
title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
2024-09-20 07:51:48 -07:00
url = {http://github.com/jax-ml/jax},
2022-06-07 13:59:45 -07:00
version = {0.3.13},
2019-09-19 15:02:03 -07:00
year = {2018},
}
```
In the above bibtex entry, names are in alphabetical order, the version number
2021-06-18 08:55:08 +03:00
is intended to be that from [jax/version.py ](../main/jax/version.py ), and
2019-09-19 15:02:03 -07:00
the year corresponds to the project's open-source release.
A nascent version of JAX, supporting only automatic differentiation and
compilation to XLA, was described in a [paper that appeared at SysML
2020-01-21 13:26:36 -08:00
2018](https://mlsys.org/Conferences/2019/doc/2018/146.pdf). We're currently working on
2019-09-19 15:02:03 -07:00
covering JAX's ideas and capabilities in a more comprehensive and up-to-date
paper.
2019-12-10 13:50:36 -08:00
## Reference documentation
For details about the JAX API, see the
[reference documentation ](https://jax.readthedocs.io/ ).
For getting started as a JAX developer, see the
[developer documentation ](https://jax.readthedocs.io/en/latest/developer.html ).