DOC: replace old tutorials with new content

This commit is contained in:
Jake VanderPlas 2024-04-18 13:11:25 -07:00
parent e70191bd9e
commit 10ed827fe9
57 changed files with 119 additions and 9220 deletions

View File

@ -83,7 +83,7 @@ perex_grads = jit(vmap(grad_loss, in_axes=(None, 0, 0))) # fast per-example gra
## Quickstart: Colab in the Cloud
Jump right in using a notebook in your browser, connected to a Google Cloud GPU.
Here are some starter notebooks:
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [The basics: NumPy on accelerators, `grad` for differentiation, `jit` for compilation, and `vmap` for vectorization](https://jax.readthedocs.io/en/latest/quickstart.html)
- [Training a Simple Neural Network, with TensorFlow Dataset Data Loading](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/neural_network_with_tfds_data.ipynb)
**JAX now runs on Cloud TPUs.** To try out the preview, see the [Cloud TPU

View File

@ -1,7 +1,7 @@
# Advanced compilation
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
For the time being, you may find some related content in the old documentation:
- {doc}`../aot`

View File

@ -15,7 +15,7 @@ kernelspec:
(advanced-debugging)=
# Advanced debugging
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
For the time being, you may find some related content in the old documentation:
- {doc}`../debugging/index`

57
docs/_tutorials/index.rst Normal file
View File

@ -0,0 +1,57 @@
:orphan:
.. _jax-tutorials-draft:
JAX tutorials draft
===================
.. note::
This is a
The tutorials below are a work in progress; for the time being, please refer
to the older tutorial content, including :ref:`beginner-guide`,
:ref:`user-guides`, and the now-deleted *JAX 101* tutorials.
JAX 101
-------
Mostly finalized at :ref:`jax-tutorials`!
.. toctree::
:maxdepth: 1
../quickstart
../key-concepts
../jit-compilation
../automatic-vectorization
../automatic-differentiation
../debugging
../random-numbers
../working-with-pytrees
../sharded-computation
../stateful-computations
simple-neural-network
JAX 201
-------
.. toctree::
:maxdepth: 1
parallelism
advanced-autodiff
gradient-checkpointing
advanced-debugging
external-callbacks
profiling-and-performance
JAX 301
-------
.. toctree::
:maxdepth: 1
jax-primitives
jaxpr
advanced-compilation

View File

@ -1,7 +1,7 @@
# Parallel computation
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
For the time being, you may find some related content in the old documentation:
- {doc}`../multi_process`

View File

@ -1,7 +1,7 @@
# Profiling and performance
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
For the time being, you may find some related content in the old documentation:
- {doc}`../profiling`

View File

@ -1,5 +1,5 @@
# Example: Writing a simple neural network
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
This is a placeholder for a section in the new {ref}`jax-tutorials-draft`.
```

View File

@ -5,7 +5,7 @@
Getting Started with JAX
========================
Welcome to JAX! The JAX documentation contains a number of useful resources for getting started.
:doc:`notebooks/quickstart` is the easiest place to jump-in and get an overview of the JAX project.
:doc:`quickstart` is the easiest place to jump-in and get an overview of the JAX project.
If you're accustomed to writing NumPy code and are starting to explore JAX, you might find the following resources helpful:
@ -15,12 +15,12 @@ If you're accustomed to writing NumPy code and are starting to explore JAX, you
Tutorials
---------
If you're ready to explore JAX more deeply, the JAX 101 tutorial goes into much more detail:
If you're ready to explore JAX more deeply, the JAX tutorials go into much more detail:
.. toctree::
:maxdepth: 2
jax-101/index
tutorials
If you prefer a video introduction here is one from JAX contributor Jake VanderPlas:

View File

@ -43,7 +43,7 @@ Here are more specific examples of each pattern.
### Direct Usage
Jax can be directly imported and utilized to build models “from scratch” as shown across this website,
for example in [JAX 101](https://jax.readthedocs.io/en/latest/jax-101/index.html)
for example in [JAX Tutorials](https://jax.readthedocs.io/en/latest/tutorials.html)
or [Neural Network with JAX](https://jax.readthedocs.io/en/latest/notebooks/neural_network_with_tfds_data.html).
This may be the best option if you are unable to find prebuilt code
for your particular challenge, or if you're looking to reduce the number

View File

@ -73,7 +73,8 @@ extensions = [
"sphinx_remove_toctrees",
'sphinx_copybutton',
'jax_extensions',
'sphinx_design'
'sphinx_design',
'sphinxext.rediraffe',
]
intersphinx_mapping = {
@ -125,9 +126,8 @@ exclude_patterns = [
'pallas/quickstart.md',
'pallas/tpu/pipelining.md',
'jep/9407-type-promotion.md',
'jax-101/*.md',
'autodidax.md',
'tutorials/sharded-computation.md',
'sharded-computation.md',
]
# The name of the Pygments (syntax highlighting) style to use.
@ -199,8 +199,6 @@ nb_execution_timeout = 100
# List of patterns, relative to source directory, that match notebook
# files that will not be executed.
nb_execution_excludepatterns = [
# Includes GPU timings that shouldn't be executed by doc build
'notebooks/quickstart.*',
# Slow notebook: long time to load tf.ds
'notebooks/neural_network_with_tfds_data.*',
# Slow notebook
@ -208,14 +206,13 @@ nb_execution_excludepatterns = [
# Has extra requirements: networkx, pandas, pytorch, tensorflow, etc.
'jep/9407-type-promotion.*',
# TODO(jakevdp): enable execution on the following if possible:
'jax-101/*',
'notebooks/xmap_tutorial.*',
'notebooks/Distributed_arrays_and_automatic_parallelization.*',
'notebooks/autodiff_remat.*',
# Requires accelerators
'pallas/quickstart.*',
'pallas/tpu/pipelining.*',
'tutorials/sharded-computation.*'
'sharded-computation.*'
]
# -- Options for HTMLHelp output ---------------------------------------------
@ -331,3 +328,18 @@ def linkcode_resolve(domain, info):
filename = os.path.relpath(filename, start=os.path.dirname(jax.__file__))
lines = f"#L{linenum}-L{linenum + len(source)}" if linenum else ""
return f"https://github.com/google/jax/blob/main/jax/{filename}{lines}"
# Generate redirects from deleted files to new sources
rediraffe_redirects = {
'notebooks/quickstart.md': 'quickstart.md',
'jax-101/01-jax-basics.md': 'key-concepts.md',
'jax-101/02-jitting.md': 'jit-compilation.md',
'jax-101/03-vectorization.md': 'automatic-vectorization.md',
'jax-101/04-advanced-autodiff.md': 'automatic-differentiation.md',
'jax-101/05-random-numbers.md': 'random-numbers.md',
'jax-101/05.1-pytrees.md': 'working-with-pytrees.md',
'jax-101/06-parallelism.md': 'sharded-computation.md',
'jax-101/07-state.md': 'stateful-computations.md',
'jax-101/08-pjit.rst': 'sharded-computation.md',
'jax-101/index.rst': 'tutorials.rst',
}

View File

@ -130,7 +130,7 @@ def f(x):
f(2.) # ==> Pauses during execution
```
![JAX debugger](../_static/debugger.gif)
![JAX debugger](_static/debugger.gif)
For value-dependent breakpointing, you can use runtime conditionals like {func}`jax.lax.cond`:

View File

@ -418,7 +418,7 @@ notebooks; for example:
```
pip install jupytext==1.16.0
jupytext --sync docs/notebooks/quickstart.ipynb
jupytext --sync docs/notebooks/thinking_in_jax.ipynb
```
The jupytext version should match that specified in

View File

@ -61,8 +61,7 @@ For an end-to-end transformer library built on JAX, see MaxText_.
:caption: Getting Started
installation
notebooks/quickstart
notebooks/thinking_in_jax
quickstart
notebooks/Common_Gotchas_in_JAX
faq
@ -70,7 +69,7 @@ For an end-to-end transformer library built on JAX, see MaxText_.
:hidden:
:maxdepth: 1
jax-101/index
tutorials
.. toctree::

File diff suppressed because one or more lines are too long

View File

@ -1,383 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "6_117sy0CGEU"}
# JAX As Accelerated NumPy
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/01-jax-basics.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/01-jax-basics.ipynb)
*Authors: Rosalia Schneider & Vladimir Mikulik*
In this first section you will learn the very fundamentals of JAX.
+++ {"id": "CXjHL4L6ku3-"}
## Getting started with JAX numpy
Fundamentally, JAX is a library that enables transformations of array-manipulating programs written with a NumPy-like API.
Over the course of this series of guides, we will unpack exactly what that means. For now, you can think of JAX as *differentiable NumPy that runs on accelerators*.
The code below shows how to import JAX and create a vector.
```{code-cell} ipython3
:id: ZqUzvqF1B1TO
import jax
import jax.numpy as jnp
x = jnp.arange(10)
print(x)
```
+++ {"id": "rPBmlAxXlBAy"}
So far, everything is just like NumPy. A big appeal of JAX is that you don't need to learn a new API. Many common NumPy programs would run just as well in JAX if you substitute `np` for `jnp`. However, there are some important differences which we touch on at the end of this section.
You can notice the first difference if you check the type of `x`. It is a variable of type `Array`, which is the way JAX represents arrays.
```{code-cell} ipython3
:id: 3fLtgPUAn7mi
x
```
+++ {"id": "Yx8VofzzoHFH"}
One useful feature of JAX is that the same code can be run on different backends -- CPU, GPU and TPU.
We will now perform a dot product to demonstrate that it can be done in different devices without changing the code. We use `%timeit` to check the performance.
(Technical detail: when a JAX function is called (including `jnp.array`
creation), the corresponding operation is dispatched to an accelerator to be
computed asynchronously when possible. The returned array is therefore not
necessarily 'filled in' as soon as the function returns. Thus, if we don't
require the result immediately, the computation won't block Python execution.
Therefore, unless we `block_until_ready` or convert the array to a regular
Python type, we will only time the dispatch, not the actual computation. See
[Asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html#asynchronous-dispatch)
in the JAX docs.)
```{code-cell} ipython3
:id: mRvjVxoqo-Bi
long_vector = jnp.arange(int(1e7))
%timeit jnp.dot(long_vector, long_vector).block_until_ready()
```
+++ {"id": "DKBB0zs-p-RC"}
**Tip**: Try running the code above twice, once without an accelerator, and once with a GPU runtime (while in Colab, click *Runtime**Change Runtime Type* and choose `GPU`). Notice how much faster it runs on a GPU.
+++ {"id": "PkCpI-v0uQQO"}
## JAX first transformation: `grad`
A fundamental feature of JAX is that it allows you to transform functions.
One of the most commonly used transformations is `jax.grad`, which takes a numerical function written in Python and returns you a new Python function that computes the gradient of the original function.
To use it, let's first define a function that takes an array and returns the sum of squares.
```{code-cell} ipython3
:id: LuaGUVRUvbzQ
def sum_of_squares(x):
return jnp.sum(x**2)
```
+++ {"id": "QAqloI1Wvtp2"}
Applying `jax.grad` to `sum_of_squares` will return a different function, namely the gradient of `sum_of_squares` with respect to its first parameter `x`.
Then, you can use that function on an array to return the derivatives with respect to each element of the array.
```{code-cell} ipython3
:id: dKeorwJfvpeI
sum_of_squares_dx = jax.grad(sum_of_squares)
x = jnp.asarray([1.0, 2.0, 3.0, 4.0])
print(sum_of_squares(x))
print(sum_of_squares_dx(x))
```
+++ {"id": "VfBt5CYbyKUX"}
You can think of `jax.grad` by analogy to the $\nabla$ operator from vector calculus. Given a function $f(x)$, $\nabla f$ represents the function that computes $f$'s gradient, i.e.
$$
(\nabla f)(x)_i = \frac{\partial f}{\partial x_i}(x).
$$
Analogously, `jax.grad(f)` is the function that computes the gradient, so `jax.grad(f)(x)` is the gradient of `f` at `x`.
(Like $\nabla$, `jax.grad` will only work on functions with a scalar output -- it will raise an error otherwise.)
This makes the JAX API quite different from other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.
This way of doing things makes it straightforward to control things like which variables to differentiate with respect to. By default, `jax.grad` will find the gradient with respect to the first argument. In the example below, the result of `sum_squared_error_dx` will be the gradient of `sum_squared_error` with respect to `x`.
```{code-cell} ipython3
:id: f3NfaVu4yrQE
def sum_squared_error(x, y):
return jnp.sum((x-y)**2)
sum_squared_error_dx = jax.grad(sum_squared_error)
y = jnp.asarray([1.1, 2.1, 3.1, 4.1])
print(sum_squared_error_dx(x, y))
```
+++ {"id": "1tOztA5zpLWN"}
To find the gradient with respect to a different argument (or several), you can set `argnums`:
```{code-cell} ipython3
:id: FQSczVQkqIPY
jax.grad(sum_squared_error, argnums=(0, 1))(x, y) # Find gradient wrt both x & y
```
+++ {"id": "yQAMTnZSqo-t"}
Does this mean that when doing machine learning, we need to write functions with gigantic argument lists, with an argument for each model parameter array? No. JAX comes equipped with machinery for bundling arrays together in data structures called 'pytrees', on which more in a [later guide](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb). So, most often, use of `jax.grad` looks like this:
```
def loss_fn(params, data):
...
grads = jax.grad(loss_fn)(params, data_batch)
```
+++ {"id": "oBowiovisT97"}
where `params` is, for example, a nested dict of arrays, and the returned `grads` is another nested dict of arrays with the same structure.
+++ {"id": "LNjf9jUEsZZ8"}
## Value and Grad
Often, you need to find both the value and the gradient of a function, e.g. if you want to log the training loss. JAX has a handy sister transformation for efficiently doing that:
```{code-cell} ipython3
:id: dWg4_-h3sYwl
jax.value_and_grad(sum_squared_error)(x, y)
```
+++ {"id": "QVT2EWHJsvvv"}
which returns a tuple of, you guessed it, (value, grad). To be precise, for any `f`,
```
jax.value_and_grad(f)(*xs) == (f(*xs), jax.grad(f)(*xs))
```
+++ {"id": "QmHTVpAks3OX"}
## Auxiliary data
In addition to wanting to log the value, we often want to report some intermediate results obtained in computing the loss function. But if we try doing that with regular `jax.grad`, we run into trouble:
```{code-cell} ipython3
:id: ffGCEzT4st41
:tags: [raises-exception]
def squared_error_with_aux(x, y):
return sum_squared_error(x, y), x-y
jax.grad(squared_error_with_aux)(x, y)
```
+++ {"id": "IUubno3nth4i"}
This is because `jax.grad` is only defined on scalar functions, and our new function returns a tuple. But we need to return a tuple to return our intermediate results! This is where `has_aux` comes in:
```{code-cell} ipython3
:id: uzUFihyatgiF
jax.grad(squared_error_with_aux, has_aux=True)(x, y)
```
+++ {"id": "g5s3UiFauwDk"}
`has_aux` signifies that the function returns a pair, `(out, aux)`. It makes `jax.grad` ignore `aux`, passing it through to the user, while differentiating the function as if only `out` was returned.
+++ {"id": "fk4FUXe7vsW4"}
## Differences from NumPy
The `jax.numpy` API closely follows that of NumPy. However, there are some important differences. We cover many of these in future guides, but it's worth pointing some out now.
The most important difference, and in some sense the root of all the rest, is that JAX is designed to be _functional_, as in _functional programming_. The reason behind this is that the kinds of program transformations that JAX enables are much more feasible in functional-style programs.
An introduction to functional programming (FP) is out of scope of this guide. If you already are familiar with FP, you will find your FP intuition helpful while learning JAX. If not, don't worry! The important feature of functional programming to grok when working with JAX is very simple: don't write code with side-effects.
A side-effect is any effect of a function that doesn't appear in its output. One example is modifying an array in place:
```{code-cell} ipython3
:id: o_YBuLQC1wPJ
import numpy as np
x = np.array([1, 2, 3])
def in_place_modify(x):
x[0] = 123
return None
in_place_modify(x)
x
```
+++ {"id": "JTtUihVZ13F6"}
The side-effectful function modifies its argument, but returns a completely unrelated value. The modification is a side-effect.
The code below will run in NumPy. However, JAX arrays won't allow themselves to be modified in-place:
```{code-cell} ipython3
:id: u6grTYIVcZ3f
:tags: [raises-exception]
in_place_modify(jnp.array(x)) # Raises error if we cast input to jnp.ndarray
```
+++ {"id": "RGqVfYSpc49s"}
Helpfully, the error points us to JAX's side-effect-free way of doing the same thing via the [`jax.numpy.ndarray.at`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html) index update operators (be careful [`jax.ops.index_*`](https://jax.readthedocs.io/en/latest/jax.ops.html#indexed-update-functions-deprecated) functions are deprecated). They are analogous to in-place modification by index, but create a new array with the corresponding modifications made:
```{code-cell} ipython3
:id: Rmklk6BB2xF0
def jax_in_place_modify(x):
return x.at[0].set(123)
y = jnp.array([1, 2, 3])
jax_in_place_modify(y)
```
+++ {"id": "91tn_25vdrNf"}
Note that the old array was untouched, so there is no side-effect:
```{code-cell} ipython3
:id: KQGXig4Hde6T
y
```
+++ {"id": "d5TibzPO25qa"}
Side-effect-free code is sometimes called *functionally pure*, or just *pure*.
Isn't the pure version less efficient? Strictly, yes; we are creating a new array. However, as we will explain in the next guide, JAX computations are often compiled before being run using another program transformation, `jax.jit`. If we don't use the old array after modifying it 'in place' using indexed update operators, the compiler can recognise that it can in fact compile to an in-place modify, resulting in efficient code in the end.
Of course, it's possible to mix side-effectful Python code and functionally pure JAX code, and we will touch on this more later. As you get more familiar with JAX, you will learn how and when this can work. As a rule of thumb, however, any functions intended to be transformed by JAX should avoid side-effects, and the JAX primitives themselves will try to help you do that.
We will explain other places where the JAX idiosyncrasies become relevant as they come up. There is even a section that focuses entirely on getting used to the functional programming style of handling state: [Part 7: Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). However, if you're impatient, you can find a [summary of JAX's sharp edges](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) in the JAX docs.
+++ {"id": "dFn_VBFFlGCz"}
## Your first JAX training loop
We still have much to learn about JAX, but you already know enough to understand how we can use JAX to build a simple training loop.
To keep things simple, we'll start with a linear regression.
Our data is sampled according to $y = w_{true} x + b_{true} + \epsilon$.
```{code-cell} ipython3
:id: WGgyEWFqrPq1
import numpy as np
import matplotlib.pyplot as plt
xs = np.random.normal(size=(100,))
noise = np.random.normal(scale=0.1, size=(100,))
ys = xs * 3 - 1 + noise
plt.scatter(xs, ys);
```
+++ {"id": "RTh22mo4rR1x"}
Therefore, our model is $\hat y(x; \theta) = wx + b$.
We will use a single array, `theta = [w, b]` to house both parameters:
```{code-cell} ipython3
:id: TnVrRTMamyzb
def model(theta, x):
"""Computes wx + b on a batch of input x."""
w, b = theta
return w * x + b
```
+++ {"id": "qCrLmmKrn9_h"}
The loss function is $J(x, y; \theta) = (\hat y - y)^2$.
```{code-cell} ipython3
:id: 07eMcDLMn9Ww
def loss_fn(theta, x, y):
prediction = model(theta, x)
return jnp.mean((prediction-y)**2)
```
+++ {"id": "ejMt4dulnoYX"}
How do we optimize a loss function? Using gradient descent. At each update step, we will find the gradient of the loss w.r.t. the parameters, and take a small step in the direction of steepest descent:
$\theta_{new} = \theta - 0.1 (\nabla_\theta J) (x, y; \theta)$
```{code-cell} ipython3
:id: 2I6T5Wphpaaa
def update(theta, x, y, lr=0.1):
return theta - lr * jax.grad(loss_fn)(theta, x, y)
```
+++ {"id": "MAUL1gT_opVn"}
In JAX, it's common to define an `update()` function that is called every step, taking the current parameters as input and returning the new parameters. This is a natural consequence of JAX's functional nature, and is explained in more detail in [The Problem of State](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb).
This function can then be JIT-compiled in its entirety for maximum efficiency. The next guide will explain exactly how `jax.jit` works, but if you want to, you can try adding `@jax.jit` before the `update()` definition, and see how the training loop below runs much faster.
```{code-cell} ipython3
:id: WLZxY7nIpuVW
theta = jnp.array([1., 1.])
for _ in range(1000):
theta = update(theta, xs, ys)
plt.scatter(xs, ys)
plt.plot(xs, model(theta, xs))
w, b = theta
print(f"w: {w:<.2f}, b: {b:<.2f}")
```
+++ {"id": "5-q17kJ_rjLc"}
As you will see going through these guides, this basic recipe underlies almost all training loops you'll see implemented in JAX. The main difference between this example and real training loops is the simplicity of our model: that allows us to use a single array to house all our parameters. We cover managing more parameters in the later [pytree guide](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb). Feel free to skip forward to that guide now to see how to manually define and train a simple MLP in JAX.

View File

@ -1,673 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "O-SkdlPxvETZ"
},
"source": [
"# Just In Time Compilation with JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/02-jitting.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/02-jitting.ipynb)\n",
"\n",
"*Authors: Rosalia Schneider & Vladimir Mikulik*\n",
"\n",
"In this section, we will further explore how JAX works, and how we can make it performant.\n",
"We will discuss the `jax.jit()` transform, which will perform *Just In Time* (JIT) compilation\n",
"of a JAX Python function so it can be executed efficiently in XLA.\n",
"\n",
"## How JAX transforms work\n",
"\n",
"In the previous section, we discussed that JAX allows us to transform Python functions. This is done by first converting the Python function into a simple intermediate language called jaxpr. The transformations then work on the jaxpr representation. \n",
"\n",
"We can show a representation of the jaxpr of a function by using `jax.make_jaxpr`:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "P9Xj77Wx3Z2P",
"outputId": "5a0597eb-86c9-4762-ce10-2811debbc732"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{ lambda ; a:f32[]. let\n",
" b:f32[] = log a\n",
" c:f32[] = log 2.0\n",
" d:f32[] = div b c\n",
" in (d,) }\n"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"global_list = []\n",
"\n",
"def log2(x):\n",
" global_list.append(x)\n",
" ln_x = jnp.log(x)\n",
" ln_2 = jnp.log(2.0)\n",
" return ln_x / ln_2\n",
"\n",
"print(jax.make_jaxpr(log2)(3.0))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jiDsT7y0RwIp"
},
"source": [
"The [Understanding Jaxprs](https://jax.readthedocs.io/en/latest/jaxpr.html) section of the documentation provides more information on the meaning of the above output.\n",
"\n",
"Importantly, note how the jaxpr does not capture the side-effect of the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX is designed to understand side-effect-free (a.k.a. functionally pure) code. If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).\n",
"\n",
"Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour once converted to jaxpr. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JAX-transformed function to run once (during the first call), and never again. This is because of the way that JAX generates jaxpr, using a process called 'tracing'.\n",
"\n",
"When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.\n",
"\n",
"Note: the Python `print()` function is not pure: the text output is a side-effect of the function. Therefore, any `print()` calls will only happen during tracing, and will not appear in the jaxpr:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "JxV2p7e2RawC",
"outputId": "9dfe8a56-e553-4640-a04e-5405aea7832d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>\n",
"{ lambda ; a:f32[]. let\n",
" b:f32[] = log a\n",
" c:f32[] = log 2.0\n",
" d:f32[] = div b c\n",
" in (d,) }\n"
]
}
],
"source": [
"def log2_with_print(x):\n",
" print(\"printed x:\", x)\n",
" ln_x = jnp.log(x)\n",
" ln_2 = jnp.log(2.0)\n",
" return ln_x / ln_2\n",
"\n",
"print(jax.make_jaxpr(log2_with_print)(3.))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f6W_YYwRRwGp"
},
"source": [
"See how the printed `x` is a `Traced` object? That's the JAX internals at work.\n",
"\n",
"The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn't be relied upon. However, it's useful to understand as you can use it when debugging to print out intermediate values of a computation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PgVqi6NlRdWZ"
},
"source": [
"A key thing to understand is that jaxpr captures the function as executed on the parameters given to it. For example, if we have a conditional, jaxpr will only know about the branch we take:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "hn0CuphEZKZm",
"outputId": "99dae727-d2be-4577-831c-e1e14af5890a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{ lambda ; a:i32[3]. let in (a,) }\n"
]
}
],
"source": [
"def log2_if_rank_2(x):\n",
" if x.ndim == 2:\n",
" ln_x = jnp.log(x)\n",
" ln_2 = jnp.log(2.0)\n",
" return ln_x / ln_2\n",
" else:\n",
" return x\n",
"\n",
"print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Qp3WhqaqvHyD"
},
"source": [
"## JIT compiling a function\n",
"\n",
"As explained before, JAX enables operations to execute on CPU/GPU/TPU using the same code.\n",
"Let's look at an example of computing a *Scaled Exponential Linear Unit*\n",
"([SELU](https://proceedings.neurips.cc/paper/6698-self-normalizing-neural-networks.pdf)), an\n",
"operation commonly used in deep learning:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "JAXFYtlRvD6p",
"outputId": "e94d7dc2-a9a1-4ac2-fd3f-152e3f6d141b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100 loops, best of 5: 2.05 ms per loop\n"
]
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"def selu(x, alpha=1.67, lambda_=1.05):\n",
" return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n",
"\n",
"x = jnp.arange(1000000)\n",
"%timeit selu(x).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ecN5lEXe6ncy"
},
"source": [
"The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions.\n",
"\n",
"Naturally, what we want to do is give the XLA compiler as much code as possible, so it can fully optimize it. For this purpose, JAX provides the `jax.jit` transformation, which will JIT compile a JAX-compatible function. The example below shows how to use JIT to speed up the previous function."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "nJVEwPcH6bQX",
"outputId": "289eb2f7-a5ce-4cec-f652-5c4e5b0b86cf"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000 loops, best of 5: 150 µs per loop\n"
]
}
],
"source": [
"selu_jit = jax.jit(selu)\n",
"\n",
"# Warm up\n",
"selu_jit(x).block_until_ready()\n",
"\n",
"%timeit selu_jit(x).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hMNKi1mYXQg5"
},
"source": [
"Here's what just happened:\n",
"\n",
"1) We defined `selu_jit` as the compiled version of `selu`.\n",
"\n",
"2) We called `selu_jit` once on `x`. This is where JAX does its tracing -- it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Finally, the compiled code is executed to satisfy the call. Subsequent calls to `selu_jit` will use the compiled code directly, skipping the python implementation entirely.\n",
"\n",
"(If we didn't include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn't be a fair comparison.)\n",
"\n",
"3) We timed the execution speed of the compiled version. (Note the use of `block_until_ready()`, which is required due to JAX's [Asynchronous execution](https://jax.readthedocs.io/en/latest/async_dispatch.html) model)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DRJ6R6-d9Q_U"
},
"source": [
"## Why can't we just JIT everything?\n",
"\n",
"After going through the example above, you might be wondering whether we should simply apply `jax.jit` to every function. To understand why this is not the case, and when we should/shouldn't apply `jit`, let's first check some cases where JIT doesn't work."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "GO1Mwd_3_W6g",
"outputId": "a6fcf6d1-7bd6-4bb7-99c3-2a5a827183e2",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "ConcretizationTypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-12-2c1a07641e48>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mf_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mf_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 419\u001b[0;31m donated_invars=donated_invars, inline=inline)\n\u001b[0m\u001b[1;32m 420\u001b[0m \u001b[0mout_pytree_def\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1631\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1632\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1633\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1622\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1623\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1624\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self, trace, fun, tracers, params)\u001b[0m\n\u001b[1;32m 1634\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1635\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 627\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 628\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 687\u001b[0m compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 688\u001b[0;31m *unsafe_map(arg_spec, args))\n\u001b[0m\u001b[1;32m 689\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_callable_uncached\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 759\u001b[0m return lower_xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 760\u001b[0;31m *arg_specs).compile().unsafe_call\n\u001b[0m\u001b[1;32m 761\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36mlower_xla_callable\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 771\u001b[0m jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(\n\u001b[0;32m--> 772\u001b[0;31m fun, abstract_args, pe.debug_info_final(fun, \"jit\"))\n\u001b[0m\u001b[1;32m 773\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals, debug_info)\u001b[0m\n\u001b[1;32m 1541\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1542\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1543\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals)\u001b[0m\n\u001b[1;32m 1519\u001b[0m \u001b[0min_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_arg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-12-2c1a07641e48>\u001b[0m in \u001b[0;36mf\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36m__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 548\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__nonzero__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 549\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m__bool__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_bool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 550\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__int__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, arg)\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1000\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mConcretizationTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfname_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1001\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mUnfilteredStackTrace\u001b[0m: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the bool function. \nWhile tracing the function f at <ipython-input-12-2c1a07641e48>:3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-12-2c1a07641e48>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mf_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mf_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-12-2c1a07641e48>\u001b[0m in \u001b[0;36mf\u001b[0;34m(x)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 4\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 5\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the bool function. \nWhile tracing the function f at <ipython-input-12-2c1a07641e48>:3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'x'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError"
]
}
],
"source": [
"# Condition on value of x.\n",
"\n",
"def f(x):\n",
" if x > 0:\n",
" return x\n",
" else:\n",
" return 2 * x\n",
"\n",
"f_jit = jax.jit(f)\n",
"f_jit(10) # Should raise an error. "
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "LHlipkIMFUhi",
"outputId": "54935882-a180-45c0-ad03-9dfb5e3baa97",
"tags": [
"raises-exception"
]
},
"outputs": [
{
"ename": "ConcretizationTypeError",
"evalue": "ignored",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mUnfilteredStackTrace\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-13-2aa78f448d5d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mg_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mg_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 163\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mcache_miss\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 418\u001b[0m \u001b[0mdevice\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbackend\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbackend\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mflat_fun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__name__\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 419\u001b[0;31m donated_invars=donated_invars, inline=inline)\n\u001b[0m\u001b[1;32m 420\u001b[0m \u001b[0mout_pytree_def\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mout_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mbind\u001b[0;34m(self, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1631\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1632\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mcall_bind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1633\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mcall_bind\u001b[0;34m(primitive, fun, *args, **params)\u001b[0m\n\u001b[1;32m 1622\u001b[0m \u001b[0mtracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1623\u001b[0;31m \u001b[0mouts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtop_trace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1624\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfull_lower\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapply_todos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0menv_trace_todo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mouts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess\u001b[0;34m(self, trace, fun, tracers, params)\u001b[0m\n\u001b[1;32m 1634\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1635\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1636\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36mprocess_call\u001b[0;34m(self, primitive, f, tracers, params)\u001b[0m\n\u001b[1;32m 626\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 627\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mprimitive\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mimpl\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0mtracers\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 628\u001b[0m \u001b[0mprocess_map\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mprocess_call\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_call_impl\u001b[0;34m(***failed resolving arguments***)\u001b[0m\n\u001b[1;32m 687\u001b[0m compiled_fun = _xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 688\u001b[0;31m *unsafe_map(arg_spec, args))\n\u001b[0m\u001b[1;32m 689\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mmemoized_fun\u001b[0;34m(fun, *args)\u001b[0m\n\u001b[1;32m 262\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 263\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mcall\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 264\u001b[0m \u001b[0mcache\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mans\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstores\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36m_xla_callable_uncached\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 759\u001b[0m return lower_xla_callable(fun, device, backend, name, donated_invars,\n\u001b[0;32m--> 760\u001b[0;31m *arg_specs).compile().unsafe_call\n\u001b[0m\u001b[1;32m 761\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py\u001b[0m in \u001b[0;36mlower_xla_callable\u001b[0;34m(fun, device, backend, name, donated_invars, *arg_specs)\u001b[0m\n\u001b[1;32m 771\u001b[0m jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(\n\u001b[0;32m--> 772\u001b[0;31m fun, abstract_args, pe.debug_info_final(fun, \"jit\"))\n\u001b[0m\u001b[1;32m 773\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0many\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTracer\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mconsts\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_jaxpr_final\u001b[0;34m(fun, in_avals, debug_info)\u001b[0m\n\u001b[1;32m 1541\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mcore\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_sublevel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1542\u001b[0;31m \u001b[0mjaxpr\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mout_avals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mconsts\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrace_to_subjaxpr_dynamic\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1543\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmain\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/interpreters/partial_eval.py\u001b[0m in \u001b[0;36mtrace_to_subjaxpr_dynamic\u001b[0;34m(fun, main, in_avals)\u001b[0m\n\u001b[1;32m 1519\u001b[0m \u001b[0min_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnew_arg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_avals\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1520\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcall_wrapped\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0min_tracers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1521\u001b[0m \u001b[0mout_tracers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfull_raise\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mans\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 167\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m<ipython-input-13-2aa78f448d5d>\u001b[0m in \u001b[0;36mg\u001b[0;34m(x, n)\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36m__bool__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 548\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__nonzero__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_nonzero\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 549\u001b[0;31m \u001b[0;32mdef\u001b[0m \u001b[0m__bool__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_bool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 550\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__int__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_int\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/core.py\u001b[0m in \u001b[0;36merror\u001b[0;34m(self, arg)\u001b[0m\n\u001b[1;32m 999\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0marg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1000\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mConcretizationTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0marg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfname_context\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1001\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0merror\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mUnfilteredStackTrace\u001b[0m: jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the bool function. \nWhile tracing the function g at <ipython-input-13-2aa78f448d5d>:3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'n'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------",
"\nThe above exception was the direct cause of the following exception:\n",
"\u001b[0;31mConcretizationTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-13-2aa78f448d5d>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0mg_jit\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 10\u001b[0;31m \u001b[0mg_jit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m# Should raise an error.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m<ipython-input-13-2aa78f448d5d>\u001b[0m in \u001b[0;36mg\u001b[0;34m(x, n)\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mg\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mwhile\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0mi\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mConcretizationTypeError\u001b[0m: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>\nThe problem arose with the bool function. \nWhile tracing the function g at <ipython-input-13-2aa78f448d5d>:3 for jit, this concrete value was not available in Python because it depends on the value of the argument 'n'.\n\nSee https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError"
]
}
],
"source": [
"# While loop conditioned on x and n.\n",
"\n",
"def g(x, n):\n",
" i = 0\n",
" while i < n:\n",
" i += 1\n",
" return x + i\n",
"\n",
"g_jit = jax.jit(g)\n",
"g_jit(10, 20) # Should raise an error. "
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "isz2U_XX_wH2"
},
"source": [
"The problem is that we tried to condition on the *value* of an input to the function being jitted. The reason we can't do this is related to the fact mentioned above that jaxpr depends on the actual values used to trace it. \n",
"\n",
"The more specific information about the values we use in the trace, the more we can use standard Python control flow to express ourselves. However, being too specific means we can't reuse the same traced function for other values. JAX solves this by tracing at different levels of abstraction for different purposes.\n",
"\n",
"For `jax.jit`, the default level is `ShapedArray` -- that is, each tracer has a concrete shape (which we're allowed to condition on), but no concrete value. This allows the compiled function to work on all possible inputs with the same shape -- the standard use case in machine learning. However, because the tracers have no concrete value, if we attempt to condition on one, we get the error above.\n",
"\n",
"In `jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` mustn't condition on value. For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).\n",
"\n",
"One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special [control flow operators](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) like `jax.lax.cond`. However, sometimes that is impossible. In that case, you can consider jitting only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "OeR8hF-NHAML",
"outputId": "d47fd6b2-8bbd-4939-a794-0b80183d3179"
},
"outputs": [
{
"data": {
"text/plain": [
"Array(30, dtype=int32, weak_type=True)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# While loop conditioned on x and n with a jitted body.\n",
"\n",
"@jax.jit\n",
"def loop_body(prev_i):\n",
" return prev_i + 1\n",
"\n",
"def g_inner_jitted(x, n):\n",
" i = 0\n",
" while i < n:\n",
" i = loop_body(i)\n",
" return x + i\n",
"\n",
"g_inner_jitted(10, 20)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5XUT2acoHBz-"
},
"source": [
"If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "2yQmQTDNAenY",
"outputId": "c48f07b8-c3f9-4d2a-9dfd-663838a52511"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"10\n"
]
}
],
"source": [
"f_jit_correct = jax.jit(f, static_argnums=0)\n",
"print(f_jit_correct(10))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "R4SXUEu-M-u1",
"outputId": "9e712e14-4e81-4744-dcf2-a10f470d9121"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30\n"
]
}
],
"source": [
"g_jit_correct = jax.jit(g, static_argnames=['n'])\n",
"print(g_jit_correct(10, 20))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"To specify such arguments when using `jit` as a decorator, a common pattern is to use python's `functools.partial`:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2X5rR4jkIO",
"outputId": "81-4744-dc2e4-4e10f470f2-a19e71d9121"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"30\n"
]
}
],
"source": [
"from functools import partial\n",
"\n",
"@partial(jax.jit, static_argnames=['n'])\n",
"def g_jit_decorated(x, n):\n",
" i = 0\n",
" while i < n:\n",
" i += 1\n",
" return x + i\n",
"\n",
"print(g_jit_decorated(10, 20))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LczjIBt2X2Ms"
},
"source": [
"## When to use JIT\n",
"\n",
"In many of the examples above, jitting is not worth it:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "uMOqsNnqYApD",
"outputId": "2d6c5122-43ad-4257-e56b-e77c889131c2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"g jitted:\n",
"The slowest run took 13.54 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000 loops, best of 5: 229 µs per loop\n",
"g:\n",
"The slowest run took 11.72 times longer than the fastest. This could mean that an intermediate result is being cached.\n",
"1000000 loops, best of 5: 1.2 µs per loop\n"
]
}
],
"source": [
"print(\"g jitted:\")\n",
"%timeit g_jit_correct(10, 20).block_until_ready()\n",
"\n",
"print(\"g:\")\n",
"%timeit g(10, 20)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cZmGYq80YP0j"
},
"source": [
"This is because `jax.jit` introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.\n",
"\n",
"Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hJMjUlRcIzVS"
},
"source": [
"## Caching\n",
"\n",
"It's important to understand the caching behaviour of `jax.jit`.\n",
"\n",
"Suppose I define `f = jax.jit(g)`. When I first invoke `f`, it will get compiled, and the resulting XLA code will get cached. Subsequent calls of `f` will reuse the cached code. This is how `jax.jit` makes up for the up-front cost of compilation.\n",
"\n",
"If I specify `static_argnums`, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs. If there are many values, then your program might spend more time compiling than it would have executing ops one-by-one.\n",
"\n",
"Avoid calling `jax.jit` inside loops. For most cases, JAX will be able to use the compiled, cached function in subsequent calls to `jax.jit`. However, because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined. This will cause unnecessary compilation each time in the loop:"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "6MDSXCfmSZVZ",
"outputId": "a035d0b7-6a4d-4a9e-c6b4-7521970829fc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"jit called in a loop with partials:\n",
"1 loop, best of 5: 192 ms per loop\n",
"jit called in a loop with lambdas:\n",
"10 loops, best of 5: 199 ms per loop\n",
"jit called in a loop with caching:\n",
"10 loops, best of 5: 21.6 ms per loop\n"
]
}
],
"source": [
"from functools import partial\n",
"\n",
"def unjitted_loop_body(prev_i):\n",
" return prev_i + 1\n",
"\n",
"def g_inner_jitted_partial(x, n):\n",
" i = 0\n",
" while i < n:\n",
" # Don't do this! each time the partial returns\n",
" # a function with different hash\n",
" i = jax.jit(partial(unjitted_loop_body))(i)\n",
" return x + i\n",
"\n",
"def g_inner_jitted_lambda(x, n):\n",
" i = 0\n",
" while i < n:\n",
" # Don't do this!, lambda will also return\n",
" # a function with a different hash\n",
" i = jax.jit(lambda x: unjitted_loop_body(x))(i)\n",
" return x + i\n",
"\n",
"def g_inner_jitted_normal(x, n):\n",
" i = 0\n",
" while i < n:\n",
" # this is OK, since JAX can find the\n",
" # cached, compiled function\n",
" i = jax.jit(unjitted_loop_body)(i)\n",
" return x + i\n",
"\n",
"print(\"jit called in a loop with partials:\")\n",
"%timeit g_inner_jitted_partial(10, 20).block_until_ready()\n",
"\n",
"print(\"jit called in a loop with lambdas:\")\n",
"%timeit g_inner_jitted_lambda(10, 20).block_until_ready()\n",
"\n",
"print(\"jit called in a loop with caching:\")\n",
"%timeit g_inner_jitted_normal(10, 20).block_until_ready()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Jitting functions in JAX",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -1,338 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "O-SkdlPxvETZ"}
# Just In Time Compilation with JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/02-jitting.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/02-jitting.ipynb)
*Authors: Rosalia Schneider & Vladimir Mikulik*
In this section, we will further explore how JAX works, and how we can make it performant.
We will discuss the `jax.jit()` transform, which will perform *Just In Time* (JIT) compilation
of a JAX Python function so it can be executed efficiently in XLA.
## How JAX transforms work
In the previous section, we discussed that JAX allows us to transform Python functions. This is done by first converting the Python function into a simple intermediate language called jaxpr. The transformations then work on the jaxpr representation.
We can show a representation of the jaxpr of a function by using `jax.make_jaxpr`:
```{code-cell} ipython3
:id: P9Xj77Wx3Z2P
:outputId: 5a0597eb-86c9-4762-ce10-2811debbc732
import jax
import jax.numpy as jnp
global_list = []
def log2(x):
global_list.append(x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2)(3.0))
```
+++ {"id": "jiDsT7y0RwIp"}
The [Understanding Jaxprs](https://jax.readthedocs.io/en/latest/jaxpr.html) section of the documentation provides more information on the meaning of the above output.
Importantly, note how the jaxpr does not capture the side-effect of the function: there is nothing in it corresponding to `global_list.append(x)`. This is a feature, not a bug: JAX is designed to understand side-effect-free (a.k.a. functionally pure) code. If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).
Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour once converted to jaxpr. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JAX-transformed function to run once (during the first call), and never again. This is because of the way that JAX generates jaxpr, using a process called 'tracing'.
When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.
Note: the Python `print()` function is not pure: the text output is a side-effect of the function. Therefore, any `print()` calls will only happen during tracing, and will not appear in the jaxpr:
```{code-cell} ipython3
:id: JxV2p7e2RawC
:outputId: 9dfe8a56-e553-4640-a04e-5405aea7832d
def log2_with_print(x):
print("printed x:", x)
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
print(jax.make_jaxpr(log2_with_print)(3.))
```
+++ {"id": "f6W_YYwRRwGp"}
See how the printed `x` is a `Traced` object? That's the JAX internals at work.
The fact that the Python code runs at least once is strictly an implementation detail, and so shouldn't be relied upon. However, it's useful to understand as you can use it when debugging to print out intermediate values of a computation.
+++ {"id": "PgVqi6NlRdWZ"}
A key thing to understand is that jaxpr captures the function as executed on the parameters given to it. For example, if we have a conditional, jaxpr will only know about the branch we take:
```{code-cell} ipython3
:id: hn0CuphEZKZm
:outputId: 99dae727-d2be-4577-831c-e1e14af5890a
def log2_if_rank_2(x):
if x.ndim == 2:
ln_x = jnp.log(x)
ln_2 = jnp.log(2.0)
return ln_x / ln_2
else:
return x
print(jax.make_jaxpr(log2_if_rank_2)(jax.numpy.array([1, 2, 3])))
```
+++ {"id": "Qp3WhqaqvHyD"}
## JIT compiling a function
As explained before, JAX enables operations to execute on CPU/GPU/TPU using the same code.
Let's look at an example of computing a *Scaled Exponential Linear Unit*
([SELU](https://proceedings.neurips.cc/paper/6698-self-normalizing-neural-networks.pdf)), an
operation commonly used in deep learning:
```{code-cell} ipython3
:id: JAXFYtlRvD6p
:outputId: e94d7dc2-a9a1-4ac2-fd3f-152e3f6d141b
import jax
import jax.numpy as jnp
def selu(x, alpha=1.67, lambda_=1.05):
return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()
```
+++ {"id": "ecN5lEXe6ncy"}
The code above is sending one operation at a time to the accelerator. This limits the ability of the XLA compiler to optimize our functions.
Naturally, what we want to do is give the XLA compiler as much code as possible, so it can fully optimize it. For this purpose, JAX provides the `jax.jit` transformation, which will JIT compile a JAX-compatible function. The example below shows how to use JIT to speed up the previous function.
```{code-cell} ipython3
:id: nJVEwPcH6bQX
:outputId: 289eb2f7-a5ce-4cec-f652-5c4e5b0b86cf
selu_jit = jax.jit(selu)
# Warm up
selu_jit(x).block_until_ready()
%timeit selu_jit(x).block_until_ready()
```
+++ {"id": "hMNKi1mYXQg5"}
Here's what just happened:
1) We defined `selu_jit` as the compiled version of `selu`.
2) We called `selu_jit` once on `x`. This is where JAX does its tracing -- it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Finally, the compiled code is executed to satisfy the call. Subsequent calls to `selu_jit` will use the compiled code directly, skipping the python implementation entirely.
(If we didn't include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn't be a fair comparison.)
3) We timed the execution speed of the compiled version. (Note the use of `block_until_ready()`, which is required due to JAX's [Asynchronous execution](https://jax.readthedocs.io/en/latest/async_dispatch.html) model).
+++ {"id": "DRJ6R6-d9Q_U"}
## Why can't we just JIT everything?
After going through the example above, you might be wondering whether we should simply apply `jax.jit` to every function. To understand why this is not the case, and when we should/shouldn't apply `jit`, let's first check some cases where JIT doesn't work.
```{code-cell} ipython3
:id: GO1Mwd_3_W6g
:outputId: a6fcf6d1-7bd6-4bb7-99c3-2a5a827183e2
:tags: [raises-exception]
# Condition on value of x.
def f(x):
if x > 0:
return x
else:
return 2 * x
f_jit = jax.jit(f)
f_jit(10) # Should raise an error.
```
```{code-cell} ipython3
:id: LHlipkIMFUhi
:outputId: 54935882-a180-45c0-ad03-9dfb5e3baa97
:tags: [raises-exception]
# While loop conditioned on x and n.
def g(x, n):
i = 0
while i < n:
i += 1
return x + i
g_jit = jax.jit(g)
g_jit(10, 20) # Should raise an error.
```
+++ {"id": "isz2U_XX_wH2"}
The problem is that we tried to condition on the *value* of an input to the function being jitted. The reason we can't do this is related to the fact mentioned above that jaxpr depends on the actual values used to trace it.
The more specific information about the values we use in the trace, the more we can use standard Python control flow to express ourselves. However, being too specific means we can't reuse the same traced function for other values. JAX solves this by tracing at different levels of abstraction for different purposes.
For `jax.jit`, the default level is `ShapedArray` -- that is, each tracer has a concrete shape (which we're allowed to condition on), but no concrete value. This allows the compiled function to work on all possible inputs with the same shape -- the standard use case in machine learning. However, because the tracers have no concrete value, if we attempt to condition on one, we get the error above.
In `jax.grad`, the constraints are more relaxed, so you can do more. If you compose several transformations, however, you must satisfy the constraints of the most strict one. So, if you `jit(grad(f))`, `f` mustn't condition on value. For more detail on the interaction between Python control flow and JAX, see [🔪 JAX - The Sharp Bits 🔪: Control Flow](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#control-flow).
One way to deal with this problem is to rewrite the code to avoid conditionals on value. Another is to use special [control flow operators](https://jax.readthedocs.io/en/latest/jax.lax.html#control-flow-operators) like `jax.lax.cond`. However, sometimes that is impossible. In that case, you can consider jitting only part of the function. For example, if the most computationally expensive part of the function is inside the loop, we can JIT just that inner part (though make sure to check the next section on caching to avoid shooting yourself in the foot):
```{code-cell} ipython3
:id: OeR8hF-NHAML
:outputId: d47fd6b2-8bbd-4939-a794-0b80183d3179
# While loop conditioned on x and n with a jitted body.
@jax.jit
def loop_body(prev_i):
return prev_i + 1
def g_inner_jitted(x, n):
i = 0
while i < n:
i = loop_body(i)
return x + i
g_inner_jitted(10, 20)
```
+++ {"id": "5XUT2acoHBz-"}
If we really need to JIT a function that has a condition on the value of an input, we can tell JAX to help itself to a less abstract tracer for a particular input by specifying `static_argnums` or `static_argnames`. The cost of this is that the resulting jaxpr is less flexible, so JAX will have to re-compile the function for every new value of the specified static input. It is only a good strategy if the function is guaranteed to get limited different values.
```{code-cell} ipython3
:id: 2yQmQTDNAenY
:outputId: c48f07b8-c3f9-4d2a-9dfd-663838a52511
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))
```
```{code-cell} ipython3
:id: R4SXUEu-M-u1
:outputId: 9e712e14-4e81-4744-dcf2-a10f470d9121
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))
```
To specify such arguments when using `jit` as a decorator, a common pattern is to use python's `functools.partial`:
```{code-cell} ipython3
:id: 2X5rR4jkIO
:outputId: 81-4744-dc2e4-4e10f470f2-a19e71d9121
from functools import partial
@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
i = 0
while i < n:
i += 1
return x + i
print(g_jit_decorated(10, 20))
```
+++ {"id": "LczjIBt2X2Ms"}
## When to use JIT
In many of the examples above, jitting is not worth it:
```{code-cell} ipython3
:id: uMOqsNnqYApD
:outputId: 2d6c5122-43ad-4257-e56b-e77c889131c2
print("g jitted:")
%timeit g_jit_correct(10, 20).block_until_ready()
print("g:")
%timeit g(10, 20)
```
+++ {"id": "cZmGYq80YP0j"}
This is because `jax.jit` introduces some overhead itself. Therefore, it usually only saves time if the compiled function is complex and you will run it numerous times. Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.
Generally, you want to jit the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise.
+++ {"id": "hJMjUlRcIzVS"}
## Caching
It's important to understand the caching behaviour of `jax.jit`.
Suppose I define `f = jax.jit(g)`. When I first invoke `f`, it will get compiled, and the resulting XLA code will get cached. Subsequent calls of `f` will reuse the cached code. This is how `jax.jit` makes up for the up-front cost of compilation.
If I specify `static_argnums`, then the cached code will be used only for the same values of arguments labelled as static. If any of them change, recompilation occurs. If there are many values, then your program might spend more time compiling than it would have executing ops one-by-one.
Avoid calling `jax.jit` inside loops. For most cases, JAX will be able to use the compiled, cached function in subsequent calls to `jax.jit`. However, because the cache relies on the hash of the function, it becomes problematic when equivalent functions are redefined. This will cause unnecessary compilation each time in the loop:
```{code-cell} ipython3
:id: 6MDSXCfmSZVZ
:outputId: a035d0b7-6a4d-4a9e-c6b4-7521970829fc
from functools import partial
def unjitted_loop_body(prev_i):
return prev_i + 1
def g_inner_jitted_partial(x, n):
i = 0
while i < n:
# Don't do this! each time the partial returns
# a function with different hash
i = jax.jit(partial(unjitted_loop_body))(i)
return x + i
def g_inner_jitted_lambda(x, n):
i = 0
while i < n:
# Don't do this!, lambda will also return
# a function with a different hash
i = jax.jit(lambda x: unjitted_loop_body(x))(i)
return x + i
def g_inner_jitted_normal(x, n):
i = 0
while i < n:
# this is OK, since JAX can find the
# cached, compiled function
i = jax.jit(unjitted_loop_body)(i)
return x + i
print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()
print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()
print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()
```

View File

@ -1,369 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "zMIrmiaZxiJC"
},
"source": [
"# Automatic Vectorization in JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb)\n",
"\n",
"*Authors: Matteo Hessel*\n",
"\n",
"In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Kw-_imBrx4nN"
},
"source": [
"## Manual Vectorization\n",
"\n",
"Consider the following simple code that computes the convolution of two one-dimensional vectors:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "5Obro91lwE_s",
"outputId": "061983c6-2faa-4a54-83a5-d2a823f61087"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([11., 20., 29.], dtype=float32)"
]
},
"execution_count": 1,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"import jax\n",
"import jax.numpy as jnp\n",
"\n",
"x = jnp.arange(5)\n",
"w = jnp.array([2., 3., 4.])\n",
"\n",
"def convolve(x, w):\n",
" output = []\n",
" for i in range(1, len(x)-1):\n",
" output.append(jnp.dot(x[i-1:i+2], w))\n",
" return jnp.array(output)\n",
"\n",
"convolve(x, w)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "z_nPhEhLRysk"
},
"source": [
"Suppose we would like to apply this function to a batch of weights `w` to a batch of vectors `x`."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "rHQJnnrVUbxE"
},
"outputs": [],
"source": [
"xs = jnp.stack([x, x])\n",
"ws = jnp.stack([w, w])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ghaJQW1aUfPi"
},
"source": [
"The most naive option would be to simply loop over the batch in Python:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "yM-IycdlzGyJ",
"outputId": "07ed6ffc-0265-45ef-d585-4b5fa7d221f1"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 10,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def manually_batched_convolve(xs, ws):\n",
" output = []\n",
" for i in range(xs.shape[0]):\n",
" output.append(convolve(xs[i], ws[i]))\n",
" return jnp.stack(output)\n",
"\n",
"manually_batched_convolve(xs, ws)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VTh0l_1SUlh4"
},
"source": [
"This produces the correct result, however it is not very efficient.\n",
"\n",
"In order to batch the computation efficiently, you would normally have to rewrite the function manually to ensure it is done in vectorized form. This is not particularly difficult to implement, but does involve changing how the function treats indices, axes, and other parts of the input.\n",
"\n",
"For example, we could manually rewrite `convolve()` to support vectorized computation across the batch dimension as follows:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "I4Wd9nrcTRRL",
"outputId": "0b037b43-7b41-4625-f9e0-a6e0dbc4c65a"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 5,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def manually_vectorized_convolve(xs, ws):\n",
" output = []\n",
" for i in range(1, xs.shape[-1] -1):\n",
" output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))\n",
" return jnp.stack(output, axis=1)\n",
"\n",
"manually_vectorized_convolve(xs, ws)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "DW-RJ2Zs2QVu"
},
"source": [
"Such re-implementation is messy and error-prone; fortunately JAX provides another way."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "2oVLanQmUAo_"
},
"source": [
"## Automatic Vectorization\n",
"\n",
"In JAX, the `jax.vmap` transformation is designed to generate such a vectorized implementation of a function automatically:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "Brl-BoTqSQDw",
"outputId": "af608dbb-27f2-4fbc-f225-79f3101b13ff"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 6,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"auto_batch_convolve = jax.vmap(convolve)\n",
"\n",
"auto_batch_convolve(xs, ws)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7aVAy7332lFj"
},
"source": [
"It does this by tracing the function similarly to `jax.jit`, and automatically adding batch axes at the beginning of each input.\n",
"\n",
"If the batch dimension is not the first, you may use the `in_axes` and `out_axes` arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "_VEEm1CGT2n0",
"outputId": "751e0fbf-bdfb-41df-9436-4da5de23123f"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[11., 11.],\n",
" [20., 20.],\n",
" [29., 29.]], dtype=float32)"
]
},
"execution_count": 7,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)\n",
"\n",
"xst = jnp.transpose(xs)\n",
"wst = jnp.transpose(ws)\n",
"\n",
"auto_batch_convolve_v2(xst, wst)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-gNiLuxzSX32"
},
"source": [
"`jax.vmap` also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weights `w` with a batch of vectors `x`; in this case the `in_axes` argument can be set to `None`:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "2s2YDsamSxki",
"outputId": "5c70879b-5cce-4549-e38a-f45dbe663ab2"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 8,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])\n",
"\n",
"batch_convolve_v3(xs, w)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bsxT4hA6RTCG"
},
"source": [
"## Combining transformations\n",
"\n",
"As with all JAX transformations, `jax.jit` and `jax.vmap` are designed to be composable, which means you can wrap a vmapped function with `jit`, or a JITted function with `vmap`, and everything will work correctly:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "gsC-Myg0RVdj",
"outputId": "cbdd384e-6633-4cea-b1a0-a01ad934a768"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[11., 20., 29.],\n",
" [11., 20., 29.]], dtype=float32)"
]
},
"execution_count": 9,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"jitted_batch_convolve = jax.jit(auto_batch_convolve)\n",
"\n",
"jitted_batch_convolve(xs, ws)"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Vectorization in JAX",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -1,161 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "zMIrmiaZxiJC"}
# Automatic Vectorization in JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb)
*Authors: Matteo Hessel*
In the previous section we discussed JIT compilation via the `jax.jit` function. This notebook discusses another of JAX's transforms: vectorization via `jax.vmap`.
+++ {"id": "Kw-_imBrx4nN"}
## Manual Vectorization
Consider the following simple code that computes the convolution of two one-dimensional vectors:
```{code-cell} ipython3
:id: 5Obro91lwE_s
:outputId: 061983c6-2faa-4a54-83a5-d2a823f61087
import jax
import jax.numpy as jnp
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
```
+++ {"id": "z_nPhEhLRysk"}
Suppose we would like to apply this function to a batch of weights `w` to a batch of vectors `x`.
```{code-cell} ipython3
:id: rHQJnnrVUbxE
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])
```
+++ {"id": "ghaJQW1aUfPi"}
The most naive option would be to simply loop over the batch in Python:
```{code-cell} ipython3
:id: yM-IycdlzGyJ
:outputId: 07ed6ffc-0265-45ef-d585-4b5fa7d221f1
def manually_batched_convolve(xs, ws):
output = []
for i in range(xs.shape[0]):
output.append(convolve(xs[i], ws[i]))
return jnp.stack(output)
manually_batched_convolve(xs, ws)
```
+++ {"id": "VTh0l_1SUlh4"}
This produces the correct result, however it is not very efficient.
In order to batch the computation efficiently, you would normally have to rewrite the function manually to ensure it is done in vectorized form. This is not particularly difficult to implement, but does involve changing how the function treats indices, axes, and other parts of the input.
For example, we could manually rewrite `convolve()` to support vectorized computation across the batch dimension as follows:
```{code-cell} ipython3
:id: I4Wd9nrcTRRL
:outputId: 0b037b43-7b41-4625-f9e0-a6e0dbc4c65a
def manually_vectorized_convolve(xs, ws):
output = []
for i in range(1, xs.shape[-1] -1):
output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
return jnp.stack(output, axis=1)
manually_vectorized_convolve(xs, ws)
```
+++ {"id": "DW-RJ2Zs2QVu"}
Such re-implementation is messy and error-prone; fortunately JAX provides another way.
+++ {"id": "2oVLanQmUAo_"}
## Automatic Vectorization
In JAX, the `jax.vmap` transformation is designed to generate such a vectorized implementation of a function automatically:
```{code-cell} ipython3
:id: Brl-BoTqSQDw
:outputId: af608dbb-27f2-4fbc-f225-79f3101b13ff
auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)
```
+++ {"id": "7aVAy7332lFj"}
It does this by tracing the function similarly to `jax.jit`, and automatically adding batch axes at the beginning of each input.
If the batch dimension is not the first, you may use the `in_axes` and `out_axes` arguments to specify the location of the batch dimension in inputs and outputs. These may be an integer if the batch axis is the same for all inputs and outputs, or lists, otherwise.
```{code-cell} ipython3
:id: _VEEm1CGT2n0
:outputId: 751e0fbf-bdfb-41df-9436-4da5de23123f
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)
xst = jnp.transpose(xs)
wst = jnp.transpose(ws)
auto_batch_convolve_v2(xst, wst)
```
+++ {"id": "-gNiLuxzSX32"}
`jax.vmap` also supports the case where only one of the arguments is batched: for example, if you would like to convolve to a single set of weights `w` with a batch of vectors `x`; in this case the `in_axes` argument can be set to `None`:
```{code-cell} ipython3
:id: 2s2YDsamSxki
:outputId: 5c70879b-5cce-4549-e38a-f45dbe663ab2
batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])
batch_convolve_v3(xs, w)
```
+++ {"id": "bsxT4hA6RTCG"}
## Combining transformations
As with all JAX transformations, `jax.jit` and `jax.vmap` are designed to be composable, which means you can wrap a vmapped function with `jit`, or a JITted function with `vmap`, and everything will work correctly:
```{code-cell} ipython3
:id: gsC-Myg0RVdj
:outputId: cbdd384e-6633-4cea-b1a0-a01ad934a768
jitted_batch_convolve = jax.jit(auto_batch_convolve)
jitted_batch_convolve(xs, ws)
```

View File

@ -1,738 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "kORMl5KmfByI"
},
"source": [
"# Advanced Automatic Differentiation in JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb)\n",
"\n",
"*Authors: Vlatimir Mikulik & Matteo Hessel*\n",
"\n",
"Computing gradients is a critical part of modern machine learning methods. This section considers a few advanced topics in the areas of automatic differentiation as it relates to modern machine learning.\n",
"\n",
"While understanding how automatic differentiation works under the hood isn't crucial for using JAX in most contexts, we encourage the reader to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on.\n",
"\n",
"[The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) is a more advanced and more detailed explanation of how these ideas are implemented in the JAX backend. It's not necessary to understand this to do most things in JAX. However, some features (like defining [custom derivatives](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)) depend on understanding this, so it's worth knowing this explanation exists if you ever need to use them."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qx50CO1IorCc"
},
"source": [
"## Higher-order derivatives\n",
"\n",
"JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations.\n",
"\n",
"We illustrate this in the single-variable case:\n",
"\n",
"The derivative of $f(x) = x^3 + 2x^2 - 3x + 1$ can be computed as:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "Kqsbj98UTVdi"
},
"outputs": [],
"source": [
"import jax\n",
"\n",
"f = lambda x: x**3 + 2*x**2 - 3*x + 1\n",
"\n",
"dfdx = jax.grad(f)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ItEt15OGiiAF"
},
"source": [
"The higher-order derivatives of $f$ are:\n",
"\n",
"$$\n",
"\\begin{array}{l}\n",
"f'(x) = 3x^2 + 4x -3\\\\\n",
"f''(x) = 6x + 4\\\\\n",
"f'''(x) = 6\\\\\n",
"f^{iv}(x) = 0\n",
"\\end{array}\n",
"$$\n",
"\n",
"Computing any of these in JAX is as easy as chaining the `grad` function:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "5X3yQqLgimqH"
},
"outputs": [],
"source": [
"d2fdx = jax.grad(dfdx)\n",
"d3fdx = jax.grad(d2fdx)\n",
"d4fdx = jax.grad(d3fdx)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fVL2P_pcj8T1"
},
"source": [
"Evaluating the above in $x=1$ would give us:\n",
"\n",
"$$\n",
"\\begin{array}{l}\n",
"f'(1) = 4\\\\\n",
"f''(1) = 10\\\\\n",
"f'''(1) = 6\\\\\n",
"f^{iv}(1) = 0\n",
"\\end{array}\n",
"$$\n",
"\n",
"Using JAX:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "tJkIp9wFjxL3",
"outputId": "581ecf87-2d20-4c83-9443-5befc1baf51d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"4.0\n",
"10.0\n",
"6.0\n",
"0.0\n"
]
}
],
"source": [
"print(dfdx(1.))\n",
"print(d2fdx(1.))\n",
"print(d3fdx(1.))\n",
"print(d4fdx(1.))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3-fTelU7LHRr"
},
"source": [
"In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to\n",
"\n",
"$$(\\mathbf{H}f)_{i,j} = \\frac{\\partial^2 f}{\\partial_i\\partial_j}.$$\n",
"\n",
"The Hessian of a real-valued function of several variables, $f: \\mathbb R^n\\to\\mathbb R$, can be identified with the Jacobian of its gradient. JAX provides two transformations for computing the Jacobian of a function, `jax.jacfwd` and `jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances see the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY) linked above for an explanation."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "ILhkef1rOB6_"
},
"outputs": [],
"source": [
"def hessian(f):\n",
" return jax.jacfwd(jax.grad(f))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "xaENwADXOGf_"
},
"source": [
"Let's double check this is correct on the dot-product $f: \\mathbf{x} \\mapsto \\mathbf{x} ^\\top \\mathbf{x}$.\n",
"\n",
"if $i=j$, $\\frac{\\partial^2 f}{\\partial_i\\partial_j}(\\mathbf{x}) = 2$. Otherwise, $\\frac{\\partial^2 f}{\\partial_i\\partial_j}(\\mathbf{x}) = 0$."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "Xm3A0QdWRdJl",
"outputId": "e1e8cba9-b567-439b-b8fc-34b21497e67f"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[2., 0., 0.],\n",
" [0., 2., 0.],\n",
" [0., 0., 2.]], dtype=float32)"
]
},
"execution_count": 6,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"import jax.numpy as jnp\n",
"\n",
"def f(x):\n",
" return jnp.dot(x, x)\n",
"\n",
"hessian(f)(jnp.array([1., 2., 3.]))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "7_gbi34WSUsD"
},
"source": [
"Often, however, we aren't interested in computing the full Hessian itself, and doing so can be very inefficient. [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) explains some tricks, like the Hessian-vector product, that allow to use it without materialising the whole matrix.\n",
"\n",
"If you plan to work with higher-order derivatives in JAX, we strongly recommend reading the Autodiff Cookbook."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zMT2qAi-SvcK"
},
"source": [
"## Higher order optimization\n",
"\n",
"Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier:\n",
"\n",
"```python\n",
"def meta_loss_fn(params, data):\n",
" \"\"\"Computes the loss after one step of SGD.\"\"\"\n",
" grads = jax.grad(loss_fn)(params, data)\n",
" return loss_fn(params - lr * grads, data)\n",
"\n",
"meta_grads = jax.grad(meta_loss_fn)(params, data)\n",
"```"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "3h9Aj3YyuL6P"
},
"source": [
"## Stopping gradients\n",
"\n",
"Auto-diff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, we might want some additional control: for instance, we might want to avoid back-propagating gradients through some subset of the computational graph.\n",
"\n",
"Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "fjLqbCb6SiOm"
},
"outputs": [],
"source": [
"# Value function and initial parameters\n",
"value_fn = lambda theta, state: jnp.dot(theta, state)\n",
"theta = jnp.array([0.1, -0.1, 0.])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "85S7HBo1tBzt"
},
"source": [
"Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which we observed the reward $r_t$"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "T6cRPau6tCSE"
},
"outputs": [],
"source": [
"# An example transition.\n",
"s_tm1 = jnp.array([1., 2., -1.])\n",
"r_t = jnp.array(1.)\n",
"s_t = jnp.array([2., 1., 0.])"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QO5CHA9_Sk01"
},
"source": [
"The TD(0) update to the network parameters is:\n",
"\n",
"$$\n",
"\\Delta \\theta = (r_t + v_{\\theta}(s_t) - v_{\\theta}(s_{t-1})) \\nabla v_{\\theta}(s_{t-1})\n",
"$$\n",
"\n",
"This update is not the gradient of any loss function.\n",
"\n",
"However, it can be **written** as the gradient of the pseudo loss function\n",
"\n",
"$$\n",
"L(\\theta) = - \\frac{1}{2} [r_t + v_{\\theta}(s_t) - v_{\\theta}(s_{t-1})]^2\n",
"$$\n",
"\n",
"if the dependency of the target $r_t + v_{\\theta}(s_t)$ on the parameter $\\theta$ is ignored.\n",
"\n",
"How can we implement this in JAX? If we write the pseudo loss naively we get:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "uMcFny2xuOwz",
"outputId": "79c10af9-10b8-4e18-9753-a53918b9d72d"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([ -1.2, 1.2, -1.2], dtype=float32)"
]
},
"execution_count": 9,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def td_loss(theta, s_tm1, r_t, s_t):\n",
" v_tm1 = value_fn(theta, s_tm1)\n",
" target = r_t + value_fn(theta, s_t)\n",
" return -0.5 * ((target - v_tm1) ** 2)\n",
"\n",
"td_update = jax.grad(td_loss)\n",
"delta_theta = td_update(theta, s_tm1, r_t, s_t)\n",
"\n",
"delta_theta"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "CPnjm59GG4Gq"
},
"source": [
"But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\\theta$.\n",
"\n",
"We can use `jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\\theta$:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "MKeq7trKPS4V",
"outputId": "0f27d754-a871-4c47-8e3a-a961418a24cc"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([1.2, 2.4, -1.2], dtype=float32)"
]
},
"execution_count": 10,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"def td_loss(theta, s_tm1, r_t, s_t):\n",
" v_tm1 = value_fn(theta, s_tm1)\n",
" target = r_t + value_fn(theta, s_t)\n",
" return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)\n",
"\n",
"td_update = jax.grad(td_loss)\n",
"delta_theta = td_update(theta, s_tm1, r_t, s_t)\n",
"\n",
"delta_theta"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JOnjm59GG4Gq"
},
"source": [
"This will treat `target` as if it did **not** depend on the parameters $\\theta$ and compute the correct update to the parameters.\n",
"\n",
"Now, let's also calculate $\\Delta \\theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using jax.grad and your knowledge so far. Here's our solution:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "WCeq7trKPS4V",
"outputId": "0f19d754-a871-4c47-8e3a-a961418a24cc"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([ 1.2, 2.4, -1.2], dtype=float32)"
]
},
"execution_count": 11,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"s_grad = jax.grad(value_fn)(theta, s_tm1)\n",
"delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad\n",
"\n",
"delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TNF0CkwOTKpD"
},
"source": [
"`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "UMY0IyuOTKpG"
},
"source": [
"## Straight-through estimator using `stop_gradient`\n",
"\n",
"The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \\mathbb{R}^n \\to \\mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "hdORJENmVHvX",
"outputId": "f0839541-46a4-45a9-fce7-ead08f20046b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"f(x): 3.0\n",
"straight_through_f(x): 3.0\n",
"grad(f)(x): 0.0\n",
"grad(straight_through_f)(x): 1.0\n"
]
}
],
"source": [
"def f(x):\n",
" return jnp.round(x) # non-differentiable\n",
"\n",
"def straight_through_f(x):\n",
" # Create an exactly-zero expression with Sterbenz lemma that has\n",
" # an exactly-one gradient.\n",
" zero = x - jax.lax.stop_gradient(x)\n",
" return zero + jax.lax.stop_gradient(f(x))\n",
"\n",
"print(\"f(x): \", f(3.2))\n",
"print(\"straight_through_f(x):\", straight_through_f(3.2))\n",
"\n",
"print(\"grad(f)(x):\", jax.grad(f)(3.2))\n",
"print(\"grad(straight_through_f)(x):\", jax.grad(straight_through_f)(3.2))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Wx3RNE0Sw5mn"
},
"source": [
"## Per-example gradients\n",
"\n",
"While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch.\n",
"\n",
"For instance, this is needed to prioritise data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis.\n",
"\n",
"In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient.\n",
"\n",
"In JAX we can define the code to compute the gradient per-sample in an easy but efficient way.\n",
"\n",
"Just combine the `jit`, `vmap` and `grad` transformations together:"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "tFLyd9ifw4GG",
"outputId": "bf3ad4a3-102d-47a6-ece0-f4a8c9e5d434"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[1.2, 2.4, -1.2],\n",
" [1.2, 2.4, -1.2]], dtype=float32)"
]
},
"execution_count": 13,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))\n",
"\n",
"# Test it:\n",
"batched_s_tm1 = jnp.stack([s_tm1, s_tm1])\n",
"batched_r_t = jnp.stack([r_t, r_t])\n",
"batched_s_t = jnp.stack([s_t, s_t])\n",
"\n",
"perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "VxvYVEYQYiS_"
},
"source": [
"Let's walk through this one transformation at a time.\n",
"\n",
"First, we apply `jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs:"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "rPO67QQrY5Bk",
"outputId": "fbb45b98-2dbf-4865-e6e5-87dc3eef5560"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([1.2, 2.4, -1.2], dtype=float32)"
]
},
"execution_count": 14,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"dtdloss_dtheta = jax.grad(td_loss)\n",
"\n",
"dtdloss_dtheta(theta, s_tm1, r_t, s_t)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cU36nVAlcnJ0"
},
"source": [
"This function computes one row of the array above."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "c6DQF0b3ZA5u"
},
"source": [
"Then, we vectorise this function using `jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, we produce a batch of outputs -- each output in the batch corresponds to the gradient for the corresponding member of the input batch."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "5agbNKavaNDM",
"outputId": "ab081012-88ab-4904-a367-68e9f81445f0"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[1.2, 2.4, -1.2],\n",
" [1.2, 2.4, -1.2]], dtype=float32)"
]
},
"execution_count": 15,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"almost_perex_grads = jax.vmap(dtdloss_dtheta)\n",
"\n",
"batched_theta = jnp.stack([theta, theta])\n",
"almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "K-v34yLuan7k"
},
"source": [
"This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the `jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want:"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"id": "S6kd5MujbGrr",
"outputId": "d3d731ef-3f7d-4a0a-ce91-7df57627ddbd"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[1.2, 2.4, -1.2],\n",
" [1.2, 2.4, -1.2]], dtype=float32)"
]
},
"execution_count": 16,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))\n",
"\n",
"inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "O0hbsm70be5T"
},
"source": [
"Almost there! This does what we want, but is slower than it has to be. Now, we wrap the whole thing in a `jax.jit` to get the compiled, efficient version of the same function:"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"id": "Fvr709FcbrSW",
"outputId": "627db899-5620-4bed-8d34-cd1364d3d187"
},
"outputs": [
{
"data": {
"text/plain": [
"Array([[1.2, 2.4, -1.2],\n",
" [1.2, 2.4, -1.2]], dtype=float32)"
]
},
"execution_count": 17,
"metadata": {
"tags": []
},
"output_type": "execute_result"
}
],
"source": [
"perex_grads = jax.jit(inefficient_perex_grads)\n",
"\n",
"perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"id": "FH42yzbHcNs2",
"outputId": "c8e52f93-615a-4ce7-d8ab-fb6215995a39"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"100 loops, best of 5: 7.74 ms per loop\n",
"10000 loops, best of 5: 86.2 µs per loop\n"
]
}
],
"source": [
"%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()\n",
"%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "Advanced Grads",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -1,374 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "kORMl5KmfByI"}
# Advanced Automatic Differentiation in JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/04-advanced-autodiff.ipynb)
*Authors: Vlatimir Mikulik & Matteo Hessel*
Computing gradients is a critical part of modern machine learning methods. This section considers a few advanced topics in the areas of automatic differentiation as it relates to modern machine learning.
While understanding how automatic differentiation works under the hood isn't crucial for using JAX in most contexts, we encourage the reader to check out this quite accessible [video](https://www.youtube.com/watch?v=wG_nF1awSSY) to get a deeper sense of what's going on.
[The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) is a more advanced and more detailed explanation of how these ideas are implemented in the JAX backend. It's not necessary to understand this to do most things in JAX. However, some features (like defining [custom derivatives](https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html)) depend on understanding this, so it's worth knowing this explanation exists if you ever need to use them.
+++ {"id": "qx50CO1IorCc"}
## Higher-order derivatives
JAX's autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations.
We illustrate this in the single-variable case:
The derivative of $f(x) = x^3 + 2x^2 - 3x + 1$ can be computed as:
```{code-cell} ipython3
:id: Kqsbj98UTVdi
import jax
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
```
+++ {"id": "ItEt15OGiiAF"}
The higher-order derivatives of $f$ are:
$$
\begin{array}{l}
f'(x) = 3x^2 + 4x -3\\
f''(x) = 6x + 4\\
f'''(x) = 6\\
f^{iv}(x) = 0
\end{array}
$$
Computing any of these in JAX is as easy as chaining the `grad` function:
```{code-cell} ipython3
:id: 5X3yQqLgimqH
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)
```
+++ {"id": "fVL2P_pcj8T1"}
Evaluating the above in $x=1$ would give us:
$$
\begin{array}{l}
f'(1) = 4\\
f''(1) = 10\\
f'''(1) = 6\\
f^{iv}(1) = 0
\end{array}
$$
Using JAX:
```{code-cell} ipython3
:id: tJkIp9wFjxL3
:outputId: 581ecf87-2d20-4c83-9443-5befc1baf51d
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
```
+++ {"id": "3-fTelU7LHRr"}
In the multivariable case, higher-order derivatives are more complicated. The second-order derivative of a function is represented by its [Hessian matrix](https://en.wikipedia.org/wiki/Hessian_matrix), defined according to
$$(\mathbf{H}f)_{i,j} = \frac{\partial^2 f}{\partial_i\partial_j}.$$
The Hessian of a real-valued function of several variables, $f: \mathbb R^n\to\mathbb R$, can be identified with the Jacobian of its gradient. JAX provides two transformations for computing the Jacobian of a function, `jax.jacfwd` and `jax.jacrev`, corresponding to forward- and reverse-mode autodiff. They give the same answer, but one can be more efficient than the other in different circumstances see the [video about autodiff](https://www.youtube.com/watch?v=wG_nF1awSSY) linked above for an explanation.
```{code-cell} ipython3
:id: ILhkef1rOB6_
def hessian(f):
return jax.jacfwd(jax.grad(f))
```
+++ {"id": "xaENwADXOGf_"}
Let's double check this is correct on the dot-product $f: \mathbf{x} \mapsto \mathbf{x} ^\top \mathbf{x}$.
if $i=j$, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 2$. Otherwise, $\frac{\partial^2 f}{\partial_i\partial_j}(\mathbf{x}) = 0$.
```{code-cell} ipython3
:id: Xm3A0QdWRdJl
:outputId: e1e8cba9-b567-439b-b8fc-34b21497e67f
import jax.numpy as jnp
def f(x):
return jnp.dot(x, x)
hessian(f)(jnp.array([1., 2., 3.]))
```
+++ {"id": "7_gbi34WSUsD"}
Often, however, we aren't interested in computing the full Hessian itself, and doing so can be very inefficient. [The Autodiff Cookbook](https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html) explains some tricks, like the Hessian-vector product, that allow to use it without materialising the whole matrix.
If you plan to work with higher-order derivatives in JAX, we strongly recommend reading the Autodiff Cookbook.
+++ {"id": "zMT2qAi-SvcK"}
## Higher order optimization
Some meta-learning techniques, such as Model-Agnostic Meta-Learning ([MAML](https://arxiv.org/abs/1703.03400)), require differentiating through gradient updates. In other frameworks this can be quite cumbersome, but in JAX it's much easier:
```python
def meta_loss_fn(params, data):
"""Computes the loss after one step of SGD."""
grads = jax.grad(loss_fn)(params, data)
return loss_fn(params - lr * grads, data)
meta_grads = jax.grad(meta_loss_fn)(params, data)
```
+++ {"id": "3h9Aj3YyuL6P"}
## Stopping gradients
Auto-diff enables automatic computation of the gradient of a function with respect to its inputs. Sometimes, however, we might want some additional control: for instance, we might want to avoid back-propagating gradients through some subset of the computational graph.
Consider for instance the TD(0) ([temporal difference](https://en.wikipedia.org/wiki/Temporal_difference_learning)) reinforcement learning update. This is used to learn to estimate the *value* of a state in an environment from experience of interacting with the environment. Let's assume the value estimate $v_{\theta}(s_{t-1}$) in a state $s_{t-1}$ is parameterised by a linear function.
```{code-cell} ipython3
:id: fjLqbCb6SiOm
# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])
```
+++ {"id": "85S7HBo1tBzt"}
Consider a transition from a state $s_{t-1}$ to a state $s_t$ during which we observed the reward $r_t$
```{code-cell} ipython3
:id: T6cRPau6tCSE
# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])
```
+++ {"id": "QO5CHA9_Sk01"}
The TD(0) update to the network parameters is:
$$
\Delta \theta = (r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})) \nabla v_{\theta}(s_{t-1})
$$
This update is not the gradient of any loss function.
However, it can be **written** as the gradient of the pseudo loss function
$$
L(\theta) = - \frac{1}{2} [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2
$$
if the dependency of the target $r_t + v_{\theta}(s_t)$ on the parameter $\theta$ is ignored.
How can we implement this in JAX? If we write the pseudo loss naively we get:
```{code-cell} ipython3
:id: uMcFny2xuOwz
:outputId: 79c10af9-10b8-4e18-9753-a53918b9d72d
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return -0.5 * ((target - v_tm1) ** 2)
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
```
+++ {"id": "CPnjm59GG4Gq"}
But `td_update` will **not** compute a TD(0) update, because the gradient computation will include the dependency of `target` on $\theta$.
We can use `jax.lax.stop_gradient` to force JAX to ignore the dependency of the target on $\theta$:
```{code-cell} ipython3
:id: MKeq7trKPS4V
:outputId: 0f27d754-a871-4c47-8e3a-a961418a24cc
def td_loss(theta, s_tm1, r_t, s_t):
v_tm1 = value_fn(theta, s_tm1)
target = r_t + value_fn(theta, s_t)
return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)
delta_theta
```
+++ {"id": "JOnjm59GG4Gq"}
This will treat `target` as if it did **not** depend on the parameters $\theta$ and compute the correct update to the parameters.
Now, let's also calculate $\Delta \theta$ using the original TD(0) update expression, to cross-check our work. You may wish to try and implement this yourself using jax.grad and your knowledge so far. Here's our solution:
```{code-cell} ipython3
:id: WCeq7trKPS4V
:outputId: 0f19d754-a871-4c47-8e3a-a961418a24cc
s_grad = jax.grad(value_fn)(theta, s_tm1)
delta_theta_original_calculation = (r_t + value_fn(theta, s_t) - value_fn(theta, s_tm1)) * s_grad
delta_theta_original_calculation # [1.2, 2.4, -1.2], same as `delta_theta`
```
+++ {"id": "TNF0CkwOTKpD"}
`jax.lax.stop_gradient` may also be useful in other settings, for instance if you want the gradient from some loss to only affect a subset of the parameters of the neural network (because, for instance, the other parameters are trained using a different loss).
+++ {"id": "UMY0IyuOTKpG"}
## Straight-through estimator using `stop_gradient`
The straight-through estimator is a trick for defining a 'gradient' of a function that is otherwise non-differentiable. Given a non-differentiable function $f : \mathbb{R}^n \to \mathbb{R}^n$ that is used as part of a larger function that we wish to find a gradient of, we simply pretend during the backward pass that $f$ is the identity function. This can be implemented neatly using `jax.lax.stop_gradient`:
```{code-cell} ipython3
:id: hdORJENmVHvX
:outputId: f0839541-46a4-45a9-fce7-ead08f20046b
def f(x):
return jnp.round(x) # non-differentiable
def straight_through_f(x):
# Create an exactly-zero expression with Sterbenz lemma that has
# an exactly-one gradient.
zero = x - jax.lax.stop_gradient(x)
return zero + jax.lax.stop_gradient(f(x))
print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))
print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))
```
+++ {"id": "Wx3RNE0Sw5mn"}
## Per-example gradients
While most ML systems compute gradients and updates from batches of data, for reasons of computational efficiency and/or variance reduction, it is sometimes necessary to have access to the gradient/update associated with each specific sample in the batch.
For instance, this is needed to prioritise data based on gradient magnitude, or to apply clipping / normalisations on a sample by sample basis.
In many frameworks (PyTorch, TF, Theano) it is often not trivial to compute per-example gradients, because the library directly accumulates the gradient over the batch. Naive workarounds, such as computing a separate loss per example and then aggregating the resulting gradients are typically very inefficient.
In JAX we can define the code to compute the gradient per-sample in an easy but efficient way.
Just combine the `jit`, `vmap` and `grad` transformations together:
```{code-cell} ipython3
:id: tFLyd9ifw4GG
:outputId: bf3ad4a3-102d-47a6-ece0-f4a8c9e5d434
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))
# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
```
+++ {"id": "VxvYVEYQYiS_"}
Let's walk through this one transformation at a time.
First, we apply `jax.grad` to `td_loss` to obtain a function that computes the gradient of the loss w.r.t. the parameters on single (unbatched) inputs:
```{code-cell} ipython3
:id: rPO67QQrY5Bk
:outputId: fbb45b98-2dbf-4865-e6e5-87dc3eef5560
dtdloss_dtheta = jax.grad(td_loss)
dtdloss_dtheta(theta, s_tm1, r_t, s_t)
```
+++ {"id": "cU36nVAlcnJ0"}
This function computes one row of the array above.
+++ {"id": "c6DQF0b3ZA5u"}
Then, we vectorise this function using `jax.vmap`. This adds a batch dimension to all inputs and outputs. Now, given a batch of inputs, we produce a batch of outputs -- each output in the batch corresponds to the gradient for the corresponding member of the input batch.
```{code-cell} ipython3
:id: 5agbNKavaNDM
:outputId: ab081012-88ab-4904-a367-68e9f81445f0
almost_perex_grads = jax.vmap(dtdloss_dtheta)
batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)
```
+++ {"id": "K-v34yLuan7k"}
This isn't quite what we want, because we have to manually feed this function a batch of `theta`s, whereas we actually want to use a single `theta`. We fix this by adding `in_axes` to the `jax.vmap`, specifying theta as `None`, and the other args as `0`. This makes the resulting function add an extra axis only to the other arguments, leaving `theta` unbatched, as we want:
```{code-cell} ipython3
:id: S6kd5MujbGrr
:outputId: d3d731ef-3f7d-4a0a-ce91-7df57627ddbd
inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))
inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
```
+++ {"id": "O0hbsm70be5T"}
Almost there! This does what we want, but is slower than it has to be. Now, we wrap the whole thing in a `jax.jit` to get the compiled, efficient version of the same function:
```{code-cell} ipython3
:id: Fvr709FcbrSW
:outputId: 627db899-5620-4bed-8d34-cd1364d3d187
perex_grads = jax.jit(inefficient_perex_grads)
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)
```
```{code-cell} ipython3
:id: FH42yzbHcNs2
:outputId: c8e52f93-615a-4ce7-d8ab-fb6215995a39
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
```

View File

@ -1,509 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "1Op_vnmkjw3z"
},
"source": [
"# Pseudo Random Numbers in JAX\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb)\n",
"\n",
"*Authors: Matteo Hessel & Rosalia Schneider*\n",
"\n",
"In this section we focus on pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution. \n",
"\n",
"PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next.\n",
"\n",
"Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.\n",
"\n",
"To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6_117sy0CGEU"
},
"source": [
"## Random numbers in NumPy\n",
"\n",
"Pseudo random number generation is natively supported in NumPy by the `numpy.random` module.\n",
"\n",
"In NumPy, pseudo random number generation is based on a global `state`.\n",
"\n",
"This can be set to a deterministic initial condition using `random.seed(SEED)`."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "qbmCquES5beU"
},
"outputs": [],
"source": [
"import numpy as np\n",
"np.random.seed(0)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WImNZxJ-7plK"
},
"source": [
"You can inspect the content of the state using the following command."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "qNO_vG7z7qUb",
"outputId": "47817350-83be-40cc-85c3-46419fdbfda0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,\n",
" 2481403966, 4042607538, 337614300, 3232553940, 1018809052,\n",
" 3202401494, 1775180719, 3192392114, 594215549, 184016991,\n",
" 829906058, 610491522, 3879932251, 3139825610, 297902587,\n",
" 4075895579, 2943625357, 3530655617, 1423771745, 2135928312,\n",
" 2891506774, 1066338622, 135451537, 933040465, 2759011858,\n",
" 2273819758, 3545703099, 2516396728, 127 ...\n"
]
}
],
"source": [
"def print_truncated_random_state():\n",
" \"\"\"To avoid spamming the outputs, print only part of the state.\"\"\"\n",
" full_random_state = np.random.get_state()\n",
" print(str(full_random_state)[:460], '...')\n",
"\n",
"print_truncated_random_state()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "nmqx0gJW9CFo"
},
"source": [
"The `state` is updated by each call to a random function:"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "ZqUzvqF1B1TO",
"outputId": "c1874391-eb8d-43d8-eb8f-c918ed0a0c1a"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('MT19937', array([ 0, 1, 1812433255, 1900727105, 1208447044,\n",
" 2481403966, 4042607538, 337614300, 3232553940, 1018809052,\n",
" 3202401494, 1775180719, 3192392114, 594215549, 184016991,\n",
" 829906058, 610491522, 3879932251, 3139825610, 297902587,\n",
" 4075895579, 2943625357, 3530655617, 1423771745, 2135928312,\n",
" 2891506774, 1066338622, 135451537, 933040465, 2759011858,\n",
" 2273819758, 3545703099, 2516396728, 127 ...\n",
"('MT19937', array([2443250962, 1093594115, 1878467924, 2709361018, 1101979660,\n",
" 3904844661, 676747479, 2085143622, 1056793272, 3812477442,\n",
" 2168787041, 275552121, 2696932952, 3432054210, 1657102335,\n",
" 3518946594, 962584079, 1051271004, 3806145045, 1414436097,\n",
" 2032348584, 1661738718, 1116708477, 2562755208, 3176189976,\n",
" 696824676, 2399811678, 3992505346, 569184356, 2626558620,\n",
" 136797809, 4273176064, 296167901, 343 ...\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"\n",
"print_truncated_random_state()\n",
"\n",
"_ = np.random.uniform()\n",
"\n",
"print_truncated_random_state()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "G1ICXejY_xR0"
},
"source": [
"NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "6Xqx2e8tAW5d",
"outputId": "a428facb-cd16-4375-f5c4-8fc601e60169"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.5488135 0.71518937 0.60276338]\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"print(np.random.uniform(size=3))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "zPfs8tXTAlr7"
},
"source": [
"NumPy provides a *sequential equivalent guarantee*, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "bZiBZXHW_2wO",
"outputId": "3aff9a51-8a19-4737-a7ad-91b23bfc05f8"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"individually: [0.5488135 0.71518937 0.60276338]\n",
"all at once: [0.5488135 0.71518937 0.60276338]\n"
]
}
],
"source": [
"np.random.seed(0)\n",
"print(\"individually:\", np.stack([np.random.uniform() for _ in range(3)]))\n",
"\n",
"np.random.seed(0)\n",
"print(\"all at once: \", np.random.uniform(size=3))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "JGCZI9UTl7o4"
},
"source": [
"## Random numbers in JAX\n",
"\n",
"JAX's random number generation differs from NumPy's in important ways. The reason is that NumPy's PRNG design makes it hard to simultaneously guarantee a number of desirable properties for JAX, specifically that code must be:\n",
"\n",
"1. reproducible,\n",
"2. parallelizable,\n",
"3. vectorisable.\n",
"\n",
"We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "j441y2NCmnbt",
"outputId": "77fe84d7-c86e-417a-95b9-d73663ed40fc"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.9791922366721637\n"
]
}
],
"source": [
"import numpy as np\n",
"\n",
"np.random.seed(0)\n",
"\n",
"def bar(): return np.random.uniform()\n",
"def baz(): return np.random.uniform()\n",
"\n",
"def foo(): return bar() + 2 * baz()\n",
"\n",
"print(foo())"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "5kVpfSV5n1d7"
},
"source": [
"The function `foo` sums two scalars sampled from a uniform distribution.\n",
"\n",
"The output of this code can only satisfy requirement #1 if we assume a specific order of execution for `bar()` and `baz()`, as native Python does.\n",
"\n",
"This doesn't seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX. \n",
"\n",
"Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.\n",
"\n",
"To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key` ."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "LuaGUVRUvbzQ",
"outputId": "bbf525d7-d407-49b4-8bee-2cd827846e04"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[ 0 42]\n"
]
}
],
"source": [
"from jax import random\n",
"\n",
"key = random.key(42)\n",
"\n",
"print(key)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XhFpKnW9F2nF"
},
"source": [
"A single key is an array of scalar shape `()` and key element type.\n",
"\n",
"'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "Tc_Tsv06Fz3l",
"outputId": "1472ae73-edbf-4163-9992-46781d258014"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.18471184\n",
"-0.18471184\n"
]
}
],
"source": [
"print(random.normal(key))\n",
"print(random.normal(key))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "foUEgtmTesOx"
},
"source": [
"**Note:** Feeding the same key to different random functions can result in correlated outputs, which is generally undesirable. \n",
"\n",
"**The rule of thumb is: never reuse keys (unless you want identical outputs).**"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "T4dOLP0GGJuB"
},
"source": [
"In order to generate different and independent samples, you must `split()` the key *yourself* whenever you want to call a random function:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "qChuz1C9CSJe",
"outputId": "f6eb1dc3-d83c-45ef-d90e-5a12d36fa7e6"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"old key [ 0 42]\n",
" \\---SPLIT --> new key [2465931498 3679230171]\n",
" \\--> new subkey [255383827 267815257] --> normal 1.3694694\n"
]
}
],
"source": [
"print(\"old key\", key)\n",
"new_key, subkey = random.split(key)\n",
"del key # The old key is discarded -- we must never use it again.\n",
"normal_sample = random.normal(subkey)\n",
"print(r\" \\---SPLIT --> new key \", new_key)\n",
"print(r\" \\--> new subkey\", subkey, \"--> normal\", normal_sample)\n",
"del subkey # The subkey is also discarded after use.\n",
"\n",
"# Note: you don't actually need to `del` keys -- that's just for emphasis.\n",
"# Not reusing the same values is enough.\n",
"\n",
"key = new_key # If we wanted to do this again, we would use new_key as the key."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WKQMJQB6cGhV"
},
"source": [
"`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.\n",
"\n",
"If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.\n",
"\n",
"It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.\n",
"\n",
"Usually, the above example would be written concisely as"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "Xkt5OYjHjWiP"
},
"outputs": [],
"source": [
"key, subkey = random.split(key)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ULmPVyd9jWSv"
},
"source": [
"which discards the old key automatically."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "dlaAsObh68R1"
},
"source": [
"It's worth noting that `split()` can create as many keys as you need, not just 2:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "hbHZP2xM7Egf"
},
"outputs": [],
"source": [
"key, *forty_two_subkeys = random.split(key, num=43)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Fhu7ejhLB4R_"
},
"source": [
"Another difference between NumPy's and JAX's random modules relates to the sequential equivalence guarantee mentioned above.\n",
"\n",
"As in NumPy, JAX's random module also allows sampling of vectors of numbers.\n",
"However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).\n",
"\n",
"In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying `shape=(3,)`:"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "4nB_TA54D-HT",
"outputId": "2f259f63-3c45-46c8-f597-4e53dc63cb56"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"individually: [-0.04838839 0.10796146 -1.2226542 ]\n",
"all at once: [ 0.18693541 -1.2806507 -1.5593133 ]\n"
]
}
],
"source": [
"key = random.key(42)\n",
"subkeys = random.split(key, 3)\n",
"sequence = np.stack([random.normal(subkey) for subkey in subkeys])\n",
"print(\"individually:\", sequence)\n",
"\n",
"key = random.key(42)\n",
"print(\"all at once: \", random.normal(key, shape=(3,)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_vBAaU2jrWPk"
},
"source": [
"Note that contrary to our recommendation above, we use `key` directly as an input to `random.normal()` in the second example. This is because we won't reuse it anywhere else, so we don't violate the single-use principle."
]
}
],
"metadata": {
"colab": {
"name": "Random Numbers in JAX",
"provenance": []
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -1,254 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "1Op_vnmkjw3z"}
# Pseudo Random Numbers in JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb)
*Authors: Matteo Hessel & Rosalia Schneider*
In this section we focus on pseudo random number generation (PRNG); that is, the process of algorithmically generating sequences of numbers whose properties approximate the properties of sequences of random numbers sampled from an appropriate distribution.
PRNG-generated sequences are not truly random because they are actually determined by their initial value, which is typically referred to as the `seed`, and each step of random sampling is a deterministic function of some `state` that is carried over from a sample to the next.
Pseudo random number generation is an essential component of any machine learning or scientific computing framework. Generally, JAX strives to be compatible with NumPy, but pseudo random number generation is a notable exception.
To better understand the difference between the approaches taken by JAX and NumPy when it comes to random number generation we will discuss both approaches in this section.
+++ {"id": "6_117sy0CGEU"}
## Random numbers in NumPy
Pseudo random number generation is natively supported in NumPy by the `numpy.random` module.
In NumPy, pseudo random number generation is based on a global `state`.
This can be set to a deterministic initial condition using `random.seed(SEED)`.
```{code-cell} ipython3
:id: qbmCquES5beU
import numpy as np
np.random.seed(0)
```
+++ {"id": "WImNZxJ-7plK"}
You can inspect the content of the state using the following command.
```{code-cell} ipython3
:id: qNO_vG7z7qUb
:outputId: 47817350-83be-40cc-85c3-46419fdbfda0
def print_truncated_random_state():
"""To avoid spamming the outputs, print only part of the state."""
full_random_state = np.random.get_state()
print(str(full_random_state)[:460], '...')
print_truncated_random_state()
```
+++ {"id": "nmqx0gJW9CFo"}
The `state` is updated by each call to a random function:
```{code-cell} ipython3
:id: ZqUzvqF1B1TO
:outputId: c1874391-eb8d-43d8-eb8f-c918ed0a0c1a
np.random.seed(0)
print_truncated_random_state()
_ = np.random.uniform()
print_truncated_random_state()
```
+++ {"id": "G1ICXejY_xR0"}
NumPy allows you to sample both individual numbers, or entire vectors of numbers in a single function call. For instance, you may sample a vector of 3 scalars from a uniform distribution by doing:
```{code-cell} ipython3
:id: 6Xqx2e8tAW5d
:outputId: a428facb-cd16-4375-f5c4-8fc601e60169
np.random.seed(0)
print(np.random.uniform(size=3))
```
+++ {"id": "zPfs8tXTAlr7"}
NumPy provides a *sequential equivalent guarantee*, meaning that sampling N numbers in a row individually or sampling a vector of N numbers results in the same pseudo-random sequences:
```{code-cell} ipython3
:id: bZiBZXHW_2wO
:outputId: 3aff9a51-8a19-4737-a7ad-91b23bfc05f8
np.random.seed(0)
print("individually:", np.stack([np.random.uniform() for _ in range(3)]))
np.random.seed(0)
print("all at once: ", np.random.uniform(size=3))
```
+++ {"id": "JGCZI9UTl7o4"}
## Random numbers in JAX
JAX's random number generation differs from NumPy's in important ways. The reason is that NumPy's PRNG design makes it hard to simultaneously guarantee a number of desirable properties for JAX, specifically that code must be:
1. reproducible,
2. parallelizable,
3. vectorisable.
We will discuss why in the following. First, we will focus on the implications of a PRNG design based on a global state. Consider the code:
```{code-cell} ipython3
:id: j441y2NCmnbt
:outputId: 77fe84d7-c86e-417a-95b9-d73663ed40fc
import numpy as np
np.random.seed(0)
def bar(): return np.random.uniform()
def baz(): return np.random.uniform()
def foo(): return bar() + 2 * baz()
print(foo())
```
+++ {"id": "5kVpfSV5n1d7"}
The function `foo` sums two scalars sampled from a uniform distribution.
The output of this code can only satisfy requirement #1 if we assume a specific order of execution for `bar()` and `baz()`, as native Python does.
This doesn't seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX.
Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.
To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key` .
```{code-cell} ipython3
:id: LuaGUVRUvbzQ
:outputId: bbf525d7-d407-49b4-8bee-2cd827846e04
from jax import random
key = random.key(42)
print(key)
```
+++ {"id": "XhFpKnW9F2nF"}
A single key is an array of scalar shape `()` and key element type.
'Random key' is essentially just another word for 'random seed'. However, instead of setting it once as in NumPy, any call of a random function in JAX requires a key to be specified. Random functions consume the key, but do not modify it. Feeding the same key to a random function will always result in the same sample being generated:
```{code-cell} ipython3
:id: Tc_Tsv06Fz3l
:outputId: 1472ae73-edbf-4163-9992-46781d258014
print(random.normal(key))
print(random.normal(key))
```
+++ {"id": "foUEgtmTesOx"}
**Note:** Feeding the same key to different random functions can result in correlated outputs, which is generally undesirable.
**The rule of thumb is: never reuse keys (unless you want identical outputs).**
+++ {"id": "T4dOLP0GGJuB"}
In order to generate different and independent samples, you must `split()` the key *yourself* whenever you want to call a random function:
```{code-cell} ipython3
:id: qChuz1C9CSJe
:outputId: f6eb1dc3-d83c-45ef-d90e-5a12d36fa7e6
print("old key", key)
new_key, subkey = random.split(key)
del key # The old key is discarded -- we must never use it again.
normal_sample = random.normal(subkey)
print(r" \---SPLIT --> new key ", new_key)
print(r" \--> new subkey", subkey, "--> normal", normal_sample)
del subkey # The subkey is also discarded after use.
# Note: you don't actually need to `del` keys -- that's just for emphasis.
# Not reusing the same values is enough.
key = new_key # If we wanted to do this again, we would use new_key as the key.
```
+++ {"id": "WKQMJQB6cGhV"}
`split()` is a deterministic function that converts one `key` into several independent (in the pseudorandomness sense) keys. We keep one of the outputs as the `new_key`, and can safely use the unique extra key (called `subkey`) as input into a random function, and then discard it forever.
If you wanted to get another sample from the normal distribution, you would split `key` again, and so on. The crucial point is that you never use the same PRNG key twice. Since `split()` takes a key as its argument, we must throw away that old key when we split it.
It doesn't matter which part of the output of `split(key)` we call `key`, and which we call `subkey`. They are all pseudorandom numbers with equal status. The reason we use the key/subkey convention is to keep track of how they're consumed down the road. Subkeys are destined for immediate consumption by random functions, while the key is retained to generate more randomness later.
Usually, the above example would be written concisely as
```{code-cell} ipython3
:id: Xkt5OYjHjWiP
key, subkey = random.split(key)
```
+++ {"id": "ULmPVyd9jWSv"}
which discards the old key automatically.
+++ {"id": "dlaAsObh68R1"}
It's worth noting that `split()` can create as many keys as you need, not just 2:
```{code-cell} ipython3
:id: hbHZP2xM7Egf
key, *forty_two_subkeys = random.split(key, num=43)
```
+++ {"id": "Fhu7ejhLB4R_"}
Another difference between NumPy's and JAX's random modules relates to the sequential equivalence guarantee mentioned above.
As in NumPy, JAX's random module also allows sampling of vectors of numbers.
However, JAX does not provide a sequential equivalence guarantee, because doing so would interfere with the vectorization on SIMD hardware (requirement #3 above).
In the example below, sampling 3 values out of a normal distribution individually using three subkeys gives a different result to using giving a single key and specifying `shape=(3,)`:
```{code-cell} ipython3
:id: 4nB_TA54D-HT
:outputId: 2f259f63-3c45-46c8-f597-4e53dc63cb56
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)
key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
```
+++ {"id": "_vBAaU2jrWPk"}
Note that contrary to our recommendation above, we use `key` directly as an input to `random.normal()` in the second example. This is because we won't reuse it anywhere else, so we don't violate the single-use principle.

File diff suppressed because one or more lines are too long

View File

@ -1,536 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "-h05_PNNhZ-D"}
# Working with Pytrees
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/05.1-pytrees.ipynb)
*Author: Vladimir Mikulik*
Often, we want to operate on objects that look like dicts of arrays, or lists of lists of dicts, or other nested structures. In JAX, we refer to these as *pytrees*, but you can sometimes see them called *nests*, or just *trees*.
JAX has built-in support for such objects, both in its library functions as well as through the use of functions from [`jax.tree_utils`](https://jax.readthedocs.io/en/latest/jax.tree_util.html) (with the most common ones also available as `jax.tree_*`). This section will explain how to use them, give some useful snippets and point out common gotchas.
+++ {"id": "9UjxVY9ulSCn"}
## What is a pytree?
As defined in the [JAX pytree docs](https://jax.readthedocs.io/en/latest/pytrees.html):
> a pytree is a container of leaf elements and/or more pytrees. Containers include lists, tuples, and dicts. A leaf element is anything thats not a pytree, e.g. an array. In other words, a pytree is just a possibly-nested standard or user-registered Python container. If nested, note that the container types do not need to match. A single “leaf”, i.e. a non-container object, is also considered a pytree.
Some example pytrees:
```{code-cell} ipython3
---
executionInfo:
elapsed: 11002
status: ok
timestamp: 1692698031720
user:
displayName: ''
userId: ''
user_tz: -60
id: Wh6BApZ9lrR1
outputId: df1fa4cd-88a6-4d71-a376-b2ddf91568dd
---
import jax
import jax.numpy as jnp
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# Let's see how many leaves they have:
for pytree in example_trees:
leaves = jax.tree_util.tree_leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
```
+++ {"id": "_tWkkGNwW8vf"}
We've also introduced our first `jax.tree_*` function, which allowed us to extract the flattened leaves from the trees.
+++ {"id": "RcsmneIGlltm"}
## Why pytrees?
In machine learning, some places where you commonly find pytrees are:
* Model parameters
* Dataset entries
* RL agent observations
They also often arise naturally when working in bulk with datasets (e.g., lists of lists of dicts).
+++ {"id": "sMrSGSIJn9MD"}
## Common pytree functions
Perhaps the most commonly used pytree function is `jax.tree_map`. It works analogously to Python's native `map`, but on entire pytrees:
```{code-cell} ipython3
:id: wZRcuQu4n7o5
:outputId: 3528bc9f-54ed-49c8-b79a-1cbea176c0f3
list_of_lists = [
[1, 2, 3],
[1, 2],
[1, 2, 3, 4]
]
jax.tree_map(lambda x: x*2, list_of_lists)
```
+++ {"id": "xu8X3fk4orC9"}
`jax.tree_map` also works with multiple arguments:
```{code-cell} ipython3
:id: KVpB4r1OkeUK
:outputId: 33f88a7e-aac7-48cd-d207-2c531cd37733
another_list_of_lists = list_of_lists
jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
```
+++ {"id": "dkRKy3LvowAb"}
When using multiple arguments with `jax.tree_map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.
+++ {"id": "Lla4hDW6sgMZ"}
## Example: ML model parameters
A simple example of training an MLP displays some ways in which pytree operations come in useful:
```{code-cell} ipython3
:id: j2ZUzWx8tKB2
import numpy as np
def init_mlp_params(layer_widths):
params = []
for n_in, n_out in zip(layer_widths[:-1], layer_widths[1:]):
params.append(
dict(weights=np.random.normal(size=(n_in, n_out)) * np.sqrt(2/n_in),
biases=np.ones(shape=(n_out,))
)
)
return params
params = init_mlp_params([1, 128, 128, 1])
```
+++ {"id": "kUFwJOspuGvU"}
We can use `jax.tree_map` to check that the shapes of our parameters are what we expect:
```{code-cell} ipython3
:id: ErWsXuxXse-z
:outputId: d3e549ab-40ef-470e-e460-1b5939d9696f
jax.tree_map(lambda x: x.shape, params)
```
+++ {"id": "zQtRKaj4ua6-"}
Now, let's train our MLP:
```{code-cell} ipython3
:id: iL4GvW9OuZ-X
def forward(params, x):
*hidden, last = params
for layer in hidden:
x = jax.nn.relu(x @ layer['weights'] + layer['biases'])
return x @ last['weights'] + last['biases']
def loss_fn(params, x, y):
return jnp.mean((forward(params, x) - y) ** 2)
LEARNING_RATE = 0.0001
@jax.jit
def update(params, x, y):
grads = jax.grad(loss_fn)(params, x, y)
# Note that `grads` is a pytree with the same structure as `params`.
# `jax.grad` is one of the many JAX functions that has
# built-in support for pytrees.
# This is handy, because we can apply the SGD update using tree utils:
return jax.tree_map(
lambda p, g: p - LEARNING_RATE * g, params, grads
)
```
```{code-cell} ipython3
:id: B3HniT9-xohz
:outputId: d77e9811-373e-45d6-ccbe-edb6f43120d7
import matplotlib.pyplot as plt
xs = np.random.normal(size=(128, 1))
ys = xs ** 2
for _ in range(1000):
params = update(params, xs, ys)
plt.scatter(xs, ys)
plt.scatter(xs, forward(params, xs), label='Model prediction')
plt.legend();
```
+++ {"id": "lNAvmpzdoE9l"}
## Key paths
In a pytree each leaf has a _key path_. A key path for a leaf is a `list` of _keys_, where the length of the list is equal to the depth of the leaf in the pytree . Each _key_ is a [hashable object](https://docs.python.org/3/glossary.html#term-hashable) that represents an index into the corresponding pytree node type. The type of the key depends on the pytree node type; for example, the type of keys for `dict`s is different from the type of keys for `tuple`s.
For built-in pytree node types, the set of keys for any pytree node instance is unique. For a pytree comprising nodes with this property, the key path for each leaf is unique.
The APIs for working with key paths are:
* [`jax.tree_util.tree_flatten_with_path`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_flatten_with_path.html): Works similarly with `jax.tree_util.tree_flatten`, but returns key paths.
* [`jax.tree_util.tree_map_with_path`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.tree_map_with_path.html): Works similarly with `jax.tree_util.tree_map`, but the function also takes key paths as arguments.
* [`jax.tree_util.keystr`](https://jax.readthedocs.io/en/latest/_autosummary/jax.tree_util.keystr.html): Given a general key path, returns a reader-friendly string expression.
One use case is to print debugging information related to a certain leaf value:
```{code-cell} ipython3
:id: G6E2YzhvoE9l
:outputId: 5aec83c8-e15e-48eb-b2c3-6fa0164344b5
import collections
ATuple = collections.namedtuple("ATuple", ('name'))
tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
```
+++ {"id": "zrKqmANgoE9l"}
To express key paths, JAX provides a few default key types for the built-in pytree node types, namely:
* `SequenceKey(idx: int)`: for lists and tuples.
* `DictKey(key: Hashable)`: for dictionaries.
* `GetAttrKey(name: str)`: for `namedtuple`s and preferably custom pytree nodes (more in the next section)
You are free to define your own key types for your own custom nodes. They will work with `jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression.
```{code-cell} ipython3
:id: ohDq0kGuoE9l
:outputId: 9b8ff3ec-3461-482e-ff27-30dc2a7e68c9
for key_path, _ in flattened:
print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
```
+++ {"id": "sBxOB21YNEDA"}
## Custom pytree nodes
So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:
```{code-cell} ipython3
:id: CK8LN2PRFnQf
class MyContainer:
"""A named container."""
def __init__(self, name: str, a: int, b: int, c: int):
self.name = name
self.a = a
self.b = b
self.c = c
```
```{code-cell} ipython3
:id: OPGe2R7ZOXCT
:outputId: 40db1f41-9df8-4dea-972a-6a7bc44a49c6
jax.tree_util.tree_leaves([
MyContainer('Alice', 1, 2, 3),
MyContainer('Bob', 4, 5, 6)
])
```
+++ {"id": "vk4vucGXPADj"}
Accordingly, if we try to use a tree map expecting our leaves to be the elements inside the container, we will get an error:
```{code-cell} ipython3
:id: vIr9_JOIOku7
:outputId: dadc9c15-4a10-4fac-e70d-f23e7085cf74
try:
jax.tree_map(lambda x: x + 1, [
MyContainer('Alice', 1, 2, 3),
MyContainer('Bob', 4, 5, 6)
])
except TypeError as e:
print(f'TypeError: {e}')
```
+++ {"id": "nAZ4FR2lPN51", "tags": ["raises-exception"]}
To solve this, we need to register our container with JAX by telling it how to flatten and unflatten it:
```{code-cell} ipython3
:id: 2RR5cDFvoE9m
:outputId: 94745373-abe4-4bca-967c-4133e8027c30
from typing import Iterable
def flatten_MyContainer(container) -> tuple[Iterable[int], str]:
"""Returns an iterable over container contents, and aux data."""
flat_contents = [container.a, container.b, container.c]
# we don't want the name to appear as a child, so it is auxiliary data.
# auxiliary data is usually a description of the structure of a node,
# e.g., the keys of a dict -- anything that isn't a node's children.
aux_data = container.name
return flat_contents, aux_data
def unflatten_MyContainer(
aux_data: str, flat_contents: Iterable[int]) -> MyContainer:
"""Converts aux data and the flat contents into a MyContainer."""
return MyContainer(aux_data, *flat_contents)
jax.tree_util.register_pytree_node(
MyContainer, flatten_MyContainer, unflatten_MyContainer)
jax.tree_util.tree_leaves([
MyContainer('Alice', 1, 2, 3),
MyContainer('Bob', 4, 5, 6)
])
```
+++ {"id": "JXaEe76ZoE9m"}
Alternatively, using the key path API mentioned above, you can register this container with its keys in mind by defining how the keys should look like for each flattened-out value.
```{code-cell} ipython3
:id: D_juQx-2OybX
:outputId: ee2cf4ad-ec21-4636-c9c5-2c64b81429bb
class MyKeyPathContainer(MyContainer):
pass
def flatten_with_keys_MyKeyPathContainer(container) -> tuple[Iterable[int], str]:
"""Returns an iterable over container contents, and aux data."""
# GetAttrKey is a common way to express an attribute key. Users are free
# to pick any other expression that fits their use cases the best.
flat_contents = [(jax.tree_util.GetAttrKey('a'), container.a),
(jax.tree_util.GetAttrKey('b'), container.b),
(jax.tree_util.GetAttrKey('c'), container.c)]
# we don't want the name to appear as a child, so it is auxiliary data.
# auxiliary data is usually a description of the structure of a node,
# e.g., the keys of a dict -- anything that isn't a node's children.
aux_data = container.name
return flat_contents, aux_data
def unflatten_MyKeyPathContainer(
aux_data: str, flat_contents: Iterable[int]) -> MyKeyPathContainer:
"""Converts aux data and the flat contents into a MyContainer."""
return MyKeyPathContainer(aux_data, *flat_contents)
jax.tree_util.register_pytree_with_keys(
MyKeyPathContainer, flatten_with_keys_MyKeyPathContainer, unflatten_MyKeyPathContainer)
jax.tree_util.tree_leaves([
MyKeyPathContainer('Alice', 1, 2, 3),
MyKeyPathContainer('Bob', 4, 5, 6)
])
```
+++ {"id": "HPX23W4zoE9m"}
`register_pytree_with_keys` is an extended API of `register_pytree_node`, and containers registered in either way can freely use all the `tree_util` utilities without error.
When a container registered with `register_pytree_node` uses `.*_with_path` APIs, the keys being returned will be a series of "flat index" fallbacks:
```{code-cell} ipython3
:id: E1BwD2aZoE9m
:outputId: 4fe12b06-aef4-426a-a732-891affa63842
flattened, _ = jax.tree_util.tree_flatten_with_path(MyContainer('Alice', 1, 2, 3))
for key_path, value in flattened:
print(f'MyContainer container{jax.tree_util.keystr(key_path)}: {value}')
flattened, _ = jax.tree_util.tree_flatten_with_path(MyKeyPathContainer('Alice', 1, 2, 3))
for key_path, value in flattened:
print(f'MyKeyPathContainer container{jax.tree_util.keystr(key_path)}: {value}')
```
+++ {"id": "JgnAp7fFShEB"}
Modern Python comes equipped with helpful tools to make defining containers easier. Some of these will work with JAX out-of-the-box, but others require more care. For instance, a `NamedTuple` subclass doesn't need to be registered to be considered a pytree node type:
```{code-cell} ipython3
:id: 8DNoLABtO0fr
:outputId: 9a448508-43eb-4450-bfaf-eeeb59a9e349
from typing import NamedTuple, Any
class MyOtherContainer(NamedTuple):
name: str
a: Any
b: Any
c: Any
# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box:
jax.tree_util.tree_leaves([
MyOtherContainer('Alice', 1, 2, 3),
MyOtherContainer('Bob', 4, 5, 6)
])
```
+++ {"id": "TVdtzJDVTZb6"}
Notice that the `name` field now appears as a leaf, as all tuple elements are children. That's the price we pay for not having to register the class the hard way.
+++ {"id": "wDbVszv-oE9n"}
One shortcut is to use `jax.tree_util.register_static` to register a type as being a node without children:
```{code-cell} ipython3
---
executionInfo:
elapsed: 59
status: ok
timestamp: 1692698060536
user:
displayName: ''
userId: ''
user_tz: -60
id: Rclc079ioE9n
outputId: 6b6a4402-8fc1-409c-b6da-88568a612e1b
---
from typing import NamedTuple, Any
@jax.tree_util.register_static
class StaticStr(str):
pass
class YetAnotherContainer(NamedTuple):
name: StaticStr
a: Any
b: Any
c: Any
# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box:
jax.tree_util.tree_leaves([
YetAnotherContainer(StaticStr('Alice'), 1, 2, 3),
YetAnotherContainer(StaticStr('Bob'), 4, 5, 6)
])
```
+++ {"id": "kNsTszcEEHD0"}
## Common pytree gotchas and patterns
+++ {"id": "0ki-JDENzyL7"}
### Gotchas
#### Mistaking nodes for leaves
A common problem to look out for is accidentally introducing tree nodes instead of leaves:
```{code-cell} ipython3
:id: N-th4jOAGJlM
:outputId: 23eed14d-d383-4d88-d6f9-02bac06020df
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]
# Try to make another tree with ones instead of zeros
shapes = jax.tree_map(lambda x: x.shape, a_tree)
jax.tree_map(jnp.ones, shapes)
```
+++ {"id": "q8d4y-hfHTWh"}
What happened is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`.
The solution will depend on the specifics, but there are two broadly applicable options:
* rewrite the code to avoid the intermediate `tree_map`.
* convert the tuple into an `np.array` or `jnp.array`, which makes the entire
sequence a leaf.
+++ {"id": "4OKlbFlEIda-"}
#### Handling of None
`jax.tree_utils` treats `None` as a node without children, not as a leaf:
```{code-cell} ipython3
:id: gIwlwo2MJcEC
:outputId: 1e59f323-a7b7-42be-8603-afa4693c00cc
jax.tree_util.tree_leaves([None, None, None])
```
+++ {"id": "pwNz-rp1JvW4"}
### Patterns
#### Transposing trees
If you would like to transpose a pytree, i.e. turn a list of trees into a tree of lists, you can do so using `jax.tree_map`:
```{code-cell} ipython3
:id: UExN7-G7qU-F
:outputId: fd049086-ef37-44db-8e2c-9f1bd9fad950
def tree_transpose(list_of_trees):
"""Convert a list of trees of identical structure into a single tree of lists."""
return jax.tree_map(lambda *xs: list(xs), *list_of_trees)
# Convert a dataset from row-major to column-major:
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
```
+++ {"id": "Ao6R2ffm2CF4"}
For more complicated transposes, JAX provides `jax.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer Pytree for more flexibility:
```{code-cell} ipython3
:id: bZvVwxshz1D3
:outputId: a0314dc8-4267-41e6-a763-931d40433c26
jax.tree_transpose(
outer_treedef = jax.tree_structure([0 for e in episode_steps]),
inner_treedef = jax.tree_structure(episode_steps[0]),
pytree_to_transpose = episode_steps
)
```
+++ {"id": "KlYA2R6N2h_8"}
## More Information
For more information on pytrees in JAX and the operations that are available, see the [Pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) section in the JAX documentation.

File diff suppressed because one or more lines are too long

View File

@ -1,414 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "tCOWitsAS1EE"}
# Parallel Evaluation in JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/06-parallelism.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/06-parallelism.ipynb)
*Authors: Vladimir Mikulik & Roman Ring*
In this section we will discuss the facilities built into JAX for single-program, multiple-data (SPMD) code.
SPMD refers to a parallelism technique where the same computation (e.g., the forward pass of a neural net) is run on different input data (e.g., different inputs in a batch) in parallel on different devices (e.g., several TPUs).
Conceptually, this is not very different from vectorisation, where the same operations occur in parallel in different parts of memory on the same device. We have already seen that vectorisation is supported in JAX as a program transformation, `jax.vmap`. JAX supports device parallelism analogously, using `jax.pmap` to transform a function written for one device into a function that runs in parallel on multiple devices. This colab will teach you all about it.
+++ {"id": "7mCgBzix2fd3"}
## TPU Setup
This notebook requires multiple accelerators and we recommend running it using Kaggle TPU VMs.
+++ {"id": "gN6VbcdRTcdE"}
Next run the following to see the TPU devices you have available:
```{code-cell} ipython3
:id: tqbpCcqY3Cn7
:outputId: 1fb88cf7-35f7-4565-f370-51586213b988
import jax
jax.devices()
```
+++ {"id": "4_EDa0Dlgtf8"}
## The basics
The most basic use of `jax.pmap` is completely analogous to `jax.vmap`, so let's return to the convolution example from the [Vectorisation notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/03-vectorization.ipynb).
```{code-cell} ipython3
:id: IIQKBr-CgtD2
:outputId: 6e7f8755-fdfd-4cf9-e2b5-a10c5a870dd4
import numpy as np
import jax.numpy as jnp
x = np.arange(5)
w = np.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
```
+++ {"id": "lqxz9NNJOQ9Z"}
Now, let's convert our `convolve` function into one that runs on entire batches of data. In anticipation of spreading the batch across several devices, we'll make the batch size equal to the number of devices:
```{code-cell} ipython3
:id: ll-hEa0jihzx
:outputId: 788be05a-10d4-4a05-8d9d-49d0083541ab
n_devices = jax.local_device_count()
xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)
xs
```
```{code-cell} ipython3
:id: mi-nysDWYbn4
:outputId: 2d115fc3-52f5-4a68-c3a7-115111a83657
ws
```
+++ {"id": "8kseIB09YWJw"}
As before, we can vectorise using `jax.vmap`:
```{code-cell} ipython3
:id: TNb9HsFXYVOI
:outputId: 2e60e07a-6687-49ab-a455-60d2ec484363
jax.vmap(convolve)(xs, ws)
```
+++ {"id": "TDF1vzt_5GMC"}
To spread out the computation across multiple devices, just replace `jax.vmap` with `jax.pmap`:
```{code-cell} ipython3
:id: KWoextrails4
:outputId: bad1fbb7-226a-4538-e442-20ce0c1c8fad
jax.pmap(convolve)(xs, ws)
```
+++ {"id": "E69cVxQPksxe"}
Note that the parallelized `convolve` returns a `jax.Array`. That is because the elements of this array are sharded across all of the devices used in the parallelism. If we were to run another parallel computation, the elements would stay on their respective devices, without incurring cross-device communication costs.
```{code-cell} ipython3
:id: P9dUyk-ciquy
:outputId: 99ea4c6e-cff7-4611-e9e5-bf016fa9716c
jax.pmap(convolve)(xs, jax.pmap(convolve)(xs, ws))
```
+++ {"id": "iuHqht-OYqca"}
The outputs of the inner `jax.pmap(convolve)` have never left their devices when being fed into the outer `jax.pmap(convolve)`.
+++ {"id": "vEFAJXN2q3dV"}
## Specifying `in_axes`
Like with `vmap`, we can use `in_axes` to specify whether an argument to the parallelized function should be broadcast (`None`), or whether it should be split along a given axis. Note, however, that unlike `vmap`, only the leading axis (`0`) is supported by `pmap` at the time of writing this guide.
```{code-cell} ipython3
:id: 6Es5WVuRlXnB
:outputId: 7e9612ae-d6e0-4d79-a228-f0403fcf8237
jax.pmap(convolve, in_axes=(0, None))(xs, w)
```
+++ {"id": "EoN6drHDOlk4"}
Notice how we get equivalent output to what we observe above with `jax.pmap(convolve)(xs, ws)`, where we manually replicated `w` when creating `ws`. Here, it is replicated via broadcasting, by specifying it as `None` in `in_axes`.
+++ {"id": "rRE8STSU5cjx"}
Keep in mind that when calling the transformed function, the size of the specified axis in arguments must not exceed the number of devices available to the host.
+++ {"id": "0lZnqImd7G6U"}
## `pmap` and `jit`
`jax.pmap` JIT-compiles the function given to it as part of its operation, so there is no need to additionally `jax.jit` it.
+++ {"id": "1jZqk_2AwO4y"}
## Communication between devices
The above is enough to perform simple parallel operations, e.g. batching a simple MLP forward pass across several devices. However, sometimes we need to pass information between the devices. For example, perhaps we are interested in normalizing the output of each device so they sum to 1.
For that, we can use special [collective ops](https://jax.readthedocs.io/en/latest/jax.lax.html#parallel-operators) (such as the `jax.lax.p*` ops `psum`, `pmean`, `pmax`, ...). In order to use the collective ops we must specify the name of the `pmap`-ed axis through the `axis_name` argument, and then refer to it when calling the op. Here's how to do that:
```{code-cell} ipython3
:id: 0nCxGwqmtd3w
:outputId: 6f9c93b0-51ed-40c5-ca5a-eacbaf40e686
def normalized_convolution(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
output = jnp.array(output)
return output / jax.lax.psum(output, axis_name='p')
jax.pmap(normalized_convolution, axis_name='p')(xs, ws)
```
+++ {"id": "9ENYsJS42YVK"}
The `axis_name` is just a string label that allows collective operations like `jax.lax.psum` to refer to the axis bound by `jax.pmap`. It can be named anything you want -- in this case, `p`. This name is essentially invisible to anything but those functions, and those functions use it to know which axis to communicate across.
`jax.vmap` also supports `axis_name`, which allows `jax.lax.p*` operations to be used in the vectorisation context in the same way they would be used in a `jax.pmap`:
```{code-cell} ipython3
:id: nT61xAYJUqCW
:outputId: e8831025-78a6-4a2b-a60a-3c77b35214ef
jax.vmap(normalized_convolution, axis_name='p')(xs, ws)
```
+++ {"id": "JSK-9dbWWV2O"}
Note that `normalized_convolution` will no longer work without being transformed by `jax.pmap` or `jax.vmap`, because `jax.lax.psum` expects there to be a named axis (`'p'`, in this case), and those two transformations are the only way to bind one.
## Nesting `jax.pmap` and `jax.vmap`
The reason we specify `axis_name` as a string is so we can use collective operations when nesting `jax.pmap` and `jax.vmap`. For example:
```python
jax.vmap(jax.pmap(f, axis_name='i'), axis_name='j')
```
A `jax.lax.psum(..., axis_name='i')` in `f` would refer only to the pmapped axis, since they share the `axis_name`.
In general, `jax.pmap` and `jax.vmap` can be nested in any order, and with themselves (so you can have a `pmap` within another `pmap`, for instance).
+++ {"id": "WzQHxnHkCxej"}
## Example
Here's an example of a regression training loop with data parallelism, where each batch is split into sub-batches which are evaluated on separate devices.
There are two places to pay attention to:
* the `update()` function
* the replication of parameters and splitting of data across devices.
If this example is too confusing, you can find the same example, but without parallelism, in the next notebook, [State in JAX](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb). Once that example makes sense, you can compare the differences to understand how parallelism changes the picture.
```{code-cell} ipython3
:id: cI8xQqzRrc-4
from typing import NamedTuple
import functools
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)
def loss_fn(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * xs + params.bias
return jnp.mean((pred - ys) ** 2)
LEARNING_RATE = 0.005
# So far, the code is identical to the single-device case. Here's what's new:
# Remember that the `axis_name` is just an arbitrary string label used
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> tuple[Params, jnp.ndarray]:
"""Performs one SGD update step on params using the given data."""
# Compute the gradients on the given minibatch (individually on each device).
loss, grads = jax.value_and_grad(loss_fn)(params, xs, ys)
# Combine the gradient across all devices (by taking their mean).
grads = jax.lax.pmean(grads, axis_name='num_devices')
# Also combine the loss. Unnecessary for the update, but useful for logging.
loss = jax.lax.pmean(loss, axis_name='num_devices')
# Each device performs its own update, but since we start with the same params
# and synchronise gradients, the params stay in sync.
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grads)
return new_params, loss
```
+++ {"id": "RWce8YZ4Pcmf"}
Here's how `update()` works:
Undecorated and without the `pmean`s, `update()` takes data tensors of shape `[batch, ...]`, computes the loss function on that batch and evaluates its gradients.
We want to spread the `batch` dimension across all available devices. To do that, we add a new axis using `pmap`. The arguments to the decorated `update()` thus need to have shape `[num_devices, batch_per_device, ...]`. So, to call the new `update()`, we'll need to reshape data batches so that what used to be `batch` is reshaped to `[num_devices, batch_per_device]`. That's what `split()` does below. Additionally, we'll need to replicate our model parameters, adding the `num_devices` axis. This reshaping is how a pmapped function knows which devices to send which data.
At some point during the update step, we need to combine the gradients computed by each device -- otherwise, the updates performed by each device would be different. That's why we use `jax.lax.pmean` to compute the mean across the `num_devices` axis, giving us the average gradient of the batch. That average gradient is what we use to compute the update.
Aside on naming: here, we use `num_devices` for the `axis_name` for didactic clarity while introducing `jax.pmap`. However, in some sense that is tautologous: any axis introduced by a pmap will represent a number of devices. Therefore, it's common to see the axis be named something semantically meaningful, like `batch`, `data` (signifying data parallelism) or `model` (signifying model parallelism).
```{code-cell} ipython3
:id: _CTtLrsQ-0kK
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
xs = np.random.normal(size=(128, 1))
noise = 0.5 * np.random.normal(size=(128, 1))
ys = xs * true_w + true_b + noise
# Initialise parameters and replicate across devices.
params = init(jax.random.key(123))
n_devices = jax.local_device_count()
replicated_params = jax.tree_map(lambda x: jnp.array([x] * n_devices), params)
```
+++ {"id": "dmCMyLP9SV99"}
So far, we've just constructed arrays with an additional leading dimension. The params are all still on the host (CPU). `pmap` will communicate them to the devices when `update()` is first called, and each copy will stay on its own device subsequently.
```{code-cell} ipython3
:id: YSCgHguTSdGW
:outputId: a8bf28df-3747-4d49-e340-b7696cf0c27d
type(replicated_params.weight)
```
+++ {"id": "90VtjPbeY-hD"}
The params will become a jax.Array when they are returned by our pmapped `update()` (see further down).
+++ {"id": "eGVKxk1CV-m1"}
We do the same to the data:
```{code-cell} ipython3
:id: vY61QJoFWCII
:outputId: f436a15f-db97-44cc-df33-bbb4ff222987
def split(arr):
"""Splits the first axis of `arr` evenly across the number of devices."""
return arr.reshape(n_devices, arr.shape[0] // n_devices, *arr.shape[1:])
# Reshape xs and ys for the pmapped `update()`.
x_split = split(xs)
y_split = split(ys)
type(x_split)
```
+++ {"id": "RzfJ-oK5WERq"}
The data is just a reshaped vanilla NumPy array. Hence, it cannot be anywhere but on the host, as NumPy runs on CPU only. Since we never modify it, it will get sent to the device at each `update` call, like in a real pipeline where data is typically streamed from CPU to the device at each step.
```{code-cell} ipython3
:id: atOTi7EeSQw-
:outputId: c8daf141-63c4-481f-afa5-684c5f7b698d
def type_after_update(name, obj):
print(f"after first `update()`, `{name}` is a", type(obj))
# Actual training loop.
for i in range(1000):
# This is where the params and data gets communicated to devices:
replicated_params, loss = update(replicated_params, x_split, y_split)
# The returned `replicated_params` and `loss` are now both jax.Arrays,
# indicating that they're on the devices.
# `x_split`, of course, remains a NumPy array on the host.
if i == 0:
type_after_update('replicated_params.weight', replicated_params.weight)
type_after_update('loss', loss)
type_after_update('x_split', x_split)
if i % 100 == 0:
# Note that loss is actually an array of shape [num_devices], with identical
# entries, because each device returns its copy of the loss.
# So, we take the first element to print it.
print(f"Step {i:3d}, loss: {loss[0]:.3f}")
# Plot results.
# Like the loss, the leaves of params have an extra leading dimension,
# so we take the params from the first device.
params = jax.device_get(jax.tree_map(lambda x: x[0], replicated_params))
```
```{code-cell} ipython3
:id: rvVCACv9UZcF
:outputId: 5c472d0f-1236-401b-be55-86e3dc43875d
import matplotlib.pyplot as plt
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend()
plt.show()
```
+++ {"id": "4wFJcqbhbn81"}
## Aside: hosts and devices in JAX
When running on TPU, the idea of a 'host' becomes important. A host is the CPU that manages several devices. A single host can only manage so many devices (usually 8), so when running very large parallel programs, multiple hosts are needed, and some finesse is required to manage them.
```{code-cell} ipython3
:id: 3DO8NwW5hurX
:outputId: 6df0bdd7-fee2-4805-9bfe-38e41bdaeb50
jax.devices()
```
+++ {"id": "sJwayfCoy15a"}
When running on CPU you can always emulate an arbitrary number of devices with a nifty `--xla_force_host_platform_device_count` XLA flag, e.g. by executing the following before importing JAX:
```python
import os
os.environ['XLA_FLAGS'] = '--xla_force_host_platform_device_count=8'
jax.devices()
```
```
[CpuDevice(id=0),
CpuDevice(id=1),
CpuDevice(id=2),
CpuDevice(id=3),
CpuDevice(id=4),
CpuDevice(id=5),
CpuDevice(id=6),
CpuDevice(id=7)]
```
This is especially useful for debugging and testing locally or even for prototyping in Colab since a CPU runtime is faster to (re-)start.

File diff suppressed because one or more lines are too long

View File

@ -1,267 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
name: python3
---
+++ {"id": "Ga0xSM8xhBIm"}
# Stateful Computations in JAX
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/07-state.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/jax-101/07-state.ipynb)
*Authors: Vladimir Mikulik*
This section explores how JAX constrains the implementation of stateful programs.
+++ {"id": "Avjnyrjojo8z"}
## Motivation
In machine learning, program state most often comes in the form of:
* model parameters,
* optimizer state, and
* stateful layers, such as [BatchNorm](https://en.wikipedia.org/wiki/Batch_normalization).
Some JAX transformations, most notably `jax.jit`, impose constraints on the functions they transform. In particular, the function transformed by `jax.jit` must have no side-effects. This is because any such side-effects will only be executed once, when the python version of the function is run during compilation. These side-effects will not be executed by the compiled function on subsequent runs.
Changing program state is one kind of side-effect. So, if we can't have side effects, how do we update model parameters, the optimizer state, and use stateful layers in our models? This colab will explain this in detail, but the short answer is: with [functional programming](https://en.wikipedia.org/wiki/Functional_programming).
+++ {"id": "s_-6semKkSzp"}
## A simple example: Counter
Let's start by looking at a simple stateful program: a counter.
```{code-cell} ipython3
:id: B3aoCHpjg8gm
:outputId: 5cbcfbf5-5c42-498f-a175-050438518337
import jax
import jax.numpy as jnp
class Counter:
"""A simple counter."""
def __init__(self):
self.n = 0
def count(self) -> int:
"""Increments the counter and returns the new value."""
self.n += 1
return self.n
def reset(self):
"""Resets the counter to zero."""
self.n = 0
counter = Counter()
for _ in range(3):
print(counter.count())
```
+++ {"id": "SQ-RNLfdiw04"}
The `n` attribute maintains the counter's _state_ between successive calls of `count`. It is modified as a side effect of calling `count`.
Let's say we want to count fast, so we `jax.jit` the `count` method. (In this example, this wouldn't actually help speed anyway, for many reasons, but treat this as a toy model of wanting to JIT-compile the update of model parameters, where `jax.jit` makes an enormous difference).
```{code-cell} ipython3
:id: 5jSjmJMon03W
:outputId: d952f16b-9b30-4753-ed94-cc914a929a36
counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
print(fast_count())
```
+++ {"id": "weiI0V7_pKGv"}
Oh no! Our counter isn't working. This is because the line
```
self.n += 1
```
in `count` is only called once, when JAX compiles the method call. Moreover, since the return value doesn't depend on the arguments to `count`, once it returns the first 1, subsequent calls to `fast_count` will always return 1. This won't do. So, how do we fix it?
## The solution: explicit state
Part of the problem with our counter was that the returned value didn't depend on the arguments, meaning a constant was "baked into" the compiled output. But it shouldn't be a constant -- it should depend on the state. Well, then why don't we make the state into an argument?
```{code-cell} ipython3
:id: 53pSdK4KoOEZ
:outputId: 5ac72b9c-7029-4bf2-de8d-1d412bd74c79
CounterState = int
class CounterV2:
def count(self, n: CounterState) -> tuple[int, CounterState]:
# You could just return n+1, but here we separate its role as
# the output and as the counter state for didactic purposes.
return n+1, n+1
def reset(self) -> CounterState:
return 0
counter = CounterV2()
state = counter.reset()
for _ in range(3):
value, state = counter.count(state)
print(value)
```
+++ {"id": "PrBjmgZtq89b"}
In this new version of `Counter`, we moved `n` to be an argument of `count`, and added another return value that represents the new, updated, state. To use this counter, we now need to keep track of the state explicitly. But in return, we can now safely `jax.jit` this counter:
```{code-cell} ipython3
:id: LO4Xzcq_q8PH
:outputId: 25c06a56-f2bf-4c54-a3c3-6e093d484362
state = counter.reset()
fast_count = jax.jit(counter.count)
for _ in range(3):
value, state = fast_count(state)
print(value)
```
+++ {"id": "MzMSWD2_sgnh"}
## A general strategy
We can apply the same process to any stateful method to convert it into a stateless one. We took a class of the form
```python
class StatefulClass
state: State
def stateful_method(*args, **kwargs) -> Output:
```
and turned it into a class of the form
```python
class StatelessClass
def stateless_method(state: State, *args, **kwargs) -> (Output, State):
```
This is a common [functional programming](https://en.wikipedia.org/wiki/Functional_programming) pattern, and, essentially, is the way that state is handled in all JAX programs.
Notice that the need for a class becomes less clear once we have rewritten it this way. We could just keep `stateless_method`, since the class is no longer doing any work. This is because, like the strategy we just applied, object-oriented programming (OOP) is a way to help programmers understand program state.
In our case, the `CounterV2` class is nothing more than a namespace bringing all the functions that use `CounterState` into one location. Exercise for the reader: do you think it makes sense to keep it as a class?
Incidentally, you've already seen an example of this strategy in the JAX pseudo-randomness API, `jax.random`, shown in the [Random Numbers section](https://colab.research.google.com/github/google/jax/blob/main/docs/jax-101/05-random-numbers.ipynb). Unlike Numpy, which manages random state using stateful classes, JAX requires the programmer to work directly with the random generator state -- the PRNG key.
+++ {"id": "I2SqRx14_z98"}
## Simple worked example: Linear Regression
Let's apply this strategy to a simple machine learning model: linear regression via gradient descent.
Here, we only deal with one kind of state: the model parameters. But generally, you'll see many kinds of state being threaded in and out of JAX functions, like optimizer state, layer statistics for batchnorm, and others.
The function to look at carefully is `update`.
```{code-cell} ipython3
:id: wQdU7DoAseW6
from typing import NamedTuple
class Params(NamedTuple):
weight: jnp.ndarray
bias: jnp.ndarray
def init(rng) -> Params:
"""Returns the initial model params."""
weights_key, bias_key = jax.random.split(rng)
weight = jax.random.normal(weights_key, ())
bias = jax.random.normal(bias_key, ())
return Params(weight, bias)
def loss(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
"""Computes the least squares error of the model's predictions on x against y."""
pred = params.weight * x + params.bias
return jnp.mean((pred - y) ** 2)
LEARNING_RATE = 0.005
@jax.jit
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray) -> Params:
"""Performs one SGD update step on params using the given data."""
grad = jax.grad(loss)(params, x, y)
# If we were using Adam or another stateful optimizer,
# we would also do something like
# ```
# updates, new_optimizer_state = optimizer(grad, optimizer_state)
# ```
# and then use `updates` instead of `grad` to actually update the params.
# (And we'd include `new_optimizer_state` in the output, naturally.)
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grad)
return new_params
```
+++ {"id": "dKySWouu2-Hu"}
Notice that we manually pipe the params in and out of the update function.
```{code-cell} ipython3
:id: jQCYYy0yxO6K
:outputId: 1f3b69d2-e90b-4065-cbcc-6422978d25c2
import matplotlib.pyplot as plt
rng = jax.random.key(42)
# Generate true data from y = w*x + b + noise
true_w, true_b = 2, -1
x_rng, noise_rng = jax.random.split(rng)
xs = jax.random.normal(x_rng, (128, 1))
noise = jax.random.normal(noise_rng, (128, 1)) * 0.5
ys = xs * true_w + true_b + noise
# Fit regression
params = init(rng)
for _ in range(1000):
params = update(params, xs, ys)
plt.scatter(xs, ys)
plt.plot(xs, params.weight * xs + params.bias, c='red', label='Model Prediction')
plt.legend();
```
+++ {"id": "1wq3L6Xg1UHP"}
## Taking it further
The strategy described above is how any (jitted) JAX program must handle state.
Handling parameters manually seems fine if you're dealing with two parameters, but what if it's a neural net with dozens of layers? You might already be getting worried about two things:
1) Are we supposed to initialize them all manually, essentially repeating what we already write in the forward pass definition?
2) Are we supposed to pipe all these things around manually?
The details can be tricky to handle, but there are examples of libraries that take care of this for you. See [JAX Neural Network Libraries](https://github.com/google/jax#neural-network-libraries) for some examples.

View File

@ -1,9 +0,0 @@
:orphan:
Introduction to `pjit`
======================
This content is no longer relevant, because :func:`~jax.pjit` and :func:`~jax.jit`
have been merged into a single unified interface.
For an updated guide to compiling and executing JAX functions in multi-host or multi-core environments,
see :doc:`../notebooks/Distributed_arrays_and_automatic_parallelization`.

View File

@ -1,24 +0,0 @@
:orphan:
.. _Jax-101:
Tutorial: JAX 101
=================
This is a tutorial developed by engineers and researchers at DeepMind_.
.. toctree::
:maxdepth: 1
:caption: Tutorials
01-jax-basics
02-jitting
03-vectorization
04-advanced-autodiff
05-random-numbers
05.1-pytrees
06-parallelism
07-state
.. _Deepmind: http://deepmind.com

View File

@ -12,7 +12,7 @@ operations (e.g. {func}`jax.lax.psum` ) in multi-process settings, although
other communication methods may be useful too depending on your use case (e.g.
RPC, [mpi4jax](https://github.com/mpi4jax/mpi4jax)). If youre not already
familiar with JAXs collective operations, we recommend starting with the
{doc}`/jax-101/06-parallelism` notebook. An important requirement of
{doc}`/sharded-computation` section. An important requirement of
multi-process environments in JAX is direct communication links between
accelerators, e.g. the high-speed interconnects for Cloud TPUs or
[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links allow
@ -123,10 +123,11 @@ global devices.
So how do you actually run a computation involving cross-process communication?
**Use the same parallel evaluation APIs that you would in a single process!**
For example, {func}`~jax.pmap` can be used to run a parallel computation across
For example, {func}`~jax.experimental.shard_map.shard_map` can be used to
run a parallel computation across
multiple processes. (If youre not already familiar with how to use
{func}`~jax.pmap` to run across multiple devices within a single process, check
out the {doc}`/jax-101/06-parallelism` notebook.) Each process should call the
`shard_map` to run across multiple devices within a single process, check
out the {doc}`/sharded-computation` tutorial.) Each process should call the
same pmapped function and pass in arguments to be mapped across its *local*
devices (i.e., the pmapped axis size is equal to the number of local devices).
Similarly, the function will return outputs sharded across *local* devices only.

View File

@ -32,7 +32,7 @@
"source": [
"![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n",
"\n",
"Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
"Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
"\n",
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
]

View File

@ -35,7 +35,7 @@ limitations under the License.
![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)
Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
Let's combine everything we showed in the [quickstart](https://colab.research.google.com/github/google/jax/blob/main/docs/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.

View File

@ -44,7 +44,7 @@
"\n",
"![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)\n",
"\n",
"Let's combine everything we showed in the [quickstart notebook](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n",
"Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n",
"\n",
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
]

View File

@ -42,7 +42,7 @@ _Forked from_ `neural_network_and_data_loading.ipynb`
![JAX](https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png)
Let's combine everything we showed in the [quickstart notebook](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).
Let's combine everything we showed in the [quickstart](https://jax.readthedocs.io/en/latest/quickstart.html) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.

View File

@ -1,609 +0,0 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "xtWX4x9DCF5_"
},
"source": [
"# JAX Quickstart\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/quickstart.ipynb)\n",
"\n",
"**JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.**\n",
"\n",
"With its updated version of [Autograd](https://github.com/hips/autograd), JAX\n",
"can automatically differentiate native Python and NumPy code. It can\n",
"differentiate through a large subset of Pythons features, including loops, ifs,\n",
"recursion, and closures, and it can even take derivatives of derivatives of\n",
"derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily\n",
"to any order.\n",
"\n",
"Whats new is that JAX uses\n",
"[XLA](https://www.tensorflow.org/xla)\n",
"to compile and run your NumPy code on accelerators, like GPUs and TPUs.\n",
"Compilation happens under the hood by default, with library calls getting\n",
"just-in-time compiled and executed. But JAX even lets you just-in-time compile\n",
"your own Python functions into XLA-optimized kernels using a one-function API.\n",
"Compilation and automatic differentiation can be composed arbitrarily, so you\n",
"can express sophisticated algorithms and get maximal performance without having\n",
"to leave Python."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"id": "SY8mDvEvCGqk"
},
"outputs": [],
"source": [
"import jax.numpy as jnp\n",
"from jax import grad, jit, vmap\n",
"from jax import random"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "FQ89jHCYfhpg"
},
"source": [
"## Multiplying Matrices"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Xpy1dSgNqCP4"
},
"source": [
"We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see [Common Gotchas in JAX].\n",
"\n",
"[Common Gotchas in JAX]: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"id": "u0nseKZNqOoH",
"outputId": "03e20e21-376c-41bb-a6bb-57431823691b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[-0.3721109 0.26423115 -0.18252768 -0.7368197 -0.44030377 -0.1521442\n",
" -0.67135346 -0.5908641 0.73168886 0.5673026 ]\n"
]
}
],
"source": [
"key = random.key(0)\n",
"x = random.normal(key, (10,))\n",
"print(x)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "hDJF0UPKnuqB"
},
"source": [
"Let's dive right in and multiply two big matrices."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"id": "eXn8GUl6CG5N",
"outputId": "ffce6bdc-86e6-4af0-ab5d-65d235022db9"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"13.5 ms ± 1.89 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"size = 3000\n",
"x = random.normal(key, (size, size), dtype=jnp.float32)\n",
"%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0AlN7EbonyaR"
},
"source": [
"We added that `block_until_ready` because JAX uses asynchronous execution by default (see {ref}`async-dispatch`).\n",
"\n",
"JAX NumPy functions work on regular NumPy arrays."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"id": "ZPl0MuwYrM7t",
"outputId": "71219657-b559-474e-a877-5441ee39f18f"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"80 ms ± 30.2 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n"
]
}
],
"source": [
"import numpy as np\n",
"x = np.random.normal(size=(size, size)).astype(np.float32)\n",
"%timeit jnp.dot(x, x.T).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_SrcB2IurUuE"
},
"source": [
"That's slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using {func}`~jax.device_put`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"id": "Jj7M7zyRskF0",
"outputId": "a649a6d3-cf28-445e-c3fc-bcfe3069482c"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"15.8 ms ± 113 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)\n"
]
}
],
"source": [
"from jax import device_put\n",
"\n",
"x = np.random.normal(size=(size, size)).astype(np.float32)\n",
"x = device_put(x)\n",
"%timeit jnp.dot(x, x.T).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "clO9djnen8qi"
},
"source": [
"The output of {func}`~jax.device_put` still acts like an NDArray, but it only copies values back to the CPU when they're needed for printing, plotting, saving to disk, branching, etc. The behavior of {func}`~jax.device_put` is equivalent to the function `jit(lambda x: x)`, but it's faster."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ghkfKNQttDpg"
},
"source": [
"If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU.\n",
"See {ref}`faq-jax-vs-numpy` for more comparison of performance characteristics of NumPy and JAX"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "iOzp0P_GoJhb"
},
"source": [
"JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:\n",
"\n",
" - {func}`~jax.jit`, for speeding up your code\n",
" - {func}`~jax.grad`, for taking derivatives\n",
" - {func}`~jax.vmap`, for automatic vectorization or batching.\n",
"\n",
"Let's go over these, one-by-one. We'll also end up composing these in interesting ways."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bTTrTbWvgLUK"
},
"source": [
"## Using {func}`~jax.jit` to speed up functions"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "YrqE32mvE3b7"
},
"source": [
"JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the `@jit` decorator to compile multiple operations together using [XLA](https://www.tensorflow.org/xla). Let's try that."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"id": "qLGdCtFKFLOR",
"outputId": "870253fa-ba1b-47ec-c5a4-1c6f706be996"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1.07 ms ± 261 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"def selu(x, alpha=1.67, lmbda=1.05):\n",
" return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)\n",
"\n",
"x = random.normal(key, (1000000,))\n",
"%timeit selu(x).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "a_V8SruVHrD_"
},
"source": [
"We can speed it up with `@jit`, which will jit-compile the first time `selu` is called and will be cached thereafter."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"id": "fh4w_3NpFYTp",
"outputId": "4d56b4f2-5d58-4689-ecc2-ac361c0245cd"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"127 µs ± 1.43 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"selu_jit = jit(selu)\n",
"%timeit selu_jit(x).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "HxpBc4WmfsEU"
},
"source": [
"## Taking derivatives with {func}`~jax.grad`\n",
"\n",
"In addition to evaluating numerical functions, we also want to transform them. One transformation is [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation). In JAX, just like in [Autograd](https://github.com/HIPS/autograd), you can compute gradients with the {func}`~jax.grad` function."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"id": "IMAgNJaMJwPD",
"outputId": "6646cc65-b52f-4825-ff7f-e50b67083493"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.25 0.19661194 0.10499357]\n"
]
}
],
"source": [
"def sum_logistic(x):\n",
" return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))\n",
"\n",
"x_small = jnp.arange(3.)\n",
"derivative_fn = grad(sum_logistic)\n",
"print(derivative_fn(x_small))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PtNs881Ohioc"
},
"source": [
"Let's verify with finite differences that our result is correct."
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"id": "JXI7_OZuKZVO",
"outputId": "18c1f913-d5d6-4895-f71e-e62180c3ad1b"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[0.24998187 0.1965761 0.10502338]\n"
]
}
],
"source": [
"def first_finite_differences(f, x):\n",
" eps = 1e-3\n",
" return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)\n",
" for v in jnp.eye(len(x))])\n",
"\n",
"\n",
"print(first_finite_differences(sum_logistic, x_small))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "Q2CUZjOWNZ-3"
},
"source": [
"Taking derivatives is as easy as calling {func}`~jax.grad`. {func}`~jax.grad` and {func}`~jax.jit` compose and can be mixed arbitrarily. In the above example we jitted `sum_logistic` and then took its derivative. We can go further:"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"id": "TO4g8ny-OEi4",
"outputId": "1a0421e6-60e9-42e3-dc9c-e558a69bbf17"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-0.0353256\n"
]
}
],
"source": [
"print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "yCJ5feKvhnBJ"
},
"source": [
"For more advanced autodiff, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products and {func}`jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose them to make a function that efficiently computes full Hessian matrices:"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"id": "Z-JxbiNyhxEW"
},
"outputs": [],
"source": [
"from jax import jacfwd, jacrev\n",
"def hessian(fun):\n",
" return jit(jacfwd(jacrev(fun)))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TI4nPsGafxbL"
},
"source": [
"## Auto-vectorization with {func}`~jax.vmap`"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "PcxkONy5aius"
},
"source": [
"JAX has one more transformation in its API that you might find useful: {func}`~jax.vmap`, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a functions primitive operations for better performance. When composed with {func}`~jax.jit`, it can be just as fast as adding the batch dimensions by hand."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "TPiX4y-bWLFS"
},
"source": [
"We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using {func}`~jax.vmap`. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions."
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"id": "8w0Gpsn8WYYj"
},
"outputs": [],
"source": [
"mat = random.normal(key, (150, 100))\n",
"batched_x = random.normal(key, (10, 100))\n",
"\n",
"def apply_matrix(v):\n",
" return jnp.dot(mat, v)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "0zWsc0RisQWx"
},
"source": [
"Given a function such as `apply_matrix`, we can loop over a batch dimension in Python, but usually the performance of doing so is poor."
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"id": "KWVc9BsZv0Ki",
"outputId": "bea78b6d-cd17-45e6-c361-1c55234e77c0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Naively batched\n",
"3.12 ms ± 176 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)\n"
]
}
],
"source": [
"def naively_batched_apply_matrix(v_batched):\n",
" return jnp.stack([apply_matrix(v) for v in v_batched])\n",
"\n",
"print('Naively batched')\n",
"%timeit naively_batched_apply_matrix(batched_x).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "qHfKaLE9stbA"
},
"source": [
"We know how to batch this operation manually. In this case, `jnp.dot` handles extra batch dimensions transparently."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "ipei6l8nvrzH",
"outputId": "335cdc4c-c603-497b-fc88-3fa37c5630c2"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Manually batched\n",
"45.6 µs ± 5.03 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"@jit\n",
"def batched_apply_matrix(v_batched):\n",
" return jnp.dot(v_batched, mat.T)\n",
"\n",
"print('Manually batched')\n",
"%timeit batched_apply_matrix(batched_x).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1eF8Nhb-szAb"
},
"source": [
"However, suppose we had a more complicated function without batching support. We can use {func}`~jax.vmap` to add batching support automatically."
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"id": "67Oeknf5vuCl",
"outputId": "9c680e74-ebb5-4563-ebfc-869fd82de091"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Auto-vectorized with vmap\n",
"48.3 µs ± 1.06 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)\n"
]
}
],
"source": [
"@jit\n",
"def vmap_batched_apply_matrix(v_batched):\n",
" return vmap(apply_matrix)(v_batched)\n",
"\n",
"print('Auto-vectorized with vmap')\n",
"%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "pYVl3Z2nbZhO"
},
"source": [
"Of course, {func}`~jax.vmap` can be arbitrarily composed with {func}`~jax.jit`, {func}`~jax.grad`, and any other JAX transformation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "WwNnjaI4th_8"
},
"source": [
"This is just a taste of what JAX can do. We're really excited to see what you do with it!"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"name": "JAX Quickstart.ipynb",
"provenance": [],
"toc_visible": true
},
"jupytext": {
"formats": "ipynb,md:myst"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@ -1,293 +0,0 @@
---
jupytext:
formats: ipynb,md:myst
text_representation:
extension: .md
format_name: myst
format_version: 0.13
jupytext_version: 1.16.1
kernelspec:
display_name: Python 3
language: python
name: python3
---
+++ {"id": "xtWX4x9DCF5_"}
# JAX Quickstart
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) [![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/google/jax/blob/main/docs/notebooks/quickstart.ipynb)
**JAX is NumPy on the CPU, GPU, and TPU, with great automatic differentiation for high-performance machine learning research.**
With its updated version of [Autograd](https://github.com/hips/autograd), JAX
can automatically differentiate native Python and NumPy code. It can
differentiate through a large subset of Pythons features, including loops, ifs,
recursion, and closures, and it can even take derivatives of derivatives of
derivatives. It supports reverse-mode as well as forward-mode differentiation, and the two can be composed arbitrarily
to any order.
Whats new is that JAX uses
[XLA](https://www.tensorflow.org/xla)
to compile and run your NumPy code on accelerators, like GPUs and TPUs.
Compilation happens under the hood by default, with library calls getting
just-in-time compiled and executed. But JAX even lets you just-in-time compile
your own Python functions into XLA-optimized kernels using a one-function API.
Compilation and automatic differentiation can be composed arbitrarily, so you
can express sophisticated algorithms and get maximal performance without having
to leave Python.
```{code-cell} ipython3
:id: SY8mDvEvCGqk
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
```
+++ {"id": "FQ89jHCYfhpg"}
## Multiplying Matrices
+++ {"id": "Xpy1dSgNqCP4"}
We'll be generating random data in the following examples. One big difference between NumPy and JAX is how you generate random numbers. For more details, see [Common Gotchas in JAX].
[Common Gotchas in JAX]: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#%F0%9F%94%AA-Random-Numbers
```{code-cell} ipython3
:id: u0nseKZNqOoH
:outputId: 03e20e21-376c-41bb-a6bb-57431823691b
key = random.key(0)
x = random.normal(key, (10,))
print(x)
```
+++ {"id": "hDJF0UPKnuqB"}
Let's dive right in and multiply two big matrices.
```{code-cell} ipython3
:id: eXn8GUl6CG5N
:outputId: ffce6bdc-86e6-4af0-ab5d-65d235022db9
size = 3000
x = random.normal(key, (size, size), dtype=jnp.float32)
%timeit jnp.dot(x, x.T).block_until_ready() # runs on the GPU
```
+++ {"id": "0AlN7EbonyaR"}
We added that `block_until_ready` because JAX uses asynchronous execution by default (see {ref}`async-dispatch`).
JAX NumPy functions work on regular NumPy arrays.
```{code-cell} ipython3
:id: ZPl0MuwYrM7t
:outputId: 71219657-b559-474e-a877-5441ee39f18f
import numpy as np
x = np.random.normal(size=(size, size)).astype(np.float32)
%timeit jnp.dot(x, x.T).block_until_ready()
```
+++ {"id": "_SrcB2IurUuE"}
That's slower because it has to transfer data to the GPU every time. You can ensure that an NDArray is backed by device memory using {func}`~jax.device_put`.
```{code-cell} ipython3
:id: Jj7M7zyRskF0
:outputId: a649a6d3-cf28-445e-c3fc-bcfe3069482c
from jax import device_put
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)
%timeit jnp.dot(x, x.T).block_until_ready()
```
+++ {"id": "clO9djnen8qi"}
The output of {func}`~jax.device_put` still acts like an NDArray, but it only copies values back to the CPU when they're needed for printing, plotting, saving to disk, branching, etc. The behavior of {func}`~jax.device_put` is equivalent to the function `jit(lambda x: x)`, but it's faster.
+++ {"id": "ghkfKNQttDpg"}
If you have a GPU (or TPU!) these calls run on the accelerator and have the potential to be much faster than on CPU.
See {ref}`faq-jax-vs-numpy` for more comparison of performance characteristics of NumPy and JAX
+++ {"id": "iOzp0P_GoJhb"}
JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:
- {func}`~jax.jit`, for speeding up your code
- {func}`~jax.grad`, for taking derivatives
- {func}`~jax.vmap`, for automatic vectorization or batching.
Let's go over these, one-by-one. We'll also end up composing these in interesting ways.
+++ {"id": "bTTrTbWvgLUK"}
## Using {func}`~jax.jit` to speed up functions
+++ {"id": "YrqE32mvE3b7"}
JAX runs transparently on the GPU or TPU (falling back to CPU if you don't have one). However, in the above example, JAX is dispatching kernels to the GPU one operation at a time. If we have a sequence of operations, we can use the `@jit` decorator to compile multiple operations together using [XLA](https://www.tensorflow.org/xla). Let's try that.
```{code-cell} ipython3
:id: qLGdCtFKFLOR
:outputId: 870253fa-ba1b-47ec-c5a4-1c6f706be996
def selu(x, alpha=1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = random.normal(key, (1000000,))
%timeit selu(x).block_until_ready()
```
+++ {"id": "a_V8SruVHrD_"}
We can speed it up with `@jit`, which will jit-compile the first time `selu` is called and will be cached thereafter.
```{code-cell} ipython3
:id: fh4w_3NpFYTp
:outputId: 4d56b4f2-5d58-4689-ecc2-ac361c0245cd
selu_jit = jit(selu)
%timeit selu_jit(x).block_until_ready()
```
+++ {"id": "HxpBc4WmfsEU"}
## Taking derivatives with {func}`~jax.grad`
In addition to evaluating numerical functions, we also want to transform them. One transformation is [automatic differentiation](https://en.wikipedia.org/wiki/Automatic_differentiation). In JAX, just like in [Autograd](https://github.com/HIPS/autograd), you can compute gradients with the {func}`~jax.grad` function.
```{code-cell} ipython3
:id: IMAgNJaMJwPD
:outputId: 6646cc65-b52f-4825-ff7f-e50b67083493
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
print(derivative_fn(x_small))
```
+++ {"id": "PtNs881Ohioc"}
Let's verify with finite differences that our result is correct.
```{code-cell} ipython3
:id: JXI7_OZuKZVO
:outputId: 18c1f913-d5d6-4895-f71e-e62180c3ad1b
def first_finite_differences(f, x):
eps = 1e-3
return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps)
for v in jnp.eye(len(x))])
print(first_finite_differences(sum_logistic, x_small))
```
+++ {"id": "Q2CUZjOWNZ-3"}
Taking derivatives is as easy as calling {func}`~jax.grad`. {func}`~jax.grad` and {func}`~jax.jit` compose and can be mixed arbitrarily. In the above example we jitted `sum_logistic` and then took its derivative. We can go further:
```{code-cell} ipython3
:id: TO4g8ny-OEi4
:outputId: 1a0421e6-60e9-42e3-dc9c-e558a69bbf17
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
```
+++ {"id": "yCJ5feKvhnBJ"}
For more advanced autodiff, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products and {func}`jax.jvp` for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose them to make a function that efficiently computes full Hessian matrices:
```{code-cell} ipython3
:id: Z-JxbiNyhxEW
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
```
+++ {"id": "TI4nPsGafxbL"}
## Auto-vectorization with {func}`~jax.vmap`
+++ {"id": "PcxkONy5aius"}
JAX has one more transformation in its API that you might find useful: {func}`~jax.vmap`, the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a functions primitive operations for better performance. When composed with {func}`~jax.jit`, it can be just as fast as adding the batch dimensions by hand.
+++ {"id": "TPiX4y-bWLFS"}
We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using {func}`~jax.vmap`. Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.
```{code-cell} ipython3
:id: 8w0Gpsn8WYYj
mat = random.normal(key, (150, 100))
batched_x = random.normal(key, (10, 100))
def apply_matrix(v):
return jnp.dot(mat, v)
```
+++ {"id": "0zWsc0RisQWx"}
Given a function such as `apply_matrix`, we can loop over a batch dimension in Python, but usually the performance of doing so is poor.
```{code-cell} ipython3
:id: KWVc9BsZv0Ki
:outputId: bea78b6d-cd17-45e6-c361-1c55234e77c0
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
```
+++ {"id": "qHfKaLE9stbA"}
We know how to batch this operation manually. In this case, `jnp.dot` handles extra batch dimensions transparently.
```{code-cell} ipython3
:id: ipei6l8nvrzH
:outputId: 335cdc4c-c603-497b-fc88-3fa37c5630c2
@jit
def batched_apply_matrix(v_batched):
return jnp.dot(v_batched, mat.T)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
```
+++ {"id": "1eF8Nhb-szAb"}
However, suppose we had a more complicated function without batching support. We can use {func}`~jax.vmap` to add batching support automatically.
```{code-cell} ipython3
:id: 67Oeknf5vuCl
:outputId: 9c680e74-ebb5-4563-ebfc-869fd82de091
@jit
def vmap_batched_apply_matrix(v_batched):
return vmap(apply_matrix)(v_batched)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
```
+++ {"id": "pYVl3Z2nbZhO"}
Of course, {func}`~jax.vmap` can be arbitrarily composed with {func}`~jax.jit`, {func}`~jax.grad`, and any other JAX transformation.
+++ {"id": "WwNnjaI4th_8"}
This is just a taste of what JAX can do. We're really excited to see what you do with it!

View File

@ -5,6 +5,7 @@ sphinx-book-theme>=1.0.1 # Older versions fail to pin pydata-sphinx-theme
sphinx-copybutton>=0.5.0
sphinx-remove-toctrees
sphinx-design
sphinxext-rediraffe
myst-nb>=1.0.0
# Packages used for CI tests.

18
docs/tutorials.rst Normal file
View File

@ -0,0 +1,18 @@
.. _jax-tutorials:
JAX tutorials
=============
.. toctree::
:maxdepth: 1
quickstart
key-concepts
jit-compilation
automatic-vectorization
automatic-differentiation
debugging
random-numbers
working-with-pytrees
sharded-computation
stateful-computations

View File

@ -1,55 +0,0 @@
:orphan:
.. _jax-tutorials:
JAX tutorials
=============
.. note::
The tutorials below are a work in progress; for the time being, please refer
to the older tutorial content at :ref:`Jax-101`, :ref:`beginner-guide` and
:ref:`user-guides`.
JAX 101
-------
.. toctree::
:maxdepth: 1
quickstart
key-concepts
jit-compilation
automatic-vectorization
automatic-differentiation
debugging
random-numbers
working-with-pytrees
sharded-computation
stateful-computations
simple-neural-network
JAX 201
-------
.. toctree::
:maxdepth: 1
parallelism
advanced-autodiff
gradient-checkpointing
advanced-debugging
external-callbacks
profiling-and-performance
JAX 301
-------
.. toctree::
:maxdepth: 1
jax-primitives
jaxpr
advanced-compilation

View File

@ -11,6 +11,7 @@ or deployed codebases.
:maxdepth: 1
:caption: Debugging and Performance
notebooks/thinking_in_jax
profiling
device_memory_profiling
debugging/index