1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

DOC: many small fixes

This commit is contained in:
elliotwaite 2021-08-02 17:57:09 -07:00
parent df103f7e66
commit 7392a57b75
65 changed files with 199 additions and 199 deletions

@ -50,7 +50,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Bug fixes:
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
not used with invalid `axis` value, or with an empty reduction dimension.
not used with an invalid `axis` value, or with an empty reduction dimension.
({jax-issue}`#7196`)
@ -333,7 +333,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* Bug fixes:
* `jax.numpy.arccosh` now returns the same branch as `numpy.arccosh` for
complex inputs ({jax-issue}`#5156`)
* `host_callback.id_tap` now works for `jax.pmap` also. There is a
* `host_callback.id_tap` now works for `jax.pmap` also. There is an
optional parameter for `id_tap` and `id_print` to request that the
device from which the value is tapped be passed as a keyword argument
to the tap function ({jax-issue}`#5182`).
@ -359,7 +359,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* New features:
* Add `jax.device_put_replicated`
* Add multi-host support to `jax.experimental.sharded_jit`
* Add support for differentiating eigenvaleus computed by `jax.numpy.linalg.eig`
* Add support for differentiating eigenvalues computed by `jax.numpy.linalg.eig`
* Add support for building on Windows platforms
* Add support for general in_axes and out_axes in `jax.pmap`
* Add complex support for `jax.numpy.linalg.slogdet`
@ -504,7 +504,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74).
* New Features:
* BFGS (#3101)
* TPU suppot for half-precision arithmetic (#3878)
* TPU support for half-precision arithmetic (#3878)
* Bug Fixes:
* Prevent some accidental dtype warnings (#3874)
* Fix a multi-threading bug in custom derivatives (#3845, #3869)

@ -109,7 +109,8 @@ or the [examples](https://github.com/google/jax/tree/main/examples).
## Transformations
At its core, JAX is an extensible system for transforming numerical functions.
Here are four of primary interest: `grad`, `jit`, `vmap`, and `pmap`.
Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and
`pmap`.
### Automatic differentiation with `grad`

@ -99,7 +99,7 @@ The name "omnistaging" means staging out everything possible.
### Toy example
iJAX transformations like `jit` and `pmap` stage out computations to XLA. That
JAX transformations like `jit` and `pmap` stage out computations to XLA. That
is, we apply them to functions comprising multiple primitive operations so that
rather being executed one at a time from Python the operations are all part of
one end-to-end optimized XLA computation.

@ -666,7 +666,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Notice both `lift` and `sublift` package a value into a `JVPTracer` with the\n",
"Notice both `pure` and `lift` package a value into a `JVPTracer` with the\n",
"minimal amount of context, which is a zero tangent value.\n",
"\n",
"Let's add some JVP rules for primitives:"
@ -1312,7 +1312,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Jaxpr data strutures\n",
"### Jaxpr data structures\n",
"\n",
"The jaxpr term syntax is roughly:\n",
"\n",
@ -2720,7 +2720,7 @@
" g:float64[] = neg e\n",
" in ( g ) }\n",
"```\n",
"This second jaxpr is represents the linear computation that we want from\n",
"This second jaxpr represents the linear computation that we want from\n",
"`linearize`.\n",
"\n",
"However, unlike in this jaxpr example, we want the computation on known values\n",
@ -2729,7 +2729,7 @@
"operations out of Python first before sorting out what can be evaluated now\n",
"and what must be delayed, we want only to form a jaxpr for those operations\n",
"that _must_ be delayed due to a dependence on unknown inputs. In the context\n",
"of automatic differentiation, this is the feature ultimately enables us to\n",
"of automatic differentiation, this is the feature that ultimately enables us to\n",
"handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control\n",
"flow works because partial evaluation keeps the primal computation in Python.\n",
"As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out\n",
@ -2874,9 +2874,10 @@
"(evaluating it in Python) and avoid forming tracers corresponding to the\n",
"output. If instead any input is unknown then we instead stage out into a\n",
"`JaxprEqnRecipe` representing the primitive application. To build the tracers\n",
"representing unknown outputs, we need avals, which get from the abstract eval\n",
"rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s\n",
"reference tracers; we avoid circular garbage by using weakrefs.)\n",
"representing unknown outputs, we need avals, which we get from the abstract\n",
"eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and\n",
"`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using\n",
"weakrefs.)\n",
"\n",
"That `process_primitive` logic applies to most primitives, but `xla_call_p`\n",
"requires recursive treatment. So we special-case its rule in a\n",
@ -3312,7 +3313,7 @@
"metadata": {},
"source": [
"We use `UndefPrimal` instances to indicate which arguments with respect to\n",
"with we want to transpose. These arise because in general, being explicit\n",
"which we want to transpose. These arise because in general, being explicit\n",
"about closed-over values, we want to transpose functions of type\n",
"`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the\n",
"inputs with respect to which the function is linear could be scattered through\n",

@ -505,7 +505,7 @@ class JVPTrace(Trace):
jvp_rules = {}
```
Notice both `lift` and `sublift` package a value into a `JVPTracer` with the
Notice both `pure` and `lift` package a value into a `JVPTracer` with the
minimal amount of context, which is a zero tangent value.
Let's add some JVP rules for primitives:
@ -960,7 +960,7 @@ jaxpr and then interpreting the jaxpr.)
+++
### Jaxpr data strutures
### Jaxpr data structures
The jaxpr term syntax is roughly:
@ -2012,7 +2012,7 @@ and tangent jaxprs:
g:float64[] = neg e
in ( g ) }
```
This second jaxpr is represents the linear computation that we want from
This second jaxpr represents the linear computation that we want from
`linearize`.
However, unlike in this jaxpr example, we want the computation on known values
@ -2021,7 +2021,7 @@ forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all
operations out of Python first before sorting out what can be evaluated now
and what must be delayed, we want only to form a jaxpr for those operations
that _must_ be delayed due to a dependence on unknown inputs. In the context
of automatic differentiation, this is the feature ultimately enables us to
of automatic differentiation, this is the feature that ultimately enables us to
handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
flow works because partial evaluation keeps the primal computation in Python.
As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
@ -2122,9 +2122,10 @@ inputs are known then we can bind the primitive on the known values
(evaluating it in Python) and avoid forming tracers corresponding to the
output. If instead any input is unknown then we instead stage out into a
`JaxprEqnRecipe` representing the primitive application. To build the tracers
representing unknown outputs, we need avals, which get from the abstract eval
rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s
reference tracers; we avoid circular garbage by using weakrefs.)
representing unknown outputs, we need avals, which we get from the abstract
eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and
`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using
weakrefs.)
That `process_primitive` logic applies to most primitives, but `xla_call_p`
requires recursive treatment. So we special-case its rule in a
@ -2468,7 +2469,7 @@ register_pytree_node(UndefPrimal,
```
We use `UndefPrimal` instances to indicate which arguments with respect to
with we want to transpose. These arise because in general, being explicit
which we want to transpose. These arise because in general, being explicit
about closed-over values, we want to transpose functions of type
`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the
inputs with respect to which the function is linear could be scattered through

@ -486,7 +486,7 @@ class JVPTrace(Trace):
jvp_rules = {}
# -
# Notice both `lift` and `sublift` package a value into a `JVPTracer` with the
# Notice both `pure` and `lift` package a value into a `JVPTracer` with the
# minimal amount of context, which is a zero tangent value.
#
# Let's add some JVP rules for primitives:
@ -919,7 +919,7 @@ jacfwd(f, np.arange(3.))
# control flow, any transformation could be implemented by first tracing to a
# jaxpr and then interpreting the jaxpr.)
# ### Jaxpr data strutures
# ### Jaxpr data structures
#
# The jaxpr term syntax is roughly:
#
@ -1930,7 +1930,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
# g:float64[] = neg e
# in ( g ) }
# ```
# This second jaxpr is represents the linear computation that we want from
# This second jaxpr represents the linear computation that we want from
# `linearize`.
#
# However, unlike in this jaxpr example, we want the computation on known values
@ -1939,7 +1939,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
# operations out of Python first before sorting out what can be evaluated now
# and what must be delayed, we want only to form a jaxpr for those operations
# that _must_ be delayed due to a dependence on unknown inputs. In the context
# of automatic differentiation, this is the feature ultimately enables us to
# of automatic differentiation, this is the feature that ultimately enables us to
# handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
# flow works because partial evaluation keeps the primal computation in Python.
# As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
@ -2036,9 +2036,10 @@ class PartialEvalTracer(Tracer):
# (evaluating it in Python) and avoid forming tracers corresponding to the
# output. If instead any input is unknown then we instead stage out into a
# `JaxprEqnRecipe` representing the primitive application. To build the tracers
# representing unknown outputs, we need avals, which get from the abstract eval
# rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s
# reference tracers; we avoid circular garbage by using weakrefs.)
# representing unknown outputs, we need avals, which we get from the abstract
# eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and
# `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using
# weakrefs.)
#
# That `process_primitive` logic applies to most primitives, but `xla_call_p`
# requires recursive treatment. So we special-case its rule in a
@ -2376,7 +2377,7 @@ register_pytree_node(UndefPrimal,
# -
# We use `UndefPrimal` instances to indicate which arguments with respect to
# with we want to transpose. These arise because in general, being explicit
# which we want to transpose. These arise because in general, being explicit
# about closed-over values, we want to transpose functions of type
# `a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the
# inputs with respect to which the function is linear could be scattered through

@ -160,6 +160,6 @@ fix the issues you can push new commits to your branch.
Once your PR has been reviewed, a JAX maintainer will mark it as `Pull Ready`. This
will trigger a larger set of tests, including tests on GPU and TPU backends that are
not available via standard GitHub CI. Detailed results of these tests are not publicly
viweable, but the JAX mantainer assigned to your PR will communicate with you regarding
viewable, but the JAX maintainer assigned to your PR will communicate with you regarding
any failures these might uncover; it's not uncommon, for example, that numerical tests
need different tolerances on TPU than on CPU.

@ -87,7 +87,7 @@ skip_app.defvjp(skip_app_fwd, skip_app_bwd)
## Explanation
Passing `Tracer`s into `nondiff_argnums` arguments was always buggy. While there
were some cases which worked correctly, others would lead to complex and
were some cases that worked correctly, others would lead to complex and
confusing error messages.
The essence of the bug was that `nondiff_argnums` was implemented in a way that

@ -85,7 +85,7 @@ You can either install Python using its
[Windows installer](https://www.python.org/downloads/), or if you prefer, you
can use [Anaconda](https://docs.anaconda.com/anaconda/install/windows/)
or [Miniconda](https://docs.conda.io/en/latest/miniconda.html#windows-installers)
to setup a Python environment.
to set up a Python environment.
Some targets of Bazel use bash utilities to do scripting, so [MSYS2](https://www.msys2.org)
is needed. See [Installing Bazel on Windows](https://docs.bazel.build/versions/master/install-windows.html#installing-compilers-and-language-runtimes)
@ -174,7 +174,7 @@ python tests/lax_numpy_test.py --test_targets="testPad"
The Colab notebooks are tested for errors as part of the documentation build.
Note that to run the full pmap tests on a (multi-core) CPU only machine, you
Note that to run the full pmap tests on a (multi-core) CPU-only machine, you
can run:
```
@ -278,7 +278,7 @@ See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs
## Documentation building on readthedocs.io
JAX's auto-generated documentations is at <https://jax.readthedocs.io/>.
JAX's auto-generated documentation is at <https://jax.readthedocs.io/>.
The documentation building is controlled for the entire project by the
[readthedocs JAX settings](https://readthedocs.org/dashboard/jax). The current settings

@ -80,7 +80,7 @@ For more information about how to interpret callgraph visualizations, see the
Functions compiled with {func}`jax.jit` are opaque to the device memory profiler.
That is, any memory allocated inside a `jit`-compiled function will be
attributed to the function as whole.
attributed to the function as a whole.
In the example, the call to `block_until_ready()` is to ensure that `func2`
completes before the device memory profile is collected. See
@ -90,7 +90,7 @@ completes before the device memory profile is collected. See
We can also use the JAX device memory profiler to track down memory leaks by using
`pprof` to visualize the change in memory usage between two device memory profiles
taken at different times. For example consider the following program which
taken at different times. For example, consider the following program which
accumulates JAX arrays into a constantly-growing Python list.
```python

@ -60,10 +60,10 @@ If your ``jit`` decorated function takes tens of seconds (or more!) to run the
first time you call it, but executes quickly when called again, JAX is taking a
long time to trace or compile your code.
This is usually a symptom of calling your function generating a large amount of
This is usually a sign that calling your function generates a large amount of
code in JAX's internal representation, typically because it makes heavy use of
Python control flow such as ``for`` loop. For a handful of loop iterations
Python is OK, but if you need _many_ loop iterations, you should rewrite your
Python control flow such as ``for`` loops. For a handful of loop iterations,
Python is OK, but if you need *many* loop iterations, you should rewrite your
code to make use of JAX's
`structured control flow primitives <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Structured-control-flow-primitives>`_
(such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can
@ -206,7 +206,7 @@ running full applications, which inevitably include some amount of both data
transfer and compilation. Also, we were careful to pick large enough arrays
(1000x1000) and an intensive enough computation (the ``@`` operator is
performing matrix-matrix multiplication) to amortize the increased overhead of
JAX/accelerators vs NumPy/CPU. For example, if switch this example to use
JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use
10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs).
.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit
@ -322,7 +322,7 @@ are not careful you may obtain a ``NaN`` for reverse differentiation::
jax.grad(my_log)(0.) ==> NaN
A short explanation is that during ``grad`` computation the adjoint corresponding
to the undefined ``jnp.log(x)`` is a ``NaN`` and when it gets accumulated to the
to the undefined ``jnp.log(x)`` is a ``NaN`` and it gets accumulated to the
adjoint of the ``jnp.where``. The correct way to write such functions is to ensure
that there is a ``jnp.where`` *inside* the partially-defined function, to ensure
that the adjoint is always finite::

@ -219,7 +219,7 @@
"\n",
"(Like $\\nabla$, `jax.grad` will only work on functions with a scalar output -- it will raise an error otherwise.)\n",
"\n",
"This makes the JAX API quite different to other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.\n",
"This makes the JAX API quite different from other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.\n",
"\n",
"This way of doing things makes it straightforward to control things like which variables to differentiate with respect to. By default, `jax.grad` will find the gradient with respect to the first argument. In the example below, the result of `sum_squared_error_dx` will be the gradient of `sum_squared_error` with respect to `x`."
]

@ -120,7 +120,7 @@ Analogously, `jax.grad(f)` is the function that computes the gradient, so `jax.g
(Like $\nabla$, `jax.grad` will only work on functions with a scalar output -- it will raise an error otherwise.)
This makes the JAX API quite different to other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.
This makes the JAX API quite different from other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.
This way of doing things makes it straightforward to control things like which variables to differentiate with respect to. By default, `jax.grad` will find the gradient with respect to the first argument. In the example below, the result of `sum_squared_error_dx` will be the gradient of `sum_squared_error` with respect to `x`.

@ -290,7 +290,7 @@
"\n",
"This update is not the gradient of any loss function.\n",
"\n",
"However it can be **written** as the gradient of the pseudo loss function\n",
"However, it can be **written** as the gradient of the pseudo loss function\n",
"\n",
"$$\n",
"L(\\theta) = [r_t + v_{\\theta}(s_t) - v_{\\theta}(s_{t-1})]^2\n",

@ -186,7 +186,7 @@ $$
This update is not the gradient of any loss function.
However it can be **written** as the gradient of the pseudo loss function
However, it can be **written** as the gradient of the pseudo loss function
$$
L(\theta) = [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2

@ -258,7 +258,7 @@
"\n",
"This doesn't seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX. \n",
"\n",
"Making this code reproducible in JAX would require to enforce this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.\n",
"Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.\n",
"\n",
"To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key` ."
]

@ -140,7 +140,7 @@ The output of this code can only satisfy requirement #1 if we assume a specific
This doesn't seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX.
Making this code reproducible in JAX would require to enforce this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.
Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.
To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key` .

@ -338,7 +338,7 @@
"source": [
"## Custom pytree nodes\n",
"\n",
"So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define my own container class, it will be considered a leaf, even if it has trees inside it:"
"So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:"
]
},
{

@ -193,7 +193,7 @@ plt.legend();
## Custom pytree nodes
So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define my own container class, it will be considered a leaf, even if it has trees inside it:
So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:
```{code-cell}
:id: CK8LN2PRFnQf

@ -555,7 +555,7 @@
},
"outputs": [],
"source": [
"from typing import NamedTuple\n",
"from typing import NamedTuple, Tuple\n",
"import functools\n",
"\n",
"class Params(NamedTuple):\n",
@ -585,7 +585,7 @@
"# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it\n",
"# 'num_devices', but could have used anything, so long as `pmean` used the same.\n",
"@functools.partial(jax.pmap, axis_name='num_devices')\n",
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Params:\n",
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:\n",
" \"\"\"Performs one SGD update step on params using the given data.\"\"\"\n",
"\n",
" # Compute the gradients on the given minibatch (individually on each device).\n",

@ -228,7 +228,7 @@ If this example is too confusing, you can find the same example, but without par
```{code-cell}
:id: cI8xQqzRrc-4
from typing import NamedTuple
from typing import NamedTuple, Tuple
import functools
class Params(NamedTuple):
@ -258,7 +258,7 @@ LEARNING_RATE = 0.005
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
# 'num_devices', but could have used anything, so long as `pmean` used the same.
@functools.partial(jax.pmap, axis_name='num_devices')
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Params:
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
"""Performs one SGD update step on params using the given data."""
# Compute the gradients on the given minibatch (individually on each device).

@ -12,8 +12,8 @@ While JAX tries to follow the NumPy API as closely as possible, sometimes JAX
cannot follow NumPy exactly.
* Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays
in-place cannot be implemented in JAX. However, often JAX is able to provide a
alternative API that is purely functional. For example, instead of in-place
in-place cannot be implemented in JAX. However, often JAX is able to provide
an alternative API that is purely functional. For example, instead of in-place
array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed
update function :func:`jax.ops.index_update`.

@ -65,8 +65,8 @@ Equations are printed as follows::
Eqn ::= let Var+ = Primitive [ Param* ] Expr+
where:
* ``Var+`` are one or more intermediate variables to be defined as the
output of a primitive invocation (some primitives can return multiple values)
* ``Var+`` are one or more intermediate variables to be defined as the output
of a primitive invocation (some primitives can return multiple values).
* ``Expr+`` are one or more atomic expressions, each either a variable or a
literal constant. A special variable ``unitvar`` or literal ``unit``,
printed as ``*``, represents a value that is not needed
@ -235,7 +235,7 @@ The cond primitive has a number of parameters:
parameters are used linearly in the conditional.
The above instance of the cond primitive takes two operands. The first
one (``c``) is the branch index, then ``b`` is the operand (``arg``) to
one (``d``) is the branch index, then ``b`` is the operand (``arg``) to
be passed to whichever jaxpr in ``branches`` is selected by the branch
index.
@ -267,7 +267,7 @@ Another example, using :py:func:`lax.cond`:
In this case, the boolean predicate is converted to an integer index
(0 or 1), and ``branches`` are jaxprs that correspond to the false and
true branch functionals, in that order. Again, each functional takes
one input variable, corresponding to ``xtrue`` and ``xfalse``
one input variable, corresponding to ``xfalse`` and ``xtrue``
respectively.
The following example shows a more complicated situation when the input
@ -341,7 +341,7 @@ For example, here is an example fori loop
cond_nconsts=0 ] c a 0 b d
in (e,) }
The while primitive takes 5 arguments: ``c a 0 b e``, as follows:
The while primitive takes 5 arguments: ``c a 0 b d``, as follows:
* 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0)
* 2 constants for ``body_jaxpr`` (``c``, and ``a``)
@ -406,8 +406,8 @@ XLA_call
^^^^^^^^
The call primitive arises from JIT compilation, and it encapsulates
a sub-jaxpr along with parameters the specify the backend and the device the
computation should run. For example
a sub-jaxpr along with parameters that specify the backend and the device on
which the computation should run. For example
>>> from jax import jit
>>>
@ -480,10 +480,10 @@ captured using the ``xla_pmap`` primitive. Consider this example
out_axes=(0,) ] b a
in (c,) }
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)
and the body of the function to be mapped as the ``call_jaxpr`` parameter.
value of this parameter is a Jaxpr with 3 input variables:
The ``xla_pmap`` primitive specifies the name of the axis (parameter
``axis_name``) and the body of the function to be mapped as the ``call_jaxpr``
parameter. The value of this parameter is a Jaxpr with 2 input variables.
The parameter ``in_axes`` specifies which of the input variables should be
mapped and which should be broadcast. In our example, the value of ``extra``
is broadcast, the other input values are mapped.
is broadcast and the value of ``arr`` is mapped.

@ -16,7 +16,7 @@ with JAXs collective operations, we recommend starting with the
environments in JAX is direct communication links between accelerators, e.g. the
high-speed interconnects for Cloud TPUs or
[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links are what allow
collective operations to run across multiple process worth of accelerators.
collective operations to run across multiple processes worth of accelerators.
## Multi-process programming model
@ -80,7 +80,7 @@ out the {doc}`/jax-101/06-parallelism` notebook.) Each process should call the
same pmapped function and pass in arguments to be mapped across its _local_
devices (i.e., the pmapped axis size is equal to the number of local
devices). Similarly, the function will return outputs sharded across _local_
devices only. Inside the function however, collective communication operations
devices only. Inside the function, however, collective communication operations
are run across all _global_ devices, across all processes. Conceptually, this
can be thought of as running a pmap over a single array sharded across hosts,
where each host “sees” only its local shard of the input and output.

@ -955,7 +955,7 @@
"source": [
"JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! \n",
"\n",
"Reusing the same state will cause __sadness__ and __monotony__, depriving the enduser of __lifegiving chaos__:"
"Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:"
]
},
{

@ -461,7 +461,7 @@ key
JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!
Reusing the same state will cause __sadness__ and __monotony__, depriving the enduser of __lifegiving chaos__:
Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:
```{code-cell} ipython3
:id: 7zUdQMynoE5e

@ -850,7 +850,7 @@
"id": "p2xFQAte19sF"
},
"source": [
"This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \\mapsto x^*(a)$ that is implicity defined by equation $x = f(a, x)$.\n",
"This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \\mapsto x^*(a)$ that is implicitly defined by equation $x = f(a, x)$.\n",
"\n",
"We can use `fixed_point` to run iterative procedures to convergence, for example running Newton's method to calculate square roots while only executing adds, multiplies, and divides:"
]
@ -1414,7 +1414,7 @@
"id": "kZ0yc-Ihoezk"
},
"source": [
"Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
"Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
]
},
{
@ -1782,7 +1782,7 @@
"id": "GwC26P9kn8qw"
},
"source": [
"Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
"Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
]
},
{

@ -437,7 +437,7 @@ def fixed_point(f, a, x_guess):
+++ {"id": "p2xFQAte19sF"}
This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \mapsto x^*(a)$ that is implicity defined by equation $x = f(a, x)$.
This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \mapsto x^*(a)$ that is implicitly defined by equation $x = f(a, x)$.
We can use `fixed_point` to run iterative procedures to convergence, for example running Newton's method to calculate square roots while only executing adds, multiplies, and divides:
@ -739,7 +739,7 @@ print(grad(f, 1)(2., 3.))
+++ {"id": "kZ0yc-Ihoezk"}
Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
+++ {"id": "3FGwfT67PDs9"}
@ -919,7 +919,7 @@ print(grad(f)(2., 3.))
+++ {"id": "GwC26P9kn8qw"}
Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
+++ {"id": "XfH-ae8bYt6-"}

@ -42,7 +42,7 @@
"Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n",
"as \"multiply_add(x, y, z) = x * y + z\". \n",
"This function operates on 3 identically-shaped tensors of floating point \n",
"values and performs the opertions pointwise."
"values and performs the operations pointwise."
]
},
{
@ -601,7 +601,7 @@
"\n",
"JAX compilation works by compiling each primitive into a graph of XLA operations.\n",
"\n",
"This is biggest hurdle to adding new functionality to JAX, because the \n",
"This is the biggest hurdle to adding new functionality to JAX, because the \n",
"set of XLA operations is limited, and JAX already has pre-defined primitives\n",
"for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++."
]
@ -976,7 +976,7 @@
"Observe also that JAX uses the special abstract tangent value `Zero` for\n",
"the tangent corresponding to the 3rd argument of `ma`. This reflects the \n",
"fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n",
"which flow to 3rd argument to `multiply_add_prim`.\n",
"which flows to the 3rd argument to `multiply_add_prim`.\n",
"\n",
"Observe also that during the abstract evaluation of the tangent we pass the \n",
"value 0.0 as the tangent for the 3rd argument. This is due to the use\n",
@ -1147,7 +1147,7 @@
" w.r.t. tangents in multiply_add_value_and_jvp:\n",
" output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n",
" \n",
" Always one of the first two multiplicative arguments are constants.\n",
" Always one of the first two multiplicative arguments is a constant.\n",
"\n",
" Args:\n",
" ct: the cotangent of the output of the primitive.\n",
@ -1254,7 +1254,7 @@
"#### JIT of reverse differentiation \n",
"\n",
"Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n",
"abstract values, while in the absensce of JIT we used `ConcreteArray`."
"abstract values, while in the absence of JIT we used `ConcreteArray`."
]
},
{

@ -49,7 +49,7 @@ one can define a new primitive that encapsulates the behavior of the function.
Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically
as "multiply_add(x, y, z) = x * y + z".
This function operates on 3 identically-shaped tensors of floating point
values and performs the opertions pointwise.
values and performs the operations pointwise.
+++ {"id": "HIJYIHNTD1yI"}
@ -347,7 +347,7 @@ with expectNotImplementedError():
JAX compilation works by compiling each primitive into a graph of XLA operations.
This is biggest hurdle to adding new functionality to JAX, because the
This is the biggest hurdle to adding new functionality to JAX, because the
set of XLA operations is limited, and JAX already has pre-defined primitives
for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++.
@ -530,7 +530,7 @@ for the differentiation point, and abstract values for the tangents.
Observe also that JAX uses the special abstract tangent value `Zero` for
the tangent corresponding to the 3rd argument of `ma`. This reflects the
fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,
which flow to 3rd argument to `multiply_add_prim`.
which flows to the 3rd argument to `multiply_add_prim`.
Observe also that during the abstract evaluation of the tangent we pass the
value 0.0 as the tangent for the 3rd argument. This is due to the use
@ -630,7 +630,7 @@ def multiply_add_transpose(ct, x, y, z):
w.r.t. tangents in multiply_add_value_and_jvp:
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
Always one of the first two multiplicative arguments are constants.
Always one of the first two multiplicative arguments is a constant.
Args:
ct: the cotangent of the output of the primitive.
@ -679,7 +679,7 @@ last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is th
#### JIT of reverse differentiation
Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only
abstract values, while in the absensce of JIT we used `ConcreteArray`.
abstract values, while in the absence of JIT we used `ConcreteArray`.
```{code-cell} ipython3
:id: FZ-JGbWZPq2-

@ -12,7 +12,7 @@
"\n",
"**Copyright 2018 Google LLC.**\n",
"\n",
"Licensed under the Apache License, Version 2.0 (the \"License\");you may not use this file except in compliance with the License.\n",
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License.\n",
"You may obtain a copy of the License at\n",
"\n",
"https://www.apache.org/licenses/LICENSE-2.0\n",
@ -34,7 +34,7 @@
"\n",
"Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
"\n",
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model."
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
]
},
{

@ -20,7 +20,7 @@ kernelspec:
**Copyright 2018 Google LLC.**
Licensed under the Apache License, Version 2.0 (the "License");you may not use this file except in compliance with the License.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
@ -37,7 +37,7 @@ limitations under the License.
Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model.
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.
```{code-cell} ipython3
:id: OksHydJDtbbI

@ -146,12 +146,12 @@
"id": "k-HxK9iagnH6"
},
"source": [
"* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions\n",
"* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions.\n",
"* `jaxpr.outvars` - the `outvars` of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.\n",
"* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later)\n",
"* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.\n",
"* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later).\n",
"* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is a list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.\n",
"\n",
"All together, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want."
"Altogether, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want."
]
},
{
@ -335,7 +335,7 @@
"\n",
"An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n",
"\n",
"It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L141-L187)."
"It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234)."
]
},
{

@ -105,12 +105,12 @@ examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
+++ {"id": "k-HxK9iagnH6"}
* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions
* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions.
* `jaxpr.outvars` - the `outvars` of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.
* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later)
* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.
* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later).
* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is a list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.
All together, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want.
Altogether, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want.
+++ {"id": "NwY7TurYn6sr"}
@ -235,7 +235,7 @@ Furthermore, this interpreter does not handle `subjaxprs`, which we will not cov
An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.
It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L141-L187).
It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234).
```{code-cell} ipython3
:id: gSMIT2z1vUpO

@ -18,7 +18,7 @@
"\n",
"XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.). \n",
"\n",
"As end users we interact with the computational primitives offered to us by the HLO spec.\n",
"As end users, we interact with the computational primitives offered to us by the HLO spec.\n",
"\n",
"**Caution: This is a pedagogical notebook covering some low level XLA details, the APIs herein are neither public nor stable!**"
]

@ -26,7 +26,7 @@ XLA is the compiler that JAX uses, and the compiler that TF uses for TPUs and wi
XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.).
As end users we interact with the computational primitives offered to us by the HLO spec.
As end users, we interact with the computational primitives offered to us by the HLO spec.
**Caution: This is a pedagogical notebook covering some low level XLA details, the APIs herein are neither public nor stable!**

@ -699,7 +699,7 @@
"\n",
"To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with \"tall\" Jacobians, but inefficient for \"wide\" Jacobians.\n",
"\n",
"If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\\mathbb{R}^n$ to a scalar loss value in $\\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\\partial f(x) \\in \\mathbb{R}^{1 \\times n}$, which we often identify with the Gradient vector $\\nabla f(x) \\in \\mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluating the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.\n",
"If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\\mathbb{R}^n$ to a scalar loss value in $\\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\\partial f(x) \\in \\mathbb{R}^{1 \\times n}$, which we often identify with the Gradient vector $\\nabla f(x) \\in \\mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.\n",
"\n",
"To do better for functions like this, we just need to use reverse-mode."
]
@ -1493,7 +1493,7 @@
"id": "jqCvEE8qwGw7"
},
"source": [
"For geneneral $\\mathbb{C} \\to \\mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them with in a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\\mathbb{C} \\to \\mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.\n",
"For general $\\mathbb{C} \\to \\mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\\mathbb{C} \\to \\mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.\n",
"\n",
"Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when `grad` is used for a complex-output function:"
]
@ -1574,8 +1574,8 @@
"There are some useful upshots for how `grad` works here:\n",
"\n",
"1. We can use `grad` on holomorphic $\\mathbb{C} \\to \\mathbb{C}$ functions.\n",
"2. We can use `grad` to optimize $f : \\mathbb{C} \\to \\mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the dierction of the conjugate of `grad(f)(x)`.\n",
"3. If we have an $\\mathbb{R} \\to \\mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in covolutions) then `grad` still works and we get the same result that an implementation using only real values would have given.\n",
"2. We can use `grad` to optimize $f : \\mathbb{C} \\to \\mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`.\n",
"3. If we have an $\\mathbb{R} \\to \\mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then `grad` still works and we get the same result that an implementation using only real values would have given.\n",
"\n",
"In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\\mathbb{C} \\to \\mathbb{C}$ function, we can do it with JVPs or VJPs!"
]
@ -1637,7 +1637,7 @@
"\n",
"In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. \n",
"\n",
"There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in a \"Advanced Autodiff Cookbook\" include:\n",
"There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an \"Advanced Autodiff Cookbook\" include:\n",
"\n",
" - Gauss-Newton Vector Products, linearizing once\n",
" - Custom VJPs and JVPs\n",

@ -396,7 +396,7 @@ That memory complexity sounds pretty compelling! So why don't we see forward-mod
To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with "tall" Jacobians, but inefficient for "wide" Jacobians.
If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluating the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.
If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.
To do better for functions like this, we just need to use reverse-mode.
@ -895,7 +895,7 @@ grad(f)(z)
+++ {"id": "jqCvEE8qwGw7"}
For geneneral $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them with in a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.
For general $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.
Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when `grad` is used for a complex-output function:
@ -930,8 +930,8 @@ grad(f, holomorphic=True)(z) # f is not actually holomorphic!
There are some useful upshots for how `grad` works here:
1. We can use `grad` on holomorphic $\mathbb{C} \to \mathbb{C}$ functions.
2. We can use `grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the dierction of the conjugate of `grad(f)(x)`.
3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in covolutions) then `grad` still works and we get the same result that an implementation using only real values would have given.
2. We can use `grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`.
3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then `grad` still works and we get the same result that an implementation using only real values would have given.
In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\mathbb{C} \to \mathbb{C}$ function, we can do it with JVPs or VJPs!
@ -960,7 +960,7 @@ grad(f, holomorphic=True)(A)
In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful.
There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in a "Advanced Autodiff Cookbook" include:
There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an "Advanced Autodiff Cookbook" include:
- Gauss-Newton Vector Products, linearizing once
- Custom VJPs and JVPs

@ -169,7 +169,7 @@
"source": [
"For the more general types of batched convolutions often useful in the context of building deep neural networks, JAX and XLA offer the very general N-dimensional __conv_general_dilated__ function, but it's not very obvious how to use it. We'll give some examples of the common use-cases.\n",
"\n",
"A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285) is highly recommended reading!\n",
"A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285), is highly recommended reading!\n",
"\n",
"Let's define a simple diagonal edge kernel:"
]

@ -111,7 +111,7 @@ Like in the one-dimensional case, we use `mode='same'` to specify how we would l
For the more general types of batched convolutions often useful in the context of building deep neural networks, JAX and XLA offer the very general N-dimensional __conv_general_dilated__ function, but it's not very obvious how to use it. We'll give some examples of the common use-cases.
A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285) is highly recommended reading!
A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285), is highly recommended reading!
Let's define a simple diagonal edge kernel:

@ -25,10 +25,10 @@
"In this notebook we'll go through:\n",
"\n",
"- how to take gradients, gradients of gradients.\n",
"- how to fit a sinusoid function with a neural network (and do auto-batching with vmap)\n",
"- how to implement MAML and check its numerics\n",
"- how to fit a sinusoid function with a neural network (and do auto-batching with vmap).\n",
"- how to implement MAML and check its numerics.\n",
"- how to implement MAML for sinusoid task (single-task objective, batching task instances).\n",
"- extending MAML to handle batching at the task-level"
"- extending MAML to handle batching at the task-level."
]
},
{
@ -526,7 +526,7 @@
"source": [
"## Batching Meta-Gradient Across Tasks\n",
"\n",
"Kind of does the job but not that great. Let's reduce the variance of gradients in outer loop by averaging across a batch of tasks (not just one task at a time). \n",
"Kind of does the job but not that great. Let's reduce the variance of the gradients in the outer loop by averaging across a batch of tasks (not just one task at a time). \n",
"\n",
"vmap is awesome it enables nice handling of batching at two levels: inner-level \"intra-task\" batching, and outer level batching across tasks.\n",
"\n",

@ -31,10 +31,10 @@ Pedagogical tutorial for implementing Model-Agnostic Meta-Learning with JAX's aw
In this notebook we'll go through:
- how to take gradients, gradients of gradients.
- how to fit a sinusoid function with a neural network (and do auto-batching with vmap)
- how to implement MAML and check its numerics
- how to fit a sinusoid function with a neural network (and do auto-batching with vmap).
- how to implement MAML and check its numerics.
- how to implement MAML for sinusoid task (single-task objective, batching task instances).
- extending MAML to handle batching at the task-level
- extending MAML to handle batching at the task-level.
```{code-cell} ipython3
:id: zKVdo3FtgyhE
@ -287,7 +287,7 @@ plt.legend()
## Batching Meta-Gradient Across Tasks
Kind of does the job but not that great. Let's reduce the variance of gradients in outer loop by averaging across a batch of tasks (not just one task at a time).
Kind of does the job but not that great. Let's reduce the variance of the gradients in the outer loop by averaging across a batch of tasks (not just one task at a time).
vmap is awesome it enables nice handling of batching at two levels: inner-level "intra-task" batching, and outer level batching across tasks.

@ -46,7 +46,7 @@
"\n",
"Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n",
"\n",
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model."
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
]
},
{

@ -44,7 +44,7 @@ _Forked from_ `neural_network_and_data_loading.ipynb`
Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model.
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.
```{code-cell} ipython3
:id: OksHydJDtbbI

@ -196,7 +196,7 @@
"id": "iOzp0P_GoJhb"
},
"source": [
"JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there's three main ones:\n",
"JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:\n",
"\n",
" - {func}`~jax.jit`, for speeding up your code\n",
" - {func}`~jax.grad`, for taking derivatives\n",

@ -127,7 +127,7 @@ x = np.random.normal(size=(size, size)).astype(np.float32)
+++ {"id": "iOzp0P_GoJhb"}
JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there's three main ones:
JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:
- {func}`~jax.jit`, for speeding up your code
- {func}`~jax.grad`, for taking derivatives

@ -77,7 +77,7 @@
"source": [
"## Compute score matching objective\n",
"\n",
"The method we apply here was originally proposed by [Hyvarinen et al. (2005)](http://jmlr.org/papers/volume6/hyvarinen05a/old.pdf). The idea behind score matching is to __learn scores:__ the gradients of $\\log p(x)$ w.r.t. $x$. When trained this model can \"improve\" a sample $x$ by changing it in the direction of highest log-probability. However, training such model can get tricky. When predicting a continuous variable, ML folks usually minimize squared error:\n",
"The method we apply here was originally proposed by [Hyvarinen et al. (2005)](http://jmlr.org/papers/volume6/hyvarinen05a/old.pdf). The idea behind score matching is to __learn scores:__ the gradients of $\\log p(x)$ w.r.t. $x$. When trained this model can \"improve\" a sample $x$ by changing it in the direction of highest log-probability. However, training such a model can get tricky. When predicting a continuous variable, ML folks usually minimize squared error:\n",
"\n",
"$$ L_{mse} = E_{x \\sim p(x)} \\left\\lVert model(x) - \\nabla_x \\log p(x) \\right\\lVert_2^2 $$\n",
"\n",
@ -85,7 +85,7 @@
"\n",
"$$ L_{matching} = E_{x \\sim p(x)} \\space tr( \\space \\mathbf{J}_x [\\space model(x) \\space]) + \\frac12 \\left\\Vert model(x) \\right\\lVert_2^2 $$\n",
"\n",
"Here $tr( \\space \\mathbf{J}_x [\\space model(x) \\space])$ is a trace of Jacobian of $model(x)$ w.r.t. $x$. Now all it takes is to minimize the second objective with backpropagation... that is, if you can compute Jacobians. Thankfully, we have __jax__!"
"Here $tr( \\space \\mathbf{J}_x [\\space model(x) \\space])$ is the trace of the Jacobian of $model(x)$ w.r.t. $x$. Now all it takes is to minimize the second objective with backpropagation... that is, if you can compute Jacobians. Thankfully, we have __jax__!"
]
},
{

@ -50,7 +50,7 @@ plt.scatter(*sample_batch(10**4).T, alpha=0.1)
## Compute score matching objective
The method we apply here was originally proposed by [Hyvarinen et al. (2005)](http://jmlr.org/papers/volume6/hyvarinen05a/old.pdf). The idea behind score matching is to __learn scores:__ the gradients of $\log p(x)$ w.r.t. $x$. When trained this model can "improve" a sample $x$ by changing it in the direction of highest log-probability. However, training such model can get tricky. When predicting a continuous variable, ML folks usually minimize squared error:
The method we apply here was originally proposed by [Hyvarinen et al. (2005)](http://jmlr.org/papers/volume6/hyvarinen05a/old.pdf). The idea behind score matching is to __learn scores:__ the gradients of $\log p(x)$ w.r.t. $x$. When trained this model can "improve" a sample $x$ by changing it in the direction of highest log-probability. However, training such a model can get tricky. When predicting a continuous variable, ML folks usually minimize squared error:
$$ L_{mse} = E_{x \sim p(x)} \left\lVert model(x) - \nabla_x \log p(x) \right\lVert_2^2 $$
@ -58,7 +58,7 @@ One can't minimize this explicitly because the real $\nabla_x \log p(x)$ is usua
$$ L_{matching} = E_{x \sim p(x)} \space tr( \space \mathbf{J}_x [\space model(x) \space]) + \frac12 \left\Vert model(x) \right\lVert_2^2 $$
Here $tr( \space \mathbf{J}_x [\space model(x) \space])$ is a trace of Jacobian of $model(x)$ w.r.t. $x$. Now all it takes is to minimize the second objective with backpropagation... that is, if you can compute Jacobians. Thankfully, we have __jax__!
Here $tr( \space \mathbf{J}_x [\space model(x) \space])$ is the trace of the Jacobian of $model(x)$ w.r.t. $x$. Now all it takes is to minimize the second objective with backpropagation... that is, if you can compute Jacobians. Thankfully, we have __jax__!
```{code-cell} ipython3
:id: 98wjxKcNG6TI

@ -774,7 +774,7 @@
"\n",
"These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.\n",
"\n",
"When we call the compiled fuction again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:"
"When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:"
]
},
{

@ -349,7 +349,7 @@ Notice that the print statements execute, but rather than printing the data we p
These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.
When we call the compiled fuction again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:
When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:
```{code-cell} ipython3
:id: xGntvzNH7skE

@ -407,7 +407,7 @@
"id": "8E7ISmwju0x1"
},
"source": [
"While this might seem like a handful at first, if you've seen code that uses `jnp.einsum` you are already familiar with this approach. The `einsum` function interprets an expression such as `nk,km->nm` assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right hand side of the `->` separator. While `einsum` never lets you interact with named axes directly, they do appear naturally in its implementation. `xmap` is a _generalized einsum_ because named axes are now first-class and you get to implement the function that can manipulate them.\n",
"While this might seem like a handful at first, if you've seen code that uses `jnp.einsum` you are already familiar with this approach. The `einsum` function interprets an expression such as `nk,km->nm` assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right-hand side of the `->` separator. While `einsum` never lets you interact with named axes directly, they do appear naturally in its implementation. `xmap` is a _generalized einsum_ because named axes are now first-class and you get to implement the function that can manipulate them.\n",
"\n",
"Continuing this analogy, `xmap(my_func, ...)` from the above example is equivalent to `jnp.einsum('bx->xb')`. But of course not every `xmap`ped function will have an equivalent `einsum`.\n",
"\n",
@ -543,7 +543,7 @@
"\n",
"> While the rule for broadcasting named axes might seem like an arbitrary extension of the NumPy model, it is actually consistent with it.\n",
">\n",
"> Broadcasting first looks for pairs of dimensions it considers as equivalent in its both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.\n",
"> Broadcasting first looks for pairs of dimensions it considers as equivalent in both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.\n",
">\n",
"> Now, in the positional world the way NumPy broadcasting chooses to form the pairs is by right-aligning the shapes. But our axes are named, so there is a straightforward way of finding equivalent axes: just check their names for equality!"
]
@ -603,7 +603,7 @@
"\n",
"Similarly to how we have extended reductions with support for named axes, we've also made it possible to contract over named axes using `jnp.einsum`.\n",
"\n",
"Operands and results still use a convention of one letter per positional axes, but now it is also possible to mention named axes in curly braces. For example `n{b,k}` implies that a value will have a single positional dimension `n` and named dimensions `b` and `k` (their order doesn't matter). Following the usual einsum semantics, any named axes that appears in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).\n",
"Operands and results still use a convention of one letter per positional axis, but now it is also possible to mention named axes in curly braces. For example, `n{b,k}` implies that a value will have a single positional dimension `n` and named dimensions `b` and `k` (their order doesn't matter). Following the usual einsum semantics, any named axes that appear in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).\n",
"\n",
"It is acceptable to omit a named dimension from _all arguments and the result_ in which case it will be treated according to the usual broadcasting semantics. However, it is not acceptable to mention a named axis in one argument that has it in its named shape and skip it in another argument that also has it in its named shape. Of course, skipping it in the arguments that don't have it is required.\n",
"\n",
@ -674,7 +674,7 @@
"id": "ZHoCsWkCEnKt"
},
"source": [
"## Parallelism suport\n",
"## Parallelism support\n",
"\n",
"While the new programming paradigm can be nice at times, the killer feature of `xmap` is its ability to parallelize code over supercomputer-scale hardware meshes!\n",
"\n",
@ -682,7 +682,7 @@
"\n",
"In all the previous examples, we haven't said a word about parallelism and for a good reason. By default `xmap` doesn't perform any parallelization and vectorizes the computation in the same way `vmap` does (i.e. it still executes on a single device). To partition the computation over multiple accelerators we have to introduce one more concept: _resource axes_.\n",
"\n",
"The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hadware and memory, but before the program is to be executed, they have to be placed somewhere. The default (`vmap`-like) evaluation style pays a high memory cost on the deafult JAX device. By mapping logical axes to (one or more) resource axes through the `axis_resources` argument, we can control how `xmap` evaluates the computation."
"The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hardware and memory, but before the program is to be executed, they have to be placed somewhere. The default (`vmap`-like) evaluation style pays a high memory cost on the default JAX device. By mapping logical axes to (one or more) resource axes through the `axis_resources` argument, we can control how `xmap` evaluates the computation."
]
},
{
@ -723,7 +723,7 @@
"\n",
"Well, it depends, but one good choice is... a hardware mesh!\n",
"\n",
"For our purposes a mesh is an nd-array of devices with named axes. But, beacuse NumPy doesn't support named axes (that's our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from `jax.devices()` or `jax.local_devices()`) and a tuple of resource axis names of length matching the rank of the array."
"For our purposes a mesh is an nd-array of devices with named axes. But, because NumPy doesn't support named axes (that's our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from `jax.devices()` or `jax.local_devices()`) and a tuple of resource axis names of length matching the rank of the array."
]
},
{
@ -819,11 +819,11 @@
"source": [
"### Is my data replicated? Or partitioned? Where is it?\n",
"\n",
"Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if an only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.\n",
"Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if and only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.\n",
"\n",
"For example, assume that we're in an `xmap` that had `axis_resources={'a': 'x', 'b': 'y'}` specified (i.e. we are running the computation over a 2D mesh with `x` and `y` axes with sizes 2 and 3 respectively). Then:\n",
"* An array of type `f32[(5, 5), {}]` is completely replicated over the whole mesh. All devices store a local copy of the value.\n",
"* An array of type `f32[(6,), {'a': 8}]` is partitioned over mesh axis `x`, beacuse it has `'a'` in its named shape, and `'a'` is mapped to `x`. It is replicated over mesh axis `y`. To put it differently, all devices in a slice of the mesh with the same `x` coordinate will store a local copy of a chunk of this array. But, mesh slices with different `x` coordinates will store different chunks of the data.\n",
"* An array of type `f32[(6,), {'a': 8}]` is partitioned over mesh axis `x`, because it has `'a'` in its named shape, and `'a'` is mapped to `x`. It is replicated over mesh axis `y`. To put it differently, all devices in a slice of the mesh with the same `x` coordinate will store a local copy of a chunk of this array. But, mesh slices with different `x` coordinates will store different chunks of the data.\n",
"* An array of type `f32[(), {'a': 8, 'c': 7}]` is partitioned just like in the previous case: split over the `x` mesh axis and replicated over the `y` axis. Named dimensions with no resources specified are no different than positional dimensions when considering partitioning, so `'c'` has no influence on it.\n",
"* An array of type `f32[(), {'a': 8, 'b': 12}]` is completely partitioned over the whole mesh. Every device holds a distinct chunk of the data."
]

@ -274,7 +274,7 @@ assert (y == x.T).all() # The first dimension was removed from x and then re-in
+++ {"id": "8E7ISmwju0x1"}
While this might seem like a handful at first, if you've seen code that uses `jnp.einsum` you are already familiar with this approach. The `einsum` function interprets an expression such as `nk,km->nm` assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right hand side of the `->` separator. While `einsum` never lets you interact with named axes directly, they do appear naturally in its implementation. `xmap` is a _generalized einsum_ because named axes are now first-class and you get to implement the function that can manipulate them.
While this might seem like a handful at first, if you've seen code that uses `jnp.einsum` you are already familiar with this approach. The `einsum` function interprets an expression such as `nk,km->nm` assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right-hand side of the `->` separator. While `einsum` never lets you interact with named axes directly, they do appear naturally in its implementation. `xmap` is a _generalized einsum_ because named axes are now first-class and you get to implement the function that can manipulate them.
Continuing this analogy, `xmap(my_func, ...)` from the above example is equivalent to `jnp.einsum('bx->xb')`. But of course not every `xmap`ped function will have an equivalent `einsum`.
@ -375,7 +375,7 @@ No shape errors can occur when operating over named axes, because `xmap` enforce
> While the rule for broadcasting named axes might seem like an arbitrary extension of the NumPy model, it is actually consistent with it.
>
> Broadcasting first looks for pairs of dimensions it considers as equivalent in its both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.
> Broadcasting first looks for pairs of dimensions it considers as equivalent in both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.
>
> Now, in the positional world the way NumPy broadcasting chooses to form the pairs is by right-aligning the shapes. But our axes are named, so there is a straightforward way of finding equivalent axes: just check their names for equality!
@ -420,7 +420,7 @@ positional_broadcast_and_reduce(jnp.arange(2, dtype=np.float32),
Similarly to how we have extended reductions with support for named axes, we've also made it possible to contract over named axes using `jnp.einsum`.
Operands and results still use a convention of one letter per positional axes, but now it is also possible to mention named axes in curly braces. For example `n{b,k}` implies that a value will have a single positional dimension `n` and named dimensions `b` and `k` (their order doesn't matter). Following the usual einsum semantics, any named axes that appears in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).
Operands and results still use a convention of one letter per positional axis, but now it is also possible to mention named axes in curly braces. For example, `n{b,k}` implies that a value will have a single positional dimension `n` and named dimensions `b` and `k` (their order doesn't matter). Following the usual einsum semantics, any named axes that appear in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).
It is acceptable to omit a named dimension from _all arguments and the result_ in which case it will be treated according to the usual broadcasting semantics. However, it is not acceptable to mention a named axis in one argument that has it in its named shape and skip it in another argument that also has it in its named shape. Of course, skipping it in the arguments that don't have it is required.
@ -466,7 +466,7 @@ xmap(lambda x: lax.pshuffle(x, 'i', list(reversed(range(8)))),
+++ {"id": "ZHoCsWkCEnKt"}
## Parallelism suport
## Parallelism support
While the new programming paradigm can be nice at times, the killer feature of `xmap` is its ability to parallelize code over supercomputer-scale hardware meshes!
@ -474,7 +474,7 @@ While the new programming paradigm can be nice at times, the killer feature of `
In all the previous examples, we haven't said a word about parallelism and for a good reason. By default `xmap` doesn't perform any parallelization and vectorizes the computation in the same way `vmap` does (i.e. it still executes on a single device). To partition the computation over multiple accelerators we have to introduce one more concept: _resource axes_.
The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hadware and memory, but before the program is to be executed, they have to be placed somewhere. The default (`vmap`-like) evaluation style pays a high memory cost on the deafult JAX device. By mapping logical axes to (one or more) resource axes through the `axis_resources` argument, we can control how `xmap` evaluates the computation.
The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hardware and memory, but before the program is to be executed, they have to be placed somewhere. The default (`vmap`-like) evaluation style pays a high memory cost on the default JAX device. By mapping logical axes to (one or more) resource axes through the `axis_resources` argument, we can control how `xmap` evaluates the computation.
```{code-cell}
:id: NnggOzOD8rl1
@ -500,7 +500,7 @@ Both `local_matmul` and `distr_matmul` implement matrix multiplication, but `dis
Well, it depends, but one good choice is... a hardware mesh!
For our purposes a mesh is an nd-array of devices with named axes. But, beacuse NumPy doesn't support named axes (that's our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from `jax.devices()` or `jax.local_devices()`) and a tuple of resource axis names of length matching the rank of the array.
For our purposes a mesh is an nd-array of devices with named axes. But, because NumPy doesn't support named axes (that's our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from `jax.devices()` or `jax.local_devices()`) and a tuple of resource axis names of length matching the rank of the array.
+++ {"id": "x3EXj1TMwZtS"}
@ -556,11 +556,11 @@ Anyway, the best part of it is that specifying `axis_resources` **never changes
### Is my data replicated? Or partitioned? Where is it?
Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if an only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.
Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if and only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.
For example, assume that we're in an `xmap` that had `axis_resources={'a': 'x', 'b': 'y'}` specified (i.e. we are running the computation over a 2D mesh with `x` and `y` axes with sizes 2 and 3 respectively). Then:
* An array of type `f32[(5, 5), {}]` is completely replicated over the whole mesh. All devices store a local copy of the value.
* An array of type `f32[(6,), {'a': 8}]` is partitioned over mesh axis `x`, beacuse it has `'a'` in its named shape, and `'a'` is mapped to `x`. It is replicated over mesh axis `y`. To put it differently, all devices in a slice of the mesh with the same `x` coordinate will store a local copy of a chunk of this array. But, mesh slices with different `x` coordinates will store different chunks of the data.
* An array of type `f32[(6,), {'a': 8}]` is partitioned over mesh axis `x`, because it has `'a'` in its named shape, and `'a'` is mapped to `x`. It is replicated over mesh axis `y`. To put it differently, all devices in a slice of the mesh with the same `x` coordinate will store a local copy of a chunk of this array. But, mesh slices with different `x` coordinates will store different chunks of the data.
* An array of type `f32[(), {'a': 8, 'c': 7}]` is partitioned just like in the previous case: split over the `x` mesh axis and replicated over the `y` axis. Named dimensions with no resources specified are no different than positional dimensions when considering partitioning, so `'c'` has no influence on it.
* An array of type `f32[(), {'a': 8, 'b': 12}]` is completely partitioned over the whole mesh. Every device holds a distinct chunk of the data.

@ -29,7 +29,7 @@ That is:
contains pytrees, is considered a pytree.
For each entry in the pytree container registry, a container-like type is
registered with a pair of functions which specify how to convert an instance of
registered with a pair of functions that specify how to convert an instance of
the container type to a `(children, metadata)` pair and how to convert such a
pair back to an instance of the container type. Using these functions, JAX can
canonicalize any tree of registered container objects into tuples.
@ -64,7 +64,7 @@ pytrees with values in the argument pytrees, the parameter pytrees are often
constrained to be tree prefixes of the argument pytrees.
For example, if we pass the following input to {func}`~jax.vmap` (note that the input
arguments to a function considered a tuple):
arguments to a function are considered a tuple):
```
(a1, {"k1": a2, "k2": a3})
@ -281,7 +281,7 @@ class RegisteredSpecial2(Special):
show_example(RegisteredSpecial2(1., 2.))
```
JAX needs sometimes to compare `treedef` for equality. Therefore, care must be
JAX sometimes needs to compare `treedef` for equality. Therefore, care must be
taken to ensure that the auxiliary data specified in the flattening recipe
supports a meaningful equality comparison.

@ -3,8 +3,8 @@ Rank promotion warning
`NumPy broadcasting rules
<https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules>`_
allow automatic promotion of arguments from one rank (number of array axes) to
another. This behavior can be convenient when intended but can also lead to
allow the automatic promotion of arguments from one rank (number of array axes)
to another. This behavior can be convenient when intended but can also lead to
surprising bugs where a silent rank promotion masks an underlying shape error.
Here's an example of rank promotion:

@ -18,7 +18,8 @@ kernelspec:
```
At its core, JAX is an extensible system for transforming numerical functions.
This section will discuss four that are of primary interest: {func}`grad`, {func}`jit`, {func}`vmap`, and {func}`pmap`.
This section will discuss four transformations that are of primary interest:
{func}`grad`, {func}`jit`, {func}`vmap`, and {func}`pmap`.
## Automatic differentiation with `grad`

@ -738,7 +738,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False,
reduce_axes: Sequence[AxisName] = ()) -> Callable:
"""Creates a function which evaluates the gradient of ``fun``.
"""Creates a function that evaluates the gradient of ``fun``.
Args:
fun: Function to be differentiated. Its arguments at positions specified by
@ -809,7 +809,7 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
has_aux: bool = False, holomorphic: bool = False,
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
) -> Callable[..., Tuple[Any, Any]]:
"""Create a function which evaluates both ``fun`` and the gradient of ``fun``.
"""Create a function that evaluates both ``fun`` and the gradient of ``fun``.
Args:
fun: Function to be differentiated. Its arguments at positions specified by
@ -1875,7 +1875,7 @@ def jvp(fun: Callable, primals, tangents) -> Tuple[Any, Any]:
array, scalar, or standard Python container of arrays or scalars.
primals: The primal values at which the Jacobian of ``fun`` should be
evaluated. Should be either a tuple or a list of arguments,
and its length should equal to the number of positional parameters of
and its length should be equal to the number of positional parameters of
``fun``.
tangents: The tangent vector for which the Jacobian-vector product should be
evaluated. Should be either a tuple or a list of tangents, with the same

@ -132,7 +132,7 @@ class custom_jvp(Generic[ReturnValue]):
jvp: a Python callable representing the custom JVP rule. When there are no
``nondiff_argnums``, the ``jvp`` function should accept two arguments,
where the first is a tuple of primal inputs and the second is a tuple of
tangent inputs. The lengths of both tuples is equal to the number of
tangent inputs. The lengths of both tuples are equal to the number of
parameters of the ``custom_jvp`` function. The ``jvp`` function should
produce as output a pair where the first element is the primal output
and the second element is the tangent output. Elements of the input and

@ -1569,7 +1569,7 @@ def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
def sort_key_val(keys: Array, values: Array, dimension: int = -1,
is_stable: bool = True) -> Tuple[Array, Array]:
"""Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
"""Sorts ``keys`` along ``dimension`` and applies the same permutation to ``values``."""
dimension = canonicalize_axis(dimension, len(keys.shape))
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
return k, v

@ -208,8 +208,8 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
Returns:
Vectorized version of the given function.
Here a few examples of how one could write vectorized linear algebra routines
using :func:`vectorize`:
Here are a few examples of how one could write vectorized linear algebra
routines using :func:`vectorize`:
>>> from functools import partial

@ -72,7 +72,7 @@ def minimize(
Args:
fun: the objective function to be minimized, ``fun(x, *args) -> float``,
where ``x`` is an 1-D array with shape ``(n,)`` and ``args`` is a tuple
where ``x`` is a 1-D array with shape ``(n,)`` and ``args`` is a tuple
of the fixed parameters needed to completely specify the function.
``fun`` must support differentiation.
x0: initial guess. Array of real elements of size ``(n,)``, where ``n`` is

@ -303,7 +303,7 @@ Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff sup
Another possible use for host computation is to invoke a library written for
another framework, such as TensorFlow.
In this case it becomes interesting to support JAX autodiff for host callbacks
by defering to the autodiff mechanism in TensorFlow,
by deferring to the autodiff mechanism in TensorFlow,
using the :func:`jax.custom_vjp` mechanism.
This is relatively easy to do, once one understands both the JAX custom VJP
@ -363,7 +363,7 @@ error during the processing of the callback (whether raised by the user-code
itself or due to a mismatch of the returned value and the expected return_shape)
we send the device a "fake" result of shape ``int8[12345]``.
This will make the device
computation abort because the received data is different than then one that
computation abort because the received data is different than the one that
it expects. On CPU the runtime will crash with a distinctive error message:
```
@ -416,25 +416,20 @@ for the C++ outfeed `receiver backend
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.
* ``TF_CPP_MIN_LOG_LEVEL=0``: will turn on INFO logging, needed for all below.
* ``TF_CPP_MIN_VLOG_LEVEL=3``: will turn make all VLOG logging up to level 3
behave like INFO logs. This may be too much, but you will see which
modules are logging relevant info, and then you can select which modules
to log from:
* `TF_CPP_VMODULE=<module_name>=3`` (the module name can be either C++ or
* ``TF_CPP_MIN_VLOG_LEVEL=3``: will make all VLOG logging up to level 3 behave
like INFO logs. This may be too much, but you will see which modules are
logging relevant info, and then you can select which modules to log from.
* ``TF_CPP_VMODULE=<module_name>=3`` (the module name can be either C++ or
Python, without the extension).
You should also use the ``--verbosity=2`` flag so that you see the logs
from Python.
For example, you can try to enable logging in the ``host_callback`` module:
```
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
```
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``
If you want to enable logging in lower-level implementation modules try:
```
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
```
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``
(For bazel tests use --test_arg=--vmodule=...
@ -1829,7 +1824,7 @@ def barrier_wait(logging_name: Optional[str] = None):
"""Blocks the calling thread until all current outfeed is processed.
Waits until all callbacks from computations already running on all devices
has been received and processed by the Python callbacks. Raises
have been received and processed by the Python callbacks. Raises
CallbackException if there were exceptions while processing the callbacks.
This works by enqueueing a special tap computation to all devices to which

@ -58,7 +58,7 @@ special `loops.scope` object and use `for` loops over special
Loops constructed with `range` must have literal constant bounds. If you need
loops with dynamic bounds, you can use the more general `while_range` iterator.
However, in that case that `grad` transformation is not supported::
However, in that case the `grad` transformation is not supported::
s.idx = start
for _ in s.while_range(lambda: s.idx < end):
@ -93,7 +93,7 @@ Restrictions:
* Once the loop starts all updates to loop state must be with new values of the
same abstract values as the values on loop start.
* For a `while` loop, the conditional function is not allowed to modify the
scope state. This is a checked error. Also, for `while` loops the `grad`
scope state. This is a checked error. Also, for `while` loops, the `grad`
transformation does not work. An alternative that allows `grad` is a bounded
loop (`range`).

@ -210,7 +210,7 @@ def serial_loop(name: ResourceAxisName, length: int):
@contextlib.contextmanager
def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]):
"""Declare the hardware resources available in scope of this manager.
"""Declare the hardware resources available in the scope of this manager.
In particular, all ``axis_names`` become valid resource names inside the
managed block and can be used e.g. in the ``axis_resources`` argument of
@ -391,9 +391,9 @@ def xmap(fun: Callable,
While it is possible to assign multiple axis names to a single resource axis,
care has to be taken to ensure that none of those named axes co-occur in a
``named_shape`` of any value in the named program. At the moment this is
**completely unchecked** and will result in **undefined behavior**. Final
release of :py:func:`xmap` will enforce this invariant, but it is work
in progress.
**completely unchecked** and will result in **undefined behavior**. The
final release of :py:func:`xmap` will enforce this invariant, but it is a
work in progress.
Note that you do not have to worry about any of this for as long as no
resource axis is repeated in ``axis_resources.values()``.
@ -401,7 +401,7 @@ def xmap(fun: Callable,
Note that any assignment of ``axis_resources`` doesn't ever change the
results of the computation, but only how it is carried out (e.g. how many
devices are used). This makes it easy to try out various ways of
partitioning a single program in many distributed scenarions (both small- and
partitioning a single program in many distributed scenarios (both small- and
large-scale), to maximize the performance. As such, :py:func:`xmap` can be
seen as a way to seamlessly interpolate between :py:func:`vmap` and
:py:func:`pmap`-style execution.
@ -489,7 +489,7 @@ def xmap(fun: Callable,
out_axes={}) # Loss is reduced over all axes, including batch!
.. note::
When using ``axis_resources`` along with a mesh that is controled by
When using ``axis_resources`` along with a mesh that is controlled by
multiple JAX hosts, keep in mind that in any given process :py:func:`xmap`
only expects the data slice that corresponds to its local devices to be
specified. This is in line with the current multi-host :py:func:`pmap`

@ -213,7 +213,7 @@ def sgd(step_size):
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
that maps the iteration index to a positive scalar.
Returns:
An (init_fun, update_fun, get_params) triple.
@ -233,7 +233,7 @@ def momentum(step_size: Schedule, mass: float):
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
that maps the iteration index to a positive scalar.
mass: positive scalar representing the momentum coefficient.
Returns:
@ -260,7 +260,7 @@ def nesterov(step_size: Schedule, mass: float):
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
that maps the iteration index to a positive scalar.
mass: positive scalar representing the momentum coefficient.
Returns:
@ -290,7 +290,7 @@ def adagrad(step_size, momentum=0.9):
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
that maps the iteration index to a positive scalar.
momentum: optional, a positive scalar value for momentum
Returns:
@ -324,7 +324,7 @@ def rmsprop(step_size, gamma=0.9, eps=1e-8):
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
that maps the iteration index to a positive scalar.
gamma: Decay parameter.
eps: Epsilon parameter.
@ -355,7 +355,7 @@ def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
that maps the iteration index to a positive scalar.
gamma: Decay parameter.
eps: Epsilon parameter.
momentum: Momentum parameter.
@ -386,7 +386,7 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
@ -422,7 +422,7 @@ def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
that maps the iteration index to a positive scalar.
b1: optional, a positive scalar value for beta_1, the exponential decay rate
for the first moment estimates (default 0.9).
b2: optional, a positive scalar value for beta_2, the exponential decay rate
@ -460,7 +460,7 @@ def sm3(step_size, momentum=0.9):
Args:
step_size: positive scalar, or a callable representing a step size schedule
that maps the iteration index to positive scalar.
that maps the iteration index to a positive scalar.
momentum: optional, a positive scalar value for momentum
Returns:

@ -56,9 +56,9 @@ def pjit(fun: Callable,
version of ``fun`` would not fit in a single device's memory, or to speed up
``fun`` by running each operation in parallel across multiple devices.
The partitioning over devices happens automatically based on
propagation of input partitioning specified in ``in_axis_resources`` and
output partitioning specified in ``out_axis_resources``. The resources
The partitioning over devices happens automatically based on the
propagation of the input partitioning specified in ``in_axis_resources`` and
the output partitioning specified in ``out_axis_resources``. The resources
specified in those two arguments must refer to mesh axes, as defined by
the :py:func:`jax.experimental.maps.mesh` context manager. Note that the mesh
definition at ``pjit`` application time is ignored, and the returned function
@ -84,7 +84,7 @@ def pjit(fun: Callable,
mesh axis size, and outputs will be similarly sized according to the local
mesh. ``fun`` will still be executed across *all* devices in the mesh,
including those from other processes, and will be given a global view of the
data spread accross multiple processes as a single array. However, outside
data spread across multiple processes as a single array. However, outside
of ``pjit`` every process only "sees" its local piece of the input and output,
corresponding to its local sub-mesh.
@ -140,7 +140,7 @@ def pjit(fun: Callable,
Returns:
A wrapped version of ``fun``, set up for just-in-time compilation and
automatic partitioned by the mesh available at each call site.
automaticly partitioned by the mesh available at each call site.
For example, a convolution operator can be automatically partitioned over
an arbitrary set of devices by a single ```pjit`` application: