1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 13:26:06 +00:00

DOC: read-through and edit the new jax tutorials

This commit is contained in:
Jake VanderPlas 2024-03-20 18:18:08 -07:00
parent d6f074bf41
commit d6c07bdf51
14 changed files with 304 additions and 162 deletions

@ -1,3 +1,4 @@
(prng-design-jep)=
# JAX PRNG Design
We want a PRNG design that
1. is **expressive** in that it is convenient to use and it doesnt constrain the users ability to write numerical programs with exactly the behavior that they want,

@ -14,3 +14,9 @@ kernelspec:
(advanced-debugging)=
# Advanced debugging
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
For the time being, you may find some related content in the old documentation:
- {doc}`../debugging/index`
```

@ -63,7 +63,7 @@ dfdx = jax.grad(f)
The higher-order derivatives of $f$ are:
$$
\begin{array}{l}s
\begin{array}{l}
f'(x) = 3x^2 + 4x -3\\
f''(x) = 6x + 4\\
f'''(x) = 6\\
@ -105,27 +105,27 @@ print(d4fdx(1.))
The next example shows how to compute gradients with {func}`jax.grad` in a linear logistic regression model. First, the setup:
```{code-cell}
key = jax.random.PRNGKey(0)
key = jax.random.key(0)
def sigmoid(x):
return 0.5 * (jnp.tanh(x / 2) + 1)
return 0.5 * (jnp.tanh(x / 2) + 1)
# Outputs probability of a label being true.
def predict(W, b, inputs):
return sigmoid(jnp.dot(inputs, W) + b)
return sigmoid(jnp.dot(inputs, W) + b)
# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77],
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
[0.88, -1.08, 0.15],
[0.52, 0.06, -1.30],
[0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])
# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
preds = predict(W, b, inputs)
label_probs = preds * targets + (1 - preds) * (1 - targets)
return -jnp.sum(jnp.log(label_probs))
# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
@ -138,20 +138,20 @@ Use the {func}`jax.grad` function with its `argnums` argument to differentiate a
```{code-cell}
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)
print(f'{W_grad=}')
# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)
print(f'{W_grad=}')
# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)
print(f'{b_grad=}')
# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)
print(f'{W_grad=}')
print(f'{b_grad=}')
```
The {func}`jax.grad` API has a direct correspondence to the excellent notation in Spivak's classic *Calculus on Manifolds* (1965), also used in Sussman and Wisdom's [*Structure and Interpretation of Classical Mechanics*](https://mitpress.mit.edu/9780262028967/structure-and-interpretation-of-classical-mechanics) (2015) and their [*Functional Differential Geometry*](https://mitpress.mit.edu/9780262019347/functional-differential-geometry) (2013). Both books are open-access. See in particular the "Prologue" section of *Functional Differential Geometry* for a defense of this notation.
@ -162,7 +162,8 @@ Essentially, when using the `argnums` argument, if `f` is a Python function for
(automatic-differentiation-nested-lists-tuples-and-dicts)=
## 3. Differentiating with respect to nested lists, tuples, and dicts
Differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
Due to JAX's PyTree abstraction (see {ref}`thinking-in-jax-pytrees`), differentiating with
respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.
Continuing the previous example:
@ -181,7 +182,7 @@ You can {ref}`pytrees-custom-pytree-nodes` to work with not just {func}`jax.grad
(automatic-differentiation-evaluating-using-jax-value_and_grad)=
## 4. Evaluating a function and its gradient using `jax.value_and_grad`
Another convenient function is {func}`jax.value_and_grad` for efficiently computing both a function's value as well as its gradient's value.
Another convenient function is {func}`jax.value_and_grad` for efficiently computing both a function's value as well as its gradient's value in one pass.
Continuing the previous examples:

@ -26,9 +26,9 @@ Let's begin with {func}`jax.debug.print`.
- Use {func}`jax.debug.print` for traced (dynamic) array values with {func}`jax.jit`, {func}`jax.vmap` and others.
- Use Python `print` for static values, such as dtypes and array shapes.
With some JAX transformations, such as {func}`jax.grad` and {func}`jax.vmap`, you can use Pythons built-in `print` function to print out numerical values. However, with {func}`jax.jit` for example, you need to use {func}`jax.debug.print`, because those transformations delay numerical evaluation.
Below is a basic example with {func}`jax.jit`:
Recall from {ref}`jit-compilation` that when transforming a function with {func}`jax.jit`,
the Python code is executed with abstract tracers in place of your arrays. Because of this,
the Python `print` statement will only print this tracer value:
```{code-cell}
import jax
@ -36,34 +36,48 @@ import jax.numpy as jnp
@jax.jit
def f(x):
jax.debug.print("This is `jax.debug.print` of x {x}", x=x)
y = jnp.sin(x)
jax.debug.print("This is `jax.debug.print` of y {y} 🤯", y=y)
return y
print("print(x) ->", x)
y = jnp.sin(x)
print("print(y) ->", y)
return y
f(2.)
result = f(2.)
```
{func}`jax.debug.print` can reveal the information about how computations are evaluated.
Python's `print` executes at trace-time, before the runtime values exist.
If you want to print the actual runtime values, you can use {func}`jax.debug.print`:
Here's an example with {func}`jax.vmap`:
```{code-cell}
@jax.jit
def f(x):
jax.debug.print("jax.debug.print(x) -> {x}", x=x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {y}", y=y)
return y
result = f(2.)
```
Similarly, within {func}`jax.vmap`, using Python's `print` will only print the tracer;
to print the values being mapped over, use {func}`jax.debug.print`:
```{code-cell}
def f(x):
jax.debug.print("This is `jax.debug.print` of x: {}", x)
y = jnp.sin(x)
jax.debug.print("This is `jax.debug.print` of y: {}", y)
return y
jax.debug.print("jax.debug.print(x) -> {}", x)
y = jnp.sin(x)
jax.debug.print("jax.debug.print(y) -> {}", y)
return y
xs = jnp.arange(3.)
jax.vmap(f)(xs)
result = jax.vmap(f)(xs)
```
Here's an example with {func}`jax.lax.map`:
Here's the result with {func}`jax.lax.map`, which is a sequential map rather than a
vectorization:
```{code-cell}
jax.lax.map(f, xs)
result = jax.lax.map(f, xs)
```
Notice the order is different, as {func}`jax.vmap` and {func}`jax.lax.map` compute the same results in different ways. When debugging, the evaluation order details are exactly what you may need to inspect.
@ -72,10 +86,10 @@ Below is an example with {func}`jax.grad`, where {func}`jax.debug.print` only pr
```{code-cell}
def f(x):
jax.debug.print("This is `jax.debug.print` of x: {}", x)
return x ** 2
jax.debug.print("jax.debug.print(x) -> {}", x)
return x ** 2
jax.grad(f)(1.)
result = jax.grad(f)(1.)
```
Sometimes, when the arguments don't depend on one another, calls to {func}`jax.debug.print` may print them in a different order when staged out with a JAX transformation. If you need the original order, such as `x: ...` first and then `y: ...` second, add the `ordered=True` parameter.
@ -85,9 +99,11 @@ For example:
```{code-cell}
@jax.jit
def f(x, y):
jax.debug.print("This is `jax.debug.print of x: {}", x, ordered=True)
jax.debug.print("This is `jax.debug.print of y: {}", y, ordered=True)
return x + y
jax.debug.print("jax.debug.print(x) -> {}", x, ordered=True)
jax.debug.print("jax.debug.print(y) -> {}", y, ordered=True)
return x + y
f(1, 2)
```
To learn more about {func}`jax.debug.print` and its Sharp Bits, refer to {ref}`advanced-debugging`.
@ -101,11 +117,24 @@ To pause your compiled JAX program during certain points during debugging, you c
To print all available commands during a `breakpoint` debugging session, use the `help` command. (Full debugger commands, the Sharp Bits, its strengths and limitations are covered in {ref}`advanced-debugging`.)
Example:
Here is an example of what a debugger session might look like:
```{code-cell}
:tags: [raises-exception]
:tags: [skip-execution]
@jax.jit
def f(x):
y, z = jnp.sin(x, jnp.cos(x))
jax.debug.breakpoint()
return y * z
f(2.) # ==> Pauses during execution
```
![JAX debugger](../_static/debugger.gif)
For value-dependent breakpointing, you can use runtime conditionals like {func}`jax.lax.cond`:
```{code-cell}
def breakpoint_if_nonfinite(x):
is_finite = jnp.isfinite(x).all()
def true_fn(x):
@ -119,20 +148,32 @@ def f(x, y):
z = x / y
breakpoint_if_nonfinite(z)
return z
f(2., 1.) # ==> No breakpoint
```
```{code-cell}
:tags: [skip-execution]
f(2., 0.) # ==> Pauses during execution
```
![JAX debugger](../_static/debugger.gif)
## JAX `debug.callback` for more control during debugging
As mentioned in the beginning, {func}`jax.debug.print` is a small wrapper around {func}`jax.debug.callback`. The {func}`jax.debug.callback` method allows you to have greater control over string formatting and the debugging output, like printing or plotting. It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in {ref]`external-callbacks` for more information).
Both {func}`jax.debug.print` and {func}`jax.debug.breakpoint` are implemented using
the more flexible {func}`jax.debug.callback`, which gives greater control over the
host-side logic executed via a Python callback.
It is compatible with {func}`jax.jit`, {func}`jax.vmap`, {func}`jax.grad` and other
transformations (refer to the {ref}`external-callbacks-flavors-of-callback` table in
{ref}`external-callbacks` for more information).
For example:
```{code-cell}
import logging
def log_value(x):
print("log:", x)
logging.warning(f'Logged value: {x}')
@jax.jit
def f(x):
@ -142,7 +183,7 @@ def f(x):
f(1.0);
```
This callback is compatible with {func}`jax.vmap` and {func}`jax.grad`:
This callback is compatible with other transformations, including {func}`jax.vmap` and {func}`jax.grad`:
```{code-cell}
x = jnp.arange(5.0)
@ -155,7 +196,7 @@ jax.grad(f)(1.0);
This can make {func}`jax.debug.callback` useful for general-purpose debugging.
You can learn more about different flavors of JAX callbacks in {ref}`external-callbacks-flavors-of-callback` and {ref}`external-callbacks-exploring-debug-callback`.
You can learn more about {func}`jax.debug.callback` and other kinds of JAX callbacks in {ref}`external-callbacks`.
## Next steps

@ -12,6 +12,13 @@ kernelspec:
name: python3
---
```{code-cell}
:tags: [remove-cell]
# This ensures that code cell tracebacks appearing below will be concise.
%xmode minimal
```
(external-callbacks)=
# External callbacks
@ -117,10 +124,6 @@ jax.lax.scan(body_fun, None, jnp.arange(5.0))[1]
However, because there is no way for JAX to introspect the content of the callback, `pure_callback` has undefined autodiff semantics:
```{code-cell}
%xmode minimal
```
```{code-cell}
:tags: [raises-exception]

@ -17,9 +17,7 @@ JAX 101
.. toctree::
:maxdepth: 1
installation
quickstart
jax-as-accelerated-numpy
thinking-in-jax
jit-compilation
automatic-vectorization
@ -55,3 +53,12 @@ JAX 301
jax-primitives
jaxpr
advanced-compilation
Reference
---------
.. toctree::
:maxdepth: 1
installation

@ -1,5 +1,5 @@
(installation)=
# How to install JAX
# Installing JAX
This guide provides instructions for:
@ -9,11 +9,14 @@ This guide provides instructions for:
**TL;DR** For most users, a typical JAX installation may look something like this:
| Hardware | Installation |
|------------------------------------|--------------------------------------------|
| CPU-only, Linux/macOS/Windows | `pip install -U "jax[cpu]"` |
| NVIDIA, CUDA 12, x86_64 | `pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html`|
* **CPU-only (Linux/macOS/Windows)**
```
pip install -U "jax[cpu]"
```
* **GPU (NVIDIA, CUDA 12, x86_64)**
```
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
(install-supported-platforms)=
## Supported platforms

@ -1,8 +0,0 @@
# JAX as accelerated NumPy
```{note}
This is a placeholder for a section in the new {ref}`jax-tutorials`.
For the time being, you may find some related content in the old documentation:
- {doc}`../jax-101/01-jax-basics`
```

@ -12,6 +12,13 @@ kernelspec:
name: python3
---
```{code-cell}
:tags: [remove-cell]
# This ensures that code cell tracebacks appearing below will be concise.
%xmode minimal
```
(jit-compilation)=
# Just-in-time compilation

@ -29,7 +29,11 @@ JAX can be installed for CPU on Linux, Windows, and macOS directly from the [Pyt
```
pip install "jax[cpu]"
```
For more detailed installation information, including installation with GPU support, check out {ref}`installation`.
or, for NVIDIA GPU:
```
pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```
For more detailed platform-specific installation information, check out {ref}`installation`.
## JAX as NumPy
@ -121,6 +125,13 @@ In the above example we jitted `sum_logistic` and then took its derivative. We c
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
```
The {func}`jax.jacobian` transformation can be used to compute gradients of vector-valued functions:
```{code-cell}
from jax import jacobian
print(jacobian(jnp.exp)(x_small))
```
For more advanced autodiff, you can use {func}`jax.vjp` for reverse-mode vector-Jacobian products and {func}`jax.jvp` for forward-mode Jacobian-vector products.
The two can be composed arbitrarily with one another, and with other JAX transformations.
Here's one way to compose them to make a function that efficiently computes full Hessian matrices:
@ -140,7 +151,7 @@ For more on automatic differentiation in JAX, check out {ref}`automatic-differen
Another useful transformation is {func}`~jax.vmap`, the vectorizing map.
It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a functions primitive operations for better performance.
When composed with {func}`~jax.jit`, it can be just as fast as adding the batch dimensions manually.
When composed with {func}`~jax.jit`, it can be just as performant as manually rewriting your function operate over an extra batch dimension.
We're going to work with a simple example, and promote matrix-vector products into matrix-matrix products using {func}`~jax.vmap`.
Although this is easy to do by hand in this specific case, the same technique can apply to more complicated functions.

@ -26,7 +26,7 @@ To better understand the difference between the approaches taken by JAX and NumP
## Random numbers in NumPy
Pseudo random number generation is natively supported in NumPy by the {mod}`numpy.random` module.
In NumPy, pseudo random number generation is based on a global `state`, which can be set to a deterministic initial condition using {func}`np.random.seed`.
In NumPy, pseudo random number generation is based on a global `state`, which can be set to a deterministic initial condition using {func}`numpy.random.seed`.
```{code-cell}
import numpy as np
@ -192,4 +192,17 @@ key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))
```
Note that contrary to our recommendation above, we use `key` directly as an input to {func}`random.normal` in the second example. This is because we won't reuse it anywhere else, so we don't violate the single-use principle.
The lack of sequential equivalence gives us freedom to write code more efficiently; for example,
instead of generating `sequence` above via a sequential loop, we can use {func}`jax.vmap` to
compute the same result in a vectorized manner:
```{code-cell}
import jax
print("vectorized:", jax.vmap(random.normal)(subkeys))
```
## Next Steps
For more information on JAX random numbers, refer to the documentation of the {mod}`jax.random`
module. If you're interested in the details of the design of JAX's random number generator,
see {ref}`prng-design-jep`.

@ -1,3 +1,4 @@
(single-host-sharding)=
# Sharded data on a single host
```{note}

@ -12,6 +12,13 @@ kernelspec:
name: python3
---
```{code-cell}
:tags: [remove-cell]
# This ensures that code cell tracebacks appearing below will be concise.
%xmode minimal
```
(thinking-in-jax)=
# How to think in JAX
@ -73,10 +80,6 @@ print(x)
The equivalent in JAX results in an error, as JAX arrays are immutable:
```{code-cell}
%xmode minimal
```
```python
:tags: [raises-exception]
# JAX: immutable arrays
@ -86,7 +89,7 @@ x[0] = 10
For updating individual elements, JAX provides an [indexed update syntax](https://jax.readthedocs.io/en/latest/jax.ops.html#indexed-update-operators) that returns an updated copy:
```python
```{code-cell}
y = x.at[0].set(10)
print(x)
print(y)
@ -98,53 +101,91 @@ print(y)
**Key concepts:**
- `jax.Array` is the default array implementation in JAX.
- The JAX array is a unified distributed datatype for representing arrays, even with physical storage spanning multiple devices
- Automatic parallelization: You can operate over sharded `jax.Array`s without copying data onto a device using the {func}`jax.jit` transformation. You can also replicate a `jax.Array` to every device on a mesh.
- JAX arrays may be stored on a single device, or sharded across many devices.
Consider this simple example:
When you create an array in JAX, the type is `jax.Array`:
```{code-cell}
import jax
from jax import Array
import jax.numpy as jnp
x = jnp.arange(5)
isinstance(x, jax.Array) # Returns True both inside and outside traced functions.
def f(x: Array) -> Array: # Type annotations are valid for traced and non-traced types.
return x
x = jnp.arange(10)
isinstance(x, jax.Array)
```
The `jax.Array` type also helps make parallelism a core feature of JAX.
`jax.Array` is also the appropriate type annotation for functions with array inputs or outputs:
```{code-cell}
def f(x: jax.Array) -> jax.Array:
return jnp.sin(x) ** 2 + jnp.cos(x) ** 2
```
JAX Array objects have a `devices` method that lets you inspect where the contents of the array are stored. In the simplest cases, this will be a single CPU device:
```{code-cell}
x.devices()
```
In general, an array may be *sharded* across multiple devices, in a manner that can be inspected via the `sharding` attribute:
```{code-cell}
x.sharding
```
In this case the sharding is on a single device, but in general a JAX array can be
sharded across multiple devices, or even multiple hosts.
To read more about sharded arrays and parallel computation, refer to {ref}`single-host-sharding`
(thinking-in-jax-pytrees)=
# Pytrees
## Pytrees
**Key concepts:**
- JAX supports a special data structure called a pytree when you need to operate on dictionaries of lists, for example.
- Use cases: machine learning model parameters, dataset entries, lists of lists of dictionaries.
- JAX supports tuples, dicts, lists, and more general containers of arrays through the
*pytree* abstraction.
JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — they are called JAX pytrees (also known as nests, or just trees). In the context of machine learning, a pytree can contain model parameters, dataset entries, and reinforcement learning agent observations.
Often it is convenient for applications to work with collections of arrays: for example,
a neural network might organize its parameters in a dictionary of arrays with meaningful
keys. Rather than handle such structures on a case-by-case basis, JAX relies on a *pytree*
abstraction to treat such collections in a uniform matter.
In JAX any pytree is safe to pass to transformed functions, which makes them much more flexible
than if they only accepted single arrays as arguments.
Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree_util.tree_leaves`, to extract the flattened leaves from the trees, as demonstrated here:
Here are some examples of objects that can be treated as pytrees:
```{code-cell}
example_trees = [
[1, 'a', object()],
(1, (2, 3), ()),
[1, {'k1': 2, 'k2': (3, 4)}, 5],
{'a': 2, 'b': (2, 3)},
jnp.array([1, 2, 3]),
]
# (nested) list of parameters
params = [1, 2, (jnp.arange(3), jnp.ones(2))]
# Let's see how many leaves they have:
for pytree in example_trees:
leaves = jax.tree_util.tree_leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
```
{func}`jax.tree_map` is the most commonly used pytree function in JAX. It works analogously to Python's native map, but on entire pytrees.
```{code-cell}
# Dictionary of parameters
params = {'n': 5, 'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
```
```{code-cell}
# Named tuple of parameters
from typing import NamedTuple
class Params(NamedTuple):
a: int
b: float
params = Params(1, 5.0)
print(jax.tree.structure(params))
print(jax.tree.leaves(params))
```
JAX has a number of general-purpose utilities for working with PyTrees; for example
the functions {func}`jax.tree.map` can be used to map a function to every leaf in a
tree, and {func}`jax.tree.reduce` can be used to apply a reduction across the leaves
in a tree.
You can learn more in the {ref}`working-with-pytrees` tutorial.
@ -312,7 +353,7 @@ f(x2, y2)
The extracted sequence of operations is encoded in a JAX expression, or *jaxpr* for short. You can view the jaxpr using the {func}`jax.make_jaxpr` transformation:
```python
```{code-cell}
from jax import make_jaxpr
def f(x, y):

@ -12,6 +12,13 @@ kernelspec:
name: python3
---
```{code-cell}
:tags: [remove-cell]
# This ensures that code cell tracebacks appearing below will be concise.
%xmode minimal
```
(working-with-pytrees)=
# Working with pytrees
@ -31,7 +38,7 @@ In the context of machine learning (ML), a pytree can contain:
When working with datasets, you can often come across pytrees (such as lists of lists of dicts).
Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree_util.tree_leaves`, to extract the flattened leaves from the trees, as demonstrated here:
Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree.leaves`, to extract the flattened leaves from the trees, as demonstrated here:
```{code-cell}
import jax
@ -47,8 +54,8 @@ example_trees = [
# Print how many leaves the pytrees have.
for pytree in example_trees:
# This `jax.tree_util.tree_leaves()` method extracts the flattened leaves from the pytrees.
leaves = jax.tree_util.tree_leaves(pytree)
# This `jax.tree.leaves()` method extracts the flattened leaves from the pytrees.
leaves = jax.tree.leaves(pytree)
print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")
```
@ -66,9 +73,9 @@ JAX will use these functions to canonicalize any tree of registered container ob
JAX provides a number of utilities to operate over pytrees. These can be found in the {mod}`jax.tree_util` subpackage.
### Common function: `jax.tree_map`
### Common function: `jax.tree.map`
The most commonly used pytree function is {func}`jax.tree_map`. It works analogously to Python's native `map`, but transparently operates over entire pytrees.
The most commonly used pytree function is {func}`jax.tree.map`. It works analogously to Python's native `map`, but transparently operates over entire pytrees.
Here's an example:
@ -79,20 +86,20 @@ list_of_lists = [
[1, 2, 3, 4]
]
jax.tree_map(lambda x: x*2, list_of_lists)
jax.tree.map(lambda x: x*2, list_of_lists)
```
{func}`jax.tree_map` also allows mapping a [N-ary](https://en.wikipedia.org/wiki/N-ary) function over multiple arguments. For example:
{func}`jax.tree.map` also allows mapping a [N-ary](https://en.wikipedia.org/wiki/N-ary) function over multiple arguments. For example:
```{code-cell}
another_list_of_lists = list_of_lists
jax.tree_map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
jax.tree.map(lambda x, y: x+y, list_of_lists, another_list_of_lists)
```
When using multiple arguments with {func}`jax.tree_map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.
When using multiple arguments with {func}`jax.tree.map`, the structure of the inputs must exactly match. That is, lists must have the same number of elements, dicts must have the same keys, etc.
(pytrees-example-jax-tree-map-ml)=
### Example of `jax.tree_map` with ML model parameters
### Example of `jax.tree.map` with ML model parameters
This example demonstrates how pytree operations can be useful when training a simple [multi-layer perceptron (MLP)](https://en.wikipedia.org/wiki/Multilayer_perceptron).
@ -114,10 +121,10 @@ def init_mlp_params(layer_widths):
params = init_mlp_params([1, 128, 128, 1])
```
Use {func}`jax.tree_map` to check the shapes of the initial parameters:
Use {func}`jax.tree.map` to check the shapes of the initial parameters:
```{code-cell}
jax.tree_map(lambda x: x.shape, params)
jax.tree.map(lambda x: x.shape, params)
```
Next, define the functions for training the MLP model:
@ -147,7 +154,7 @@ def update(params, x, y):
# `jax.grad` is one of many JAX functions that has
# built-in support for pytrees.
# This is useful - you can apply the SGD update using JAX pytree utilities.
return jax.tree_map(
return jax.tree.map(
lambda p, g: p - LEARNING_RATE * g, params, grads
)
```
@ -155,7 +162,7 @@ def update(params, x, y):
(pytrees-custom-pytree-nodes)=
## Custom pytree nodes
This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {func}`jax.tree_util.register_pytree_node` with {func}`jax.tree_map`.
This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {func}`jax.tree_util.register_pytree_node` with {func}`jax.tree.map`.
Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you _register_ it with JAX. This is also the case even if your container class has trees inside it. For example:
@ -165,23 +172,22 @@ class Special(object):
self.x = x
self.y = y
jax.tree_util.tree_leaves([
jax.tree.leaves([
Special(0, 1),
Special(2, 4),
])
```
Accordingly, if you try to use a {func}`jax.tree_map` expecting the leaves to be elements inside the container, you will get an error:
Accordingly, if you try to use a {func}`jax.tree.map` expecting the leaves to be elements inside the container, you will get an error:
```{code-cell}
try:
jax.tree_map(lambda x: x + 1,
[
Special(0, 1),
Special(2, 4),
])
except TypeError as e:
print(f'TypeError: {e}')
:tags: [raises-exception]
jax.tree.map(lambda x: x + 1,
[
Special(0, 1),
Special(2, 4)
])
```
As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively.
@ -235,11 +241,11 @@ register_pytree_node(
Now you can traverse the special container structure:
```{code-cell}
jax.tree_map(lambda x: x + 1,
[
RegisteredSpecial(0, 1),
RegisteredSpecial(2, 4),
])
jax.tree.map(lambda x: x + 1,
[
RegisteredSpecial(0, 1),
RegisteredSpecial(2, 4),
])
```
Modern Python comes equipped with helpful tools to make defining containers easier. Some will work with JAX out-of-the-box, but others require more care.
@ -257,7 +263,7 @@ class MyOtherContainer(NamedTuple):
# NamedTuple subclasses are handled as pytree nodes, so
# this will work out-of-the-box.
jax.tree_util.tree_leaves([
jax.tree.leaves([
MyOtherContainer('Alice', 1, 2, 3),
MyOtherContainer('Bob', 4, 5, 6)
])
@ -275,28 +281,28 @@ Some JAX function transformations take optional parameters that specify how cert
For example, if you pass the following input to {func}`jax.vmap` (note that the input arguments to a function are considered a tuple):
```
(a1, {"k1": a2, "k2": a3})
```python
vmap(f, in_axes=(a1, {"k1": a2, "k2": a3}))
```
then you can use the following `in_axes` pytree to specify that only the `k2` argument is mapped (`axis=0`), and the rest arent mapped over (`axis=None`):
```
(None, {"k1": None, "k2": 0})
```python
vmap(f, in_axes=(None, {"k1": None, "k2": 0}))
```
The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree.
For example, if you have the same {func}`jax.vmap` input as above, but wish to only map over the dictionary argument, you can use:
```
(None, 0) # equivalent to (None, {"k1": 0, "k2": 0})
```python
vmap(f, in_axes=(None, 0)) # equivalent to (None, {"k1": 0, "k2": 0})
```
Alternatively, if you want every argument to be mapped, you can write a single leaf value that is applied over the entire argument tuple pytree:
```
0
```python
vmap(f, in_axes=0) # equivalent to (0, {"k1": 0, "k2": 0})
```
This happens to be the default `in_axes` value for {func}`jax.vmap`.
@ -312,8 +318,8 @@ For built-in pytree node types, the set of keys for any pytree node instance is
JAX has the following `jax.tree_util.*` methods for working with key paths:
- {func}`jax.tree_util.tree_flatten_with_path`: Works similarly to {func}`jax.tree_util.tree_flatten`, but returns key paths.
- {func}`jax.tree_util.tree_map_with_path``: Works similarly to {func}`jax.tree_util.tree_map`, but the function also takes key paths as arguments.
- {func}`jax.tree_util.tree_flatten_with_path`: Works similarly to {func}`jax.tree.flatten`, but returns key paths.
- {func}`jax.tree_util.tree_map_with_path`: Works similarly to {func}`jax.tree.map`, but the function also takes key paths as arguments.
- {func}`jax.tree_util.keystr`: Given a general key path, returns a reader-friendly string expression.
For example, one use case is to print debugging information related to a certain leaf value:
@ -327,7 +333,7 @@ tree = [1, {'k1': 2, 'k2': (3, 4)}, ATuple('foo')]
flattened, _ = jax.tree_util.tree_flatten_with_path(tree)
for key_path, value in flattened:
print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
print(f'Value of tree{jax.tree_util.keystr(key_path)}: {value}')
```
To express key paths, JAX provides a few default key types for the built-in pytree node types, namely:
@ -340,7 +346,7 @@ You are free to define your own key types for your custom nodes. They will work
```{code-cell}
for key_path, _ in flattened:
print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
print(f'Key path of tree{jax.tree_util.keystr(key_path)}: {repr(key_path)}')
```
(pytrees-common-pytree-gotchas)=
@ -356,26 +362,30 @@ A common gotcha to look out for is accidentally introducing _tree nodes_ instead
a_tree = [jnp.zeros((2, 3)), jnp.zeros((3, 4))]
# Try to make another pytree with ones instead of zeros.
shapes = jax.tree_map(lambda x: x.shape, a_tree)
jax.tree_map(jnp.ones, shapes)
shapes = jax.tree.map(lambda x: x.shape, a_tree)
jax.tree.map(jnp.ones, shapes)
```
What happened here is that the `shape` of an array is a tuple, which is a pytree node, with its elements as leaves. Thus, in the map, instead of calling `jnp.ones` on e.g. `(2, 3)`, it's called on `2` and `3`.
The solution will depend on the specifics, but there are two broadly applicable options:
- Rewrite the code to avoid the intermediate {func}`jax.tree_map`.
- Rewrite the code to avoid the intermediate {func}`jax.tree.map`.
- Convert the tuple into a NumPy array (`np.array`) or a JAX NumPy array (`jnp.array`), which makes the entire sequence a leaf.
### Handling of `None` by `jax.tree_utils`
### Handling of `None` by `jax.tree_util`
`jax.tree_utils` treats `None` as the absence of a pytree node, not as a leaf:
`jax.tree_util` functions treat `None` as the absence of a pytree node, not as a leaf:
```{code-cell}
jax.tree_util.tree_leaves([None, None, None])
jax.tree.leaves([None, None, None])
```
Note that this is different from how the (now deprecated) [`tree` (`dm_tree`)](https://github.com/google-deepmind/tree) library used to treat `None`.
To treat `None` as a leaf, you can use the `is_leaf` argument:
```{code-cell}
jax.tree.leaves([None, None, None], is_leaf=lambda x: x is None)
```
### Custom pytrees and initialization with unexpected values
@ -394,6 +404,11 @@ register_pytree_node(MyTree, lambda tree: ((tree.a,), None),
tree = MyTree(jnp.arange(5.0))
jax.vmap(lambda x: x)(tree) # Error because object() is passed to `MyTree`.
```
```{code-cell}
:tags: [raises-exception]
jax.jacobian(lambda x: x)(tree) # Error because MyTree(...) is passed to `MyTree`.
```
@ -429,30 +444,30 @@ def tree_unflatten(aux_data, children):
This section covers some of the most common patterns with JAX pytrees.
### Transposing pytrees with `jax.tree_map` and `jax.tree_util.tree_transpose`
### Transposing pytrees with `jax.tree.map` and `jax.tree.transpose`
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree_map` (more basic) and {func}`jax.tree_util.tree_transpose` (more flexible, complex and verbose).
To transpose a pytree (turn a list of trees into a tree of lists), JAX has two functions: {func} `jax.tree.map` (more basic) and {func}`jax.tree.transpose` (more flexible, complex and verbose).
**Option 1:** Use {func}`jax.tree_map`. Here's an example:
**Option 1:** Use {func}`jax.tree.map`. Here's an example:
```{code-cell}
def tree_transpose(list_of_trees):
"""
Converts a list of trees of identical structure into a single tree of lists.
"""
return jax.tree_map(lambda *xs: list(xs), *list_of_trees)
return jax.tree.map(lambda *xs: list(xs), *list_of_trees)
# Convert a dataset from row-major to column-major.
episode_steps = [dict(t=1, obs=3), dict(t=2, obs=4)]
tree_transpose(episode_steps)
```
**Option 2:** For more complex transposes, use {func}`jax.tree_util.tree_transpose`, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example:
**Option 2:** For more complex transposes, use {func}`jax.tree.transpose`, which is more verbose, but allows you specify the structure of the inner and outer pytree for more flexibility. For example:
```{code-cell}
jax.tree_util.tree_transpose(
outer_treedef = jax.tree_util.tree_structure([0 for e in episode_steps]),
inner_treedef = jax.tree_util.tree_structure(episode_steps[0]),
jax.tree.transpose(
outer_treedef = jax.tree.structure([0 for e in episode_steps]),
inner_treedef = jax.tree.structure(episode_steps[0]),
pytree_to_transpose = episode_steps
)
```