mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
DOC: many small fixes
This commit is contained in:
parent
df103f7e66
commit
7392a57b75
CHANGELOG.mdREADME.md
design_notes
docs
autodidax.ipynbautodidax.mdautodidax.pycontributing.mdcustom_vjp_update.mddeveloper.mddevice_memory_profiling.mdfaq.rst
jax-101
01-jax-basics.ipynb01-jax-basics.md04-advanced-autodiff.ipynb04-advanced-autodiff.md05-random-numbers.ipynb05-random-numbers.md05.1-pytrees.ipynb05.1-pytrees.md06-parallelism.ipynb06-parallelism.md
jax.numpy.rstjaxpr.rstmulti_process.mdnotebooks
Common_Gotchas_in_JAX.ipynbCommon_Gotchas_in_JAX.mdCustom_derivative_rules_for_Python_code.ipynbCustom_derivative_rules_for_Python_code.mdHow_JAX_primitives_work.ipynbHow_JAX_primitives_work.mdNeural_Network_and_Data_Loading.ipynbNeural_Network_and_Data_Loading.mdWriting_custom_interpreters_in_Jax.ipynbWriting_custom_interpreters_in_Jax.mdXLA_in_Python.ipynbXLA_in_Python.mdautodiff_cookbook.ipynbautodiff_cookbook.mdconvolutions.ipynbconvolutions.mdmaml.ipynbmaml.mdneural_network_with_tfds_data.ipynbneural_network_with_tfds_data.mdquickstart.ipynbquickstart.mdscore_matching.ipynbscore_matching.mdthinking_in_jax.ipynbthinking_in_jax.mdxmap_tutorial.ipynbxmap_tutorial.md
pytrees.mdrank_promotion_warning.rsttransformations.mdjax
_src
experimental
@ -50,7 +50,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
|
||||
* Bug fixes:
|
||||
* Tightened the checks for lax.argmin and lax.argmax to ensure they are
|
||||
not used with invalid `axis` value, or with an empty reduction dimension.
|
||||
not used with an invalid `axis` value, or with an empty reduction dimension.
|
||||
({jax-issue}`#7196`)
|
||||
|
||||
|
||||
@ -333,7 +333,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* Bug fixes:
|
||||
* `jax.numpy.arccosh` now returns the same branch as `numpy.arccosh` for
|
||||
complex inputs ({jax-issue}`#5156`)
|
||||
* `host_callback.id_tap` now works for `jax.pmap` also. There is a
|
||||
* `host_callback.id_tap` now works for `jax.pmap` also. There is an
|
||||
optional parameter for `id_tap` and `id_print` to request that the
|
||||
device from which the value is tapped be passed as a keyword argument
|
||||
to the tap function ({jax-issue}`#5182`).
|
||||
@ -359,7 +359,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* New features:
|
||||
* Add `jax.device_put_replicated`
|
||||
* Add multi-host support to `jax.experimental.sharded_jit`
|
||||
* Add support for differentiating eigenvaleus computed by `jax.numpy.linalg.eig`
|
||||
* Add support for differentiating eigenvalues computed by `jax.numpy.linalg.eig`
|
||||
* Add support for building on Windows platforms
|
||||
* Add support for general in_axes and out_axes in `jax.pmap`
|
||||
* Add complex support for `jax.numpy.linalg.slogdet`
|
||||
@ -504,7 +504,7 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.
|
||||
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.1.73...jax-v0.1.74).
|
||||
* New Features:
|
||||
* BFGS (#3101)
|
||||
* TPU suppot for half-precision arithmetic (#3878)
|
||||
* TPU support for half-precision arithmetic (#3878)
|
||||
* Bug Fixes:
|
||||
* Prevent some accidental dtype warnings (#3874)
|
||||
* Fix a multi-threading bug in custom derivatives (#3845, #3869)
|
||||
|
@ -109,7 +109,8 @@ or the [examples](https://github.com/google/jax/tree/main/examples).
|
||||
## Transformations
|
||||
|
||||
At its core, JAX is an extensible system for transforming numerical functions.
|
||||
Here are four of primary interest: `grad`, `jit`, `vmap`, and `pmap`.
|
||||
Here are four transformations of primary interest: `grad`, `jit`, `vmap`, and
|
||||
`pmap`.
|
||||
|
||||
### Automatic differentiation with `grad`
|
||||
|
||||
|
@ -99,7 +99,7 @@ The name "omnistaging" means staging out everything possible.
|
||||
|
||||
### Toy example
|
||||
|
||||
iJAX transformations like `jit` and `pmap` stage out computations to XLA. That
|
||||
JAX transformations like `jit` and `pmap` stage out computations to XLA. That
|
||||
is, we apply them to functions comprising multiple primitive operations so that
|
||||
rather being executed one at a time from Python the operations are all part of
|
||||
one end-to-end optimized XLA computation.
|
||||
|
@ -666,7 +666,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Notice both `lift` and `sublift` package a value into a `JVPTracer` with the\n",
|
||||
"Notice both `pure` and `lift` package a value into a `JVPTracer` with the\n",
|
||||
"minimal amount of context, which is a zero tangent value.\n",
|
||||
"\n",
|
||||
"Let's add some JVP rules for primitives:"
|
||||
@ -1312,7 +1312,7 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Jaxpr data strutures\n",
|
||||
"### Jaxpr data structures\n",
|
||||
"\n",
|
||||
"The jaxpr term syntax is roughly:\n",
|
||||
"\n",
|
||||
@ -2720,7 +2720,7 @@
|
||||
" g:float64[] = neg e\n",
|
||||
" in ( g ) }\n",
|
||||
"```\n",
|
||||
"This second jaxpr is represents the linear computation that we want from\n",
|
||||
"This second jaxpr represents the linear computation that we want from\n",
|
||||
"`linearize`.\n",
|
||||
"\n",
|
||||
"However, unlike in this jaxpr example, we want the computation on known values\n",
|
||||
@ -2729,7 +2729,7 @@
|
||||
"operations out of Python first before sorting out what can be evaluated now\n",
|
||||
"and what must be delayed, we want only to form a jaxpr for those operations\n",
|
||||
"that _must_ be delayed due to a dependence on unknown inputs. In the context\n",
|
||||
"of automatic differentiation, this is the feature ultimately enables us to\n",
|
||||
"of automatic differentiation, this is the feature that ultimately enables us to\n",
|
||||
"handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control\n",
|
||||
"flow works because partial evaluation keeps the primal computation in Python.\n",
|
||||
"As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out\n",
|
||||
@ -2874,9 +2874,10 @@
|
||||
"(evaluating it in Python) and avoid forming tracers corresponding to the\n",
|
||||
"output. If instead any input is unknown then we instead stage out into a\n",
|
||||
"`JaxprEqnRecipe` representing the primitive application. To build the tracers\n",
|
||||
"representing unknown outputs, we need avals, which get from the abstract eval\n",
|
||||
"rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s\n",
|
||||
"reference tracers; we avoid circular garbage by using weakrefs.)\n",
|
||||
"representing unknown outputs, we need avals, which we get from the abstract\n",
|
||||
"eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and\n",
|
||||
"`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using\n",
|
||||
"weakrefs.)\n",
|
||||
"\n",
|
||||
"That `process_primitive` logic applies to most primitives, but `xla_call_p`\n",
|
||||
"requires recursive treatment. So we special-case its rule in a\n",
|
||||
@ -3312,7 +3313,7 @@
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"We use `UndefPrimal` instances to indicate which arguments with respect to\n",
|
||||
"with we want to transpose. These arise because in general, being explicit\n",
|
||||
"which we want to transpose. These arise because in general, being explicit\n",
|
||||
"about closed-over values, we want to transpose functions of type\n",
|
||||
"`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the\n",
|
||||
"inputs with respect to which the function is linear could be scattered through\n",
|
||||
|
@ -505,7 +505,7 @@ class JVPTrace(Trace):
|
||||
jvp_rules = {}
|
||||
```
|
||||
|
||||
Notice both `lift` and `sublift` package a value into a `JVPTracer` with the
|
||||
Notice both `pure` and `lift` package a value into a `JVPTracer` with the
|
||||
minimal amount of context, which is a zero tangent value.
|
||||
|
||||
Let's add some JVP rules for primitives:
|
||||
@ -960,7 +960,7 @@ jaxpr and then interpreting the jaxpr.)
|
||||
|
||||
+++
|
||||
|
||||
### Jaxpr data strutures
|
||||
### Jaxpr data structures
|
||||
|
||||
The jaxpr term syntax is roughly:
|
||||
|
||||
@ -2012,7 +2012,7 @@ and tangent jaxprs:
|
||||
g:float64[] = neg e
|
||||
in ( g ) }
|
||||
```
|
||||
This second jaxpr is represents the linear computation that we want from
|
||||
This second jaxpr represents the linear computation that we want from
|
||||
`linearize`.
|
||||
|
||||
However, unlike in this jaxpr example, we want the computation on known values
|
||||
@ -2021,7 +2021,7 @@ forming a jaxpr for the entire function `(a1, a2) -> (b1, b2)`, staging all
|
||||
operations out of Python first before sorting out what can be evaluated now
|
||||
and what must be delayed, we want only to form a jaxpr for those operations
|
||||
that _must_ be delayed due to a dependence on unknown inputs. In the context
|
||||
of automatic differentiation, this is the feature ultimately enables us to
|
||||
of automatic differentiation, this is the feature that ultimately enables us to
|
||||
handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
|
||||
flow works because partial evaluation keeps the primal computation in Python.
|
||||
As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
|
||||
@ -2122,9 +2122,10 @@ inputs are known then we can bind the primitive on the known values
|
||||
(evaluating it in Python) and avoid forming tracers corresponding to the
|
||||
output. If instead any input is unknown then we instead stage out into a
|
||||
`JaxprEqnRecipe` representing the primitive application. To build the tracers
|
||||
representing unknown outputs, we need avals, which get from the abstract eval
|
||||
rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s
|
||||
reference tracers; we avoid circular garbage by using weakrefs.)
|
||||
representing unknown outputs, we need avals, which we get from the abstract
|
||||
eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and
|
||||
`JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using
|
||||
weakrefs.)
|
||||
|
||||
That `process_primitive` logic applies to most primitives, but `xla_call_p`
|
||||
requires recursive treatment. So we special-case its rule in a
|
||||
@ -2468,7 +2469,7 @@ register_pytree_node(UndefPrimal,
|
||||
```
|
||||
|
||||
We use `UndefPrimal` instances to indicate which arguments with respect to
|
||||
with we want to transpose. These arise because in general, being explicit
|
||||
which we want to transpose. These arise because in general, being explicit
|
||||
about closed-over values, we want to transpose functions of type
|
||||
`a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the
|
||||
inputs with respect to which the function is linear could be scattered through
|
||||
|
@ -486,7 +486,7 @@ class JVPTrace(Trace):
|
||||
jvp_rules = {}
|
||||
# -
|
||||
|
||||
# Notice both `lift` and `sublift` package a value into a `JVPTracer` with the
|
||||
# Notice both `pure` and `lift` package a value into a `JVPTracer` with the
|
||||
# minimal amount of context, which is a zero tangent value.
|
||||
#
|
||||
# Let's add some JVP rules for primitives:
|
||||
@ -919,7 +919,7 @@ jacfwd(f, np.arange(3.))
|
||||
# control flow, any transformation could be implemented by first tracing to a
|
||||
# jaxpr and then interpreting the jaxpr.)
|
||||
|
||||
# ### Jaxpr data strutures
|
||||
# ### Jaxpr data structures
|
||||
#
|
||||
# The jaxpr term syntax is roughly:
|
||||
#
|
||||
@ -1930,7 +1930,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
|
||||
# g:float64[] = neg e
|
||||
# in ( g ) }
|
||||
# ```
|
||||
# This second jaxpr is represents the linear computation that we want from
|
||||
# This second jaxpr represents the linear computation that we want from
|
||||
# `linearize`.
|
||||
#
|
||||
# However, unlike in this jaxpr example, we want the computation on known values
|
||||
@ -1939,7 +1939,7 @@ def vspace(aval: ShapedArray) -> ShapedArray:
|
||||
# operations out of Python first before sorting out what can be evaluated now
|
||||
# and what must be delayed, we want only to form a jaxpr for those operations
|
||||
# that _must_ be delayed due to a dependence on unknown inputs. In the context
|
||||
# of automatic differentiation, this is the feature ultimately enables us to
|
||||
# of automatic differentiation, this is the feature that ultimately enables us to
|
||||
# handle functions like `grad(lambda x: x**2 if x > 0 else 0.)`. Python control
|
||||
# flow works because partial evaluation keeps the primal computation in Python.
|
||||
# As a consequence, our `Trace` and `Tracer` subclasses must on the fly sort out
|
||||
@ -2036,9 +2036,10 @@ class PartialEvalTracer(Tracer):
|
||||
# (evaluating it in Python) and avoid forming tracers corresponding to the
|
||||
# output. If instead any input is unknown then we instead stage out into a
|
||||
# `JaxprEqnRecipe` representing the primitive application. To build the tracers
|
||||
# representing unknown outputs, we need avals, which get from the abstract eval
|
||||
# rules. (Notice that tracers reference `JaxprEqnRecipe`s, and `JaxprEqnRecipe`s
|
||||
# reference tracers; we avoid circular garbage by using weakrefs.)
|
||||
# representing unknown outputs, we need avals, which we get from the abstract
|
||||
# eval rules. (Notice that tracers reference `JaxprEqnRecipe`s, and
|
||||
# `JaxprEqnRecipe`s reference tracers; we avoid circular garbage by using
|
||||
# weakrefs.)
|
||||
#
|
||||
# That `process_primitive` logic applies to most primitives, but `xla_call_p`
|
||||
# requires recursive treatment. So we special-case its rule in a
|
||||
@ -2376,7 +2377,7 @@ register_pytree_node(UndefPrimal,
|
||||
# -
|
||||
|
||||
# We use `UndefPrimal` instances to indicate which arguments with respect to
|
||||
# with we want to transpose. These arise because in general, being explicit
|
||||
# which we want to transpose. These arise because in general, being explicit
|
||||
# about closed-over values, we want to transpose functions of type
|
||||
# `a -> b -o c` to functions of type `a -> c -o b`. Even more generally, the
|
||||
# inputs with respect to which the function is linear could be scattered through
|
||||
|
@ -160,6 +160,6 @@ fix the issues you can push new commits to your branch.
|
||||
Once your PR has been reviewed, a JAX maintainer will mark it as `Pull Ready`. This
|
||||
will trigger a larger set of tests, including tests on GPU and TPU backends that are
|
||||
not available via standard GitHub CI. Detailed results of these tests are not publicly
|
||||
viweable, but the JAX mantainer assigned to your PR will communicate with you regarding
|
||||
viewable, but the JAX maintainer assigned to your PR will communicate with you regarding
|
||||
any failures these might uncover; it's not uncommon, for example, that numerical tests
|
||||
need different tolerances on TPU than on CPU.
|
||||
|
@ -87,7 +87,7 @@ skip_app.defvjp(skip_app_fwd, skip_app_bwd)
|
||||
## Explanation
|
||||
|
||||
Passing `Tracer`s into `nondiff_argnums` arguments was always buggy. While there
|
||||
were some cases which worked correctly, others would lead to complex and
|
||||
were some cases that worked correctly, others would lead to complex and
|
||||
confusing error messages.
|
||||
|
||||
The essence of the bug was that `nondiff_argnums` was implemented in a way that
|
||||
|
@ -85,7 +85,7 @@ You can either install Python using its
|
||||
[Windows installer](https://www.python.org/downloads/), or if you prefer, you
|
||||
can use [Anaconda](https://docs.anaconda.com/anaconda/install/windows/)
|
||||
or [Miniconda](https://docs.conda.io/en/latest/miniconda.html#windows-installers)
|
||||
to setup a Python environment.
|
||||
to set up a Python environment.
|
||||
|
||||
Some targets of Bazel use bash utilities to do scripting, so [MSYS2](https://www.msys2.org)
|
||||
is needed. See [Installing Bazel on Windows](https://docs.bazel.build/versions/master/install-windows.html#installing-compilers-and-language-runtimes)
|
||||
@ -174,7 +174,7 @@ python tests/lax_numpy_test.py --test_targets="testPad"
|
||||
|
||||
The Colab notebooks are tested for errors as part of the documentation build.
|
||||
|
||||
Note that to run the full pmap tests on a (multi-core) CPU only machine, you
|
||||
Note that to run the full pmap tests on a (multi-core) CPU-only machine, you
|
||||
can run:
|
||||
|
||||
```
|
||||
@ -278,7 +278,7 @@ See `exclude_patterns` in [conf.py](https://github.com/google/jax/blob/main/docs
|
||||
|
||||
## Documentation building on readthedocs.io
|
||||
|
||||
JAX's auto-generated documentations is at <https://jax.readthedocs.io/>.
|
||||
JAX's auto-generated documentation is at <https://jax.readthedocs.io/>.
|
||||
|
||||
The documentation building is controlled for the entire project by the
|
||||
[readthedocs JAX settings](https://readthedocs.org/dashboard/jax). The current settings
|
||||
|
@ -80,7 +80,7 @@ For more information about how to interpret callgraph visualizations, see the
|
||||
|
||||
Functions compiled with {func}`jax.jit` are opaque to the device memory profiler.
|
||||
That is, any memory allocated inside a `jit`-compiled function will be
|
||||
attributed to the function as whole.
|
||||
attributed to the function as a whole.
|
||||
|
||||
In the example, the call to `block_until_ready()` is to ensure that `func2`
|
||||
completes before the device memory profile is collected. See
|
||||
@ -90,7 +90,7 @@ completes before the device memory profile is collected. See
|
||||
|
||||
We can also use the JAX device memory profiler to track down memory leaks by using
|
||||
`pprof` to visualize the change in memory usage between two device memory profiles
|
||||
taken at different times. For example consider the following program which
|
||||
taken at different times. For example, consider the following program which
|
||||
accumulates JAX arrays into a constantly-growing Python list.
|
||||
|
||||
```python
|
||||
|
10
docs/faq.rst
10
docs/faq.rst
@ -60,10 +60,10 @@ If your ``jit`` decorated function takes tens of seconds (or more!) to run the
|
||||
first time you call it, but executes quickly when called again, JAX is taking a
|
||||
long time to trace or compile your code.
|
||||
|
||||
This is usually a symptom of calling your function generating a large amount of
|
||||
This is usually a sign that calling your function generates a large amount of
|
||||
code in JAX's internal representation, typically because it makes heavy use of
|
||||
Python control flow such as ``for`` loop. For a handful of loop iterations
|
||||
Python is OK, but if you need _many_ loop iterations, you should rewrite your
|
||||
Python control flow such as ``for`` loops. For a handful of loop iterations,
|
||||
Python is OK, but if you need *many* loop iterations, you should rewrite your
|
||||
code to make use of JAX's
|
||||
`structured control flow primitives <https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#Structured-control-flow-primitives>`_
|
||||
(such as :func:`lax.scan`) or avoid wrapping the loop with ``jit`` (you can
|
||||
@ -206,7 +206,7 @@ running full applications, which inevitably include some amount of both data
|
||||
transfer and compilation. Also, we were careful to pick large enough arrays
|
||||
(1000x1000) and an intensive enough computation (the ``@`` operator is
|
||||
performing matrix-matrix multiplication) to amortize the increased overhead of
|
||||
JAX/accelerators vs NumPy/CPU. For example, if switch this example to use
|
||||
JAX/accelerators vs NumPy/CPU. For example, if we switch this example to use
|
||||
10x10 input instead, JAX/GPU runs 10x slower than NumPy/CPU (100 µs vs 10 µs).
|
||||
|
||||
.. _To JIT or not to JIT: https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#to-jit-or-not-to-jit
|
||||
@ -322,7 +322,7 @@ are not careful you may obtain a ``NaN`` for reverse differentiation::
|
||||
jax.grad(my_log)(0.) ==> NaN
|
||||
|
||||
A short explanation is that during ``grad`` computation the adjoint corresponding
|
||||
to the undefined ``jnp.log(x)`` is a ``NaN`` and when it gets accumulated to the
|
||||
to the undefined ``jnp.log(x)`` is a ``NaN`` and it gets accumulated to the
|
||||
adjoint of the ``jnp.where``. The correct way to write such functions is to ensure
|
||||
that there is a ``jnp.where`` *inside* the partially-defined function, to ensure
|
||||
that the adjoint is always finite::
|
||||
|
@ -219,7 +219,7 @@
|
||||
"\n",
|
||||
"(Like $\\nabla$, `jax.grad` will only work on functions with a scalar output -- it will raise an error otherwise.)\n",
|
||||
"\n",
|
||||
"This makes the JAX API quite different to other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.\n",
|
||||
"This makes the JAX API quite different from other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.\n",
|
||||
"\n",
|
||||
"This way of doing things makes it straightforward to control things like which variables to differentiate with respect to. By default, `jax.grad` will find the gradient with respect to the first argument. In the example below, the result of `sum_squared_error_dx` will be the gradient of `sum_squared_error` with respect to `x`."
|
||||
]
|
||||
|
@ -120,7 +120,7 @@ Analogously, `jax.grad(f)` is the function that computes the gradient, so `jax.g
|
||||
|
||||
(Like $\nabla$, `jax.grad` will only work on functions with a scalar output -- it will raise an error otherwise.)
|
||||
|
||||
This makes the JAX API quite different to other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.
|
||||
This makes the JAX API quite different from other autodiff libraries like Tensorflow and PyTorch, where to compute the gradient we use the loss tensor itself (e.g. by calling `loss.backward()`). The JAX API works directly with functions, staying closer to the underlying math. Once you become accustomed to this way of doing things, it feels natural: your loss function in code really is a function of parameters and data, and you find its gradient just like you would in the math.
|
||||
|
||||
This way of doing things makes it straightforward to control things like which variables to differentiate with respect to. By default, `jax.grad` will find the gradient with respect to the first argument. In the example below, the result of `sum_squared_error_dx` will be the gradient of `sum_squared_error` with respect to `x`.
|
||||
|
||||
|
@ -290,7 +290,7 @@
|
||||
"\n",
|
||||
"This update is not the gradient of any loss function.\n",
|
||||
"\n",
|
||||
"However it can be **written** as the gradient of the pseudo loss function\n",
|
||||
"However, it can be **written** as the gradient of the pseudo loss function\n",
|
||||
"\n",
|
||||
"$$\n",
|
||||
"L(\\theta) = [r_t + v_{\\theta}(s_t) - v_{\\theta}(s_{t-1})]^2\n",
|
||||
|
@ -186,7 +186,7 @@ $$
|
||||
|
||||
This update is not the gradient of any loss function.
|
||||
|
||||
However it can be **written** as the gradient of the pseudo loss function
|
||||
However, it can be **written** as the gradient of the pseudo loss function
|
||||
|
||||
$$
|
||||
L(\theta) = [r_t + v_{\theta}(s_t) - v_{\theta}(s_{t-1})]^2
|
||||
|
@ -258,7 +258,7 @@
|
||||
"\n",
|
||||
"This doesn't seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX. \n",
|
||||
"\n",
|
||||
"Making this code reproducible in JAX would require to enforce this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.\n",
|
||||
"Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.\n",
|
||||
"\n",
|
||||
"To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key` ."
|
||||
]
|
||||
|
@ -140,7 +140,7 @@ The output of this code can only satisfy requirement #1 if we assume a specific
|
||||
|
||||
This doesn't seem to be a major issue in NumPy, as it is already enforced by Python, but it becomes an issue in JAX.
|
||||
|
||||
Making this code reproducible in JAX would require to enforce this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.
|
||||
Making this code reproducible in JAX would require enforcing this specific order of execution. This would violate requirement #2, as JAX should be able to parallelize `bar` and `baz` when jitting as these functions don't actually depend on each other.
|
||||
|
||||
To avoid this issue, JAX does not use a global state. Instead, random functions explicitly consume the state, which is referred to as a `key` .
|
||||
|
||||
|
@ -338,7 +338,7 @@
|
||||
"source": [
|
||||
"## Custom pytree nodes\n",
|
||||
"\n",
|
||||
"So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define my own container class, it will be considered a leaf, even if it has trees inside it:"
|
||||
"So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -193,7 +193,7 @@ plt.legend();
|
||||
|
||||
## Custom pytree nodes
|
||||
|
||||
So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define my own container class, it will be considered a leaf, even if it has trees inside it:
|
||||
So far, we've only been considering pytrees of lists, tuples, and dicts; everything else is considered a leaf. Therefore, if you define your own container class, it will be considered a leaf, even if it has trees inside it:
|
||||
|
||||
```{code-cell}
|
||||
:id: CK8LN2PRFnQf
|
||||
|
@ -555,7 +555,7 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from typing import NamedTuple\n",
|
||||
"from typing import NamedTuple, Tuple\n",
|
||||
"import functools\n",
|
||||
"\n",
|
||||
"class Params(NamedTuple):\n",
|
||||
@ -585,7 +585,7 @@
|
||||
"# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it\n",
|
||||
"# 'num_devices', but could have used anything, so long as `pmean` used the same.\n",
|
||||
"@functools.partial(jax.pmap, axis_name='num_devices')\n",
|
||||
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Params:\n",
|
||||
"def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:\n",
|
||||
" \"\"\"Performs one SGD update step on params using the given data.\"\"\"\n",
|
||||
"\n",
|
||||
" # Compute the gradients on the given minibatch (individually on each device).\n",
|
||||
|
@ -228,7 +228,7 @@ If this example is too confusing, you can find the same example, but without par
|
||||
```{code-cell}
|
||||
:id: cI8xQqzRrc-4
|
||||
|
||||
from typing import NamedTuple
|
||||
from typing import NamedTuple, Tuple
|
||||
import functools
|
||||
|
||||
class Params(NamedTuple):
|
||||
@ -258,7 +258,7 @@ LEARNING_RATE = 0.005
|
||||
# to later tell `jax.lax.pmean` which axis to reduce over. Here, we call it
|
||||
# 'num_devices', but could have used anything, so long as `pmean` used the same.
|
||||
@functools.partial(jax.pmap, axis_name='num_devices')
|
||||
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Params:
|
||||
def update(params: Params, xs: jnp.ndarray, ys: jnp.ndarray) -> Tuple[Params, jnp.ndarray]:
|
||||
"""Performs one SGD update step on params using the given data."""
|
||||
|
||||
# Compute the gradients on the given minibatch (individually on each device).
|
||||
|
@ -12,8 +12,8 @@ While JAX tries to follow the NumPy API as closely as possible, sometimes JAX
|
||||
cannot follow NumPy exactly.
|
||||
|
||||
* Notably, since JAX arrays are immutable, NumPy APIs that mutate arrays
|
||||
in-place cannot be implemented in JAX. However, often JAX is able to provide a
|
||||
alternative API that is purely functional. For example, instead of in-place
|
||||
in-place cannot be implemented in JAX. However, often JAX is able to provide
|
||||
an alternative API that is purely functional. For example, instead of in-place
|
||||
array updates (:code:`x[i] = y`), JAX provides an alternative pure indexed
|
||||
update function :func:`jax.ops.index_update`.
|
||||
|
||||
|
@ -65,8 +65,8 @@ Equations are printed as follows::
|
||||
Eqn ::= let Var+ = Primitive [ Param* ] Expr+
|
||||
|
||||
where:
|
||||
* ``Var+`` are one or more intermediate variables to be defined as the
|
||||
output of a primitive invocation (some primitives can return multiple values)
|
||||
* ``Var+`` are one or more intermediate variables to be defined as the output
|
||||
of a primitive invocation (some primitives can return multiple values).
|
||||
* ``Expr+`` are one or more atomic expressions, each either a variable or a
|
||||
literal constant. A special variable ``unitvar`` or literal ``unit``,
|
||||
printed as ``*``, represents a value that is not needed
|
||||
@ -235,7 +235,7 @@ The cond primitive has a number of parameters:
|
||||
parameters are used linearly in the conditional.
|
||||
|
||||
The above instance of the cond primitive takes two operands. The first
|
||||
one (``c``) is the branch index, then ``b`` is the operand (``arg``) to
|
||||
one (``d``) is the branch index, then ``b`` is the operand (``arg``) to
|
||||
be passed to whichever jaxpr in ``branches`` is selected by the branch
|
||||
index.
|
||||
|
||||
@ -267,7 +267,7 @@ Another example, using :py:func:`lax.cond`:
|
||||
In this case, the boolean predicate is converted to an integer index
|
||||
(0 or 1), and ``branches`` are jaxprs that correspond to the false and
|
||||
true branch functionals, in that order. Again, each functional takes
|
||||
one input variable, corresponding to ``xtrue`` and ``xfalse``
|
||||
one input variable, corresponding to ``xfalse`` and ``xtrue``
|
||||
respectively.
|
||||
|
||||
The following example shows a more complicated situation when the input
|
||||
@ -341,7 +341,7 @@ For example, here is an example fori loop
|
||||
cond_nconsts=0 ] c a 0 b d
|
||||
in (e,) }
|
||||
|
||||
The while primitive takes 5 arguments: ``c a 0 b e``, as follows:
|
||||
The while primitive takes 5 arguments: ``c a 0 b d``, as follows:
|
||||
|
||||
* 0 constants for ``cond_jaxpr`` (since ``cond_nconsts`` is 0)
|
||||
* 2 constants for ``body_jaxpr`` (``c``, and ``a``)
|
||||
@ -406,8 +406,8 @@ XLA_call
|
||||
^^^^^^^^
|
||||
|
||||
The call primitive arises from JIT compilation, and it encapsulates
|
||||
a sub-jaxpr along with parameters the specify the backend and the device the
|
||||
computation should run. For example
|
||||
a sub-jaxpr along with parameters that specify the backend and the device on
|
||||
which the computation should run. For example
|
||||
|
||||
>>> from jax import jit
|
||||
>>>
|
||||
@ -480,10 +480,10 @@ captured using the ``xla_pmap`` primitive. Consider this example
|
||||
out_axes=(0,) ] b a
|
||||
in (c,) }
|
||||
|
||||
The ``xla_pmap`` primitive specifies the name of the axis (parameter ``rows``)
|
||||
and the body of the function to be mapped as the ``call_jaxpr`` parameter.
|
||||
value of this parameter is a Jaxpr with 3 input variables:
|
||||
The ``xla_pmap`` primitive specifies the name of the axis (parameter
|
||||
``axis_name``) and the body of the function to be mapped as the ``call_jaxpr``
|
||||
parameter. The value of this parameter is a Jaxpr with 2 input variables.
|
||||
|
||||
The parameter ``in_axes`` specifies which of the input variables should be
|
||||
mapped and which should be broadcast. In our example, the value of ``extra``
|
||||
is broadcast, the other input values are mapped.
|
||||
is broadcast and the value of ``arr`` is mapped.
|
||||
|
@ -16,7 +16,7 @@ with JAX’s collective operations, we recommend starting with the
|
||||
environments in JAX is direct communication links between accelerators, e.g. the
|
||||
high-speed interconnects for Cloud TPUs or
|
||||
[NCCL](https://developer.nvidia.com/nccl) for GPUs. These links are what allow
|
||||
collective operations to run across multiple process’ worth of accelerators.
|
||||
collective operations to run across multiple processes’ worth of accelerators.
|
||||
|
||||
|
||||
## Multi-process programming model
|
||||
@ -80,7 +80,7 @@ out the {doc}`/jax-101/06-parallelism` notebook.) Each process should call the
|
||||
same pmapped function and pass in arguments to be mapped across its _local_
|
||||
devices (i.e., the pmapped axis size is equal to the number of local
|
||||
devices). Similarly, the function will return outputs sharded across _local_
|
||||
devices only. Inside the function however, collective communication operations
|
||||
devices only. Inside the function, however, collective communication operations
|
||||
are run across all _global_ devices, across all processes. Conceptually, this
|
||||
can be thought of as running a pmap over a single array sharded across hosts,
|
||||
where each host “sees” only its local shard of the input and output.
|
||||
|
@ -955,7 +955,7 @@
|
||||
"source": [
|
||||
"JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state! \n",
|
||||
"\n",
|
||||
"Reusing the same state will cause __sadness__ and __monotony__, depriving the enduser of __lifegiving chaos__:"
|
||||
"Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -461,7 +461,7 @@ key
|
||||
|
||||
JAX's random functions produce pseudorandom numbers from the PRNG state, but __do not__ change the state!
|
||||
|
||||
Reusing the same state will cause __sadness__ and __monotony__, depriving the enduser of __lifegiving chaos__:
|
||||
Reusing the same state will cause __sadness__ and __monotony__, depriving the end user of __lifegiving chaos__:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 7zUdQMynoE5e
|
||||
|
@ -850,7 +850,7 @@
|
||||
"id": "p2xFQAte19sF"
|
||||
},
|
||||
"source": [
|
||||
"This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \\mapsto x^*(a)$ that is implicity defined by equation $x = f(a, x)$.\n",
|
||||
"This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \\mapsto x^*(a)$ that is implicitly defined by equation $x = f(a, x)$.\n",
|
||||
"\n",
|
||||
"We can use `fixed_point` to run iterative procedures to convergence, for example running Newton's method to calculate square roots while only executing adds, multiplies, and divides:"
|
||||
]
|
||||
@ -1414,7 +1414,7 @@
|
||||
"id": "kZ0yc-Ihoezk"
|
||||
},
|
||||
"source": [
|
||||
"Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
|
||||
"Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -1782,7 +1782,7 @@
|
||||
"id": "GwC26P9kn8qw"
|
||||
},
|
||||
"source": [
|
||||
"Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
|
||||
"Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -437,7 +437,7 @@ def fixed_point(f, a, x_guess):
|
||||
|
||||
+++ {"id": "p2xFQAte19sF"}
|
||||
|
||||
This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \mapsto x^*(a)$ that is implicity defined by equation $x = f(a, x)$.
|
||||
This is an iterative procedure for numerically solving the equation $x = f(a, x)$ for $x$, by iterating $x_{t+1} = f(a, x_t)$ until $x_{t+1}$ is sufficiently close to $x_t$. The result $x^*$ depends on the parameters $a$, and so we can think of there being a function $a \mapsto x^*(a)$ that is implicitly defined by equation $x = f(a, x)$.
|
||||
|
||||
We can use `fixed_point` to run iterative procedures to convergence, for example running Newton's method to calculate square roots while only executing adds, multiplies, and divides:
|
||||
|
||||
@ -739,7 +739,7 @@ print(grad(f, 1)(2., 3.))
|
||||
|
||||
+++ {"id": "kZ0yc-Ihoezk"}
|
||||
|
||||
Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
|
||||
Calling a `jax.custom_jvp` function with keyword arguments, or writing a `jax.custom_jvp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
|
||||
|
||||
+++ {"id": "3FGwfT67PDs9"}
|
||||
|
||||
@ -919,7 +919,7 @@ print(grad(f)(2., 3.))
|
||||
|
||||
+++ {"id": "GwC26P9kn8qw"}
|
||||
|
||||
Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguosly mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
|
||||
Calling a `jax.custom_vjp` function with keyword arguments, or writing a `jax.custom_vjp` function definition with default arguments, are both allowed so long as they can be unambiguously mapped to positional arguments based on the function signature retrieved by the standard library `inspect.signature` mechanism.
|
||||
|
||||
+++ {"id": "XfH-ae8bYt6-"}
|
||||
|
||||
|
@ -42,7 +42,7 @@
|
||||
"Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically\n",
|
||||
"as \"multiply_add(x, y, z) = x * y + z\". \n",
|
||||
"This function operates on 3 identically-shaped tensors of floating point \n",
|
||||
"values and performs the opertions pointwise."
|
||||
"values and performs the operations pointwise."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -601,7 +601,7 @@
|
||||
"\n",
|
||||
"JAX compilation works by compiling each primitive into a graph of XLA operations.\n",
|
||||
"\n",
|
||||
"This is biggest hurdle to adding new functionality to JAX, because the \n",
|
||||
"This is the biggest hurdle to adding new functionality to JAX, because the \n",
|
||||
"set of XLA operations is limited, and JAX already has pre-defined primitives\n",
|
||||
"for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++."
|
||||
]
|
||||
@ -976,7 +976,7 @@
|
||||
"Observe also that JAX uses the special abstract tangent value `Zero` for\n",
|
||||
"the tangent corresponding to the 3rd argument of `ma`. This reflects the \n",
|
||||
"fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,\n",
|
||||
"which flow to 3rd argument to `multiply_add_prim`.\n",
|
||||
"which flows to the 3rd argument to `multiply_add_prim`.\n",
|
||||
"\n",
|
||||
"Observe also that during the abstract evaluation of the tangent we pass the \n",
|
||||
"value 0.0 as the tangent for the 3rd argument. This is due to the use\n",
|
||||
@ -1147,7 +1147,7 @@
|
||||
" w.r.t. tangents in multiply_add_value_and_jvp:\n",
|
||||
" output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))\n",
|
||||
" \n",
|
||||
" Always one of the first two multiplicative arguments are constants.\n",
|
||||
" Always one of the first two multiplicative arguments is a constant.\n",
|
||||
"\n",
|
||||
" Args:\n",
|
||||
" ct: the cotangent of the output of the primitive.\n",
|
||||
@ -1254,7 +1254,7 @@
|
||||
"#### JIT of reverse differentiation \n",
|
||||
"\n",
|
||||
"Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only\n",
|
||||
"abstract values, while in the absensce of JIT we used `ConcreteArray`."
|
||||
"abstract values, while in the absence of JIT we used `ConcreteArray`."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -49,7 +49,7 @@ one can define a new primitive that encapsulates the behavior of the function.
|
||||
Consider that we want to add to JAX support for a multiply-add function with three arguments, defined mathematically
|
||||
as "multiply_add(x, y, z) = x * y + z".
|
||||
This function operates on 3 identically-shaped tensors of floating point
|
||||
values and performs the opertions pointwise.
|
||||
values and performs the operations pointwise.
|
||||
|
||||
+++ {"id": "HIJYIHNTD1yI"}
|
||||
|
||||
@ -347,7 +347,7 @@ with expectNotImplementedError():
|
||||
|
||||
JAX compilation works by compiling each primitive into a graph of XLA operations.
|
||||
|
||||
This is biggest hurdle to adding new functionality to JAX, because the
|
||||
This is the biggest hurdle to adding new functionality to JAX, because the
|
||||
set of XLA operations is limited, and JAX already has pre-defined primitives
|
||||
for most of them. However, XLA includes a `CustomCall` operation that can be used to encapsulate arbitrary functionality defined using C++.
|
||||
|
||||
@ -530,7 +530,7 @@ for the differentiation point, and abstract values for the tangents.
|
||||
Observe also that JAX uses the special abstract tangent value `Zero` for
|
||||
the tangent corresponding to the 3rd argument of `ma`. This reflects the
|
||||
fact that we do not differentiate w.r.t. the 2nd argument to `square_add_prim`,
|
||||
which flow to 3rd argument to `multiply_add_prim`.
|
||||
which flows to the 3rd argument to `multiply_add_prim`.
|
||||
|
||||
Observe also that during the abstract evaluation of the tangent we pass the
|
||||
value 0.0 as the tangent for the 3rd argument. This is due to the use
|
||||
@ -630,7 +630,7 @@ def multiply_add_transpose(ct, x, y, z):
|
||||
w.r.t. tangents in multiply_add_value_and_jvp:
|
||||
output_tangent(xt, yt, zt) = multiply_add_prim(xt, y, multiply_add_prim(x, yt, zt))
|
||||
|
||||
Always one of the first two multiplicative arguments are constants.
|
||||
Always one of the first two multiplicative arguments is a constant.
|
||||
|
||||
Args:
|
||||
ct: the cotangent of the output of the primitive.
|
||||
@ -679,7 +679,7 @@ last use of `multiply_add_prim`: `multiply_add_prim(xt, y, ...)` where `y` is th
|
||||
#### JIT of reverse differentiation
|
||||
|
||||
Notice that the abstract evaluation of the `multiply_add_value_and_jvp` is using only
|
||||
abstract values, while in the absensce of JIT we used `ConcreteArray`.
|
||||
abstract values, while in the absence of JIT we used `ConcreteArray`.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: FZ-JGbWZPq2-
|
||||
|
@ -12,7 +12,7 @@
|
||||
"\n",
|
||||
"**Copyright 2018 Google LLC.**\n",
|
||||
"\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\");you may not use this file except in compliance with the License.\n",
|
||||
"Licensed under the Apache License, Version 2.0 (the \"License\"); you may not use this file except in compliance with the License.\n",
|
||||
"You may obtain a copy of the License at\n",
|
||||
"\n",
|
||||
"https://www.apache.org/licenses/LICENSE-2.0\n",
|
||||
@ -34,7 +34,7 @@
|
||||
"\n",
|
||||
"Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).\n",
|
||||
"\n",
|
||||
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model."
|
||||
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -20,7 +20,7 @@ kernelspec:
|
||||
|
||||
**Copyright 2018 Google LLC.**
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");you may not use this file except in compliance with the License.
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
https://www.apache.org/licenses/LICENSE-2.0
|
||||
@ -37,7 +37,7 @@ limitations under the License.
|
||||
|
||||
Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/docs/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use PyTorch's data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library).
|
||||
|
||||
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model.
|
||||
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: OksHydJDtbbI
|
||||
|
@ -146,12 +146,12 @@
|
||||
"id": "k-HxK9iagnH6"
|
||||
},
|
||||
"source": [
|
||||
"* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions\n",
|
||||
"* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions.\n",
|
||||
"* `jaxpr.outvars` - the `outvars` of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.\n",
|
||||
"* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later)\n",
|
||||
"* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.\n",
|
||||
"* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later).\n",
|
||||
"* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is a list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.\n",
|
||||
"\n",
|
||||
"All together, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want."
|
||||
"Altogether, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -335,7 +335,7 @@
|
||||
"\n",
|
||||
"An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.\n",
|
||||
"\n",
|
||||
"It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L141-L187)."
|
||||
"It turns out that this interpreter will also look similar to the \"transpose\" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234)."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -105,12 +105,12 @@ examine_jaxpr(jax.make_jaxpr(bar)(jnp.ones((5, 10)), jnp.ones(5), jnp.ones(10)))
|
||||
|
||||
+++ {"id": "k-HxK9iagnH6"}
|
||||
|
||||
* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions
|
||||
* `jaxpr.invars` - the `invars` of a Jaxpr are a list of the input variables to Jaxpr, analogous to arguments in Python functions.
|
||||
* `jaxpr.outvars` - the `outvars` of a Jaxpr are the variables that are returned by the Jaxpr. Every Jaxpr has multiple outputs.
|
||||
* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later)
|
||||
* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.
|
||||
* `jaxpr.constvars` - the `constvars` are a list of variables that are also inputs to the Jaxpr, but correspond to constants from the trace (we'll go over these in more detail later).
|
||||
* `jaxpr.eqns` - a list of equations, which are essentially let-bindings. Each equation is a list of input variables, a list of output variables, and a *primitive*, which is used to evaluate inputs to produce outputs. Each equation also has a `params`, a dictionary of parameters.
|
||||
|
||||
All together, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want.
|
||||
Altogether, a Jaxpr encapsulates a simple program that can be evaluated with inputs to produce an output. We'll go over how exactly to do this later. The important thing to note now is that a Jaxpr is a data structure that can be manipulated and evaluated in whatever way we want.
|
||||
|
||||
+++ {"id": "NwY7TurYn6sr"}
|
||||
|
||||
@ -235,7 +235,7 @@ Furthermore, this interpreter does not handle `subjaxprs`, which we will not cov
|
||||
|
||||
An `inverse` interpreter doesn't look too different from `eval_jaxpr`. We'll first set up the registry which will map primitives to their inverses. We'll then write a custom interpreter that looks up primitives in the registry.
|
||||
|
||||
It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L141-L187).
|
||||
It turns out that this interpreter will also look similar to the "transpose" interpreter used in reverse-mode autodifferentiation [found here](https://github.com/google/jax/blob/main/jax/interpreters/ad.py#L164-L234).
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: gSMIT2z1vUpO
|
||||
|
@ -18,7 +18,7 @@
|
||||
"\n",
|
||||
"XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.). \n",
|
||||
"\n",
|
||||
"As end users we interact with the computational primitives offered to us by the HLO spec.\n",
|
||||
"As end users, we interact with the computational primitives offered to us by the HLO spec.\n",
|
||||
"\n",
|
||||
"**Caution: This is a pedagogical notebook covering some low level XLA details, the APIs herein are neither public nor stable!**"
|
||||
]
|
||||
|
@ -26,7 +26,7 @@ XLA is the compiler that JAX uses, and the compiler that TF uses for TPUs and wi
|
||||
|
||||
XLA computations are built as computation graphs in HLO IR, which is then lowered to LLO that is device specific (CPU, GPU, TPU, etc.).
|
||||
|
||||
As end users we interact with the computational primitives offered to us by the HLO spec.
|
||||
As end users, we interact with the computational primitives offered to us by the HLO spec.
|
||||
|
||||
**Caution: This is a pedagogical notebook covering some low level XLA details, the APIs herein are neither public nor stable!**
|
||||
|
||||
|
@ -699,7 +699,7 @@
|
||||
"\n",
|
||||
"To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with \"tall\" Jacobians, but inefficient for \"wide\" Jacobians.\n",
|
||||
"\n",
|
||||
"If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\\mathbb{R}^n$ to a scalar loss value in $\\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\\partial f(x) \\in \\mathbb{R}^{1 \\times n}$, which we often identify with the Gradient vector $\\nabla f(x) \\in \\mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluating the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.\n",
|
||||
"If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\\mathbb{R}^n$ to a scalar loss value in $\\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\\partial f(x) \\in \\mathbb{R}^{1 \\times n}$, which we often identify with the Gradient vector $\\nabla f(x) \\in \\mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.\n",
|
||||
"\n",
|
||||
"To do better for functions like this, we just need to use reverse-mode."
|
||||
]
|
||||
@ -1493,7 +1493,7 @@
|
||||
"id": "jqCvEE8qwGw7"
|
||||
},
|
||||
"source": [
|
||||
"For geneneral $\\mathbb{C} \\to \\mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them with in a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\\mathbb{C} \\to \\mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.\n",
|
||||
"For general $\\mathbb{C} \\to \\mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\\mathbb{C} \\to \\mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.\n",
|
||||
"\n",
|
||||
"Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when `grad` is used for a complex-output function:"
|
||||
]
|
||||
@ -1574,8 +1574,8 @@
|
||||
"There are some useful upshots for how `grad` works here:\n",
|
||||
"\n",
|
||||
"1. We can use `grad` on holomorphic $\\mathbb{C} \\to \\mathbb{C}$ functions.\n",
|
||||
"2. We can use `grad` to optimize $f : \\mathbb{C} \\to \\mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the dierction of the conjugate of `grad(f)(x)`.\n",
|
||||
"3. If we have an $\\mathbb{R} \\to \\mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in covolutions) then `grad` still works and we get the same result that an implementation using only real values would have given.\n",
|
||||
"2. We can use `grad` to optimize $f : \\mathbb{C} \\to \\mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`.\n",
|
||||
"3. If we have an $\\mathbb{R} \\to \\mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then `grad` still works and we get the same result that an implementation using only real values would have given.\n",
|
||||
"\n",
|
||||
"In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\\mathbb{C} \\to \\mathbb{C}$ function, we can do it with JVPs or VJPs!"
|
||||
]
|
||||
@ -1637,7 +1637,7 @@
|
||||
"\n",
|
||||
"In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful. \n",
|
||||
"\n",
|
||||
"There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in a \"Advanced Autodiff Cookbook\" include:\n",
|
||||
"There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an \"Advanced Autodiff Cookbook\" include:\n",
|
||||
"\n",
|
||||
" - Gauss-Newton Vector Products, linearizing once\n",
|
||||
" - Custom VJPs and JVPs\n",
|
||||
|
@ -396,7 +396,7 @@ That memory complexity sounds pretty compelling! So why don't we see forward-mod
|
||||
|
||||
To answer that, first think about how you could use a JVP to build a full Jacobian matrix. If we apply a JVP to a one-hot tangent vector, it reveals one column of the Jacobian matrix, corresponding to the nonzero entry we fed in. So we can build a full Jacobian one column at a time, and to get each column costs about the same as one function evaluation. That will be efficient for functions with "tall" Jacobians, but inefficient for "wide" Jacobians.
|
||||
|
||||
If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluating the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.
|
||||
If you're doing gradient-based optimization in machine learning, you probably want to minimize a loss function from parameters in $\mathbb{R}^n$ to a scalar loss value in $\mathbb{R}$. That means the Jacobian of this function is a very wide matrix: $\partial f(x) \in \mathbb{R}^{1 \times n}$, which we often identify with the Gradient vector $\nabla f(x) \in \mathbb{R}^n$. Building that matrix one column at a time, with each call taking a similar number of FLOPs to evaluate the original function, sure seems inefficient! In particular, for training neural networks, where $f$ is a training loss function and $n$ can be in the millions or billions, this approach just won't scale.
|
||||
|
||||
To do better for functions like this, we just need to use reverse-mode.
|
||||
|
||||
@ -895,7 +895,7 @@ grad(f)(z)
|
||||
|
||||
+++ {"id": "jqCvEE8qwGw7"}
|
||||
|
||||
For geneneral $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them with in a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.
|
||||
For general $\mathbb{C} \to \mathbb{C}$ functions, the Jacobian has 4 real-valued degrees of freedom (as in the 2x2 Jacobian matrices above), so we can't hope to represent all of them within a complex number. But we can for holomorphic functions! A holomorphic function is precisely a $\mathbb{C} \to \mathbb{C}$ function with the special property that its derivative can be represented as a single complex number. (The [Cauchy-Riemann equations](https://en.wikipedia.org/wiki/Cauchy%E2%80%93Riemann_equations) ensure that the above 2x2 Jacobians have the special form of a scale-and-rotate matrix in the complex plane, i.e. the action of a single complex number under multiplication.) And we can reveal that one complex number using a single call to `vjp` with a covector of `1.0`.
|
||||
|
||||
Because this only works for holomorphic functions, to use this trick we need to promise JAX that our function is holomorphic; otherwise, JAX will raise an error when `grad` is used for a complex-output function:
|
||||
|
||||
@ -930,8 +930,8 @@ grad(f, holomorphic=True)(z) # f is not actually holomorphic!
|
||||
There are some useful upshots for how `grad` works here:
|
||||
|
||||
1. We can use `grad` on holomorphic $\mathbb{C} \to \mathbb{C}$ functions.
|
||||
2. We can use `grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the dierction of the conjugate of `grad(f)(x)`.
|
||||
3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in covolutions) then `grad` still works and we get the same result that an implementation using only real values would have given.
|
||||
2. We can use `grad` to optimize $f : \mathbb{C} \to \mathbb{R}$ functions, like real-valued loss functions of complex parameters `x`, by taking steps in the direction of the conjugate of `grad(f)(x)`.
|
||||
3. If we have an $\mathbb{R} \to \mathbb{R}$ function that just happens to use some complex-valued operations internally (some of which must be non-holomorphic, e.g. FFTs used in convolutions) then `grad` still works and we get the same result that an implementation using only real values would have given.
|
||||
|
||||
In any case, JVPs and VJPs are always unambiguous. And if we wanted to compute the full Jacobian matrix of a non-holomorphic $\mathbb{C} \to \mathbb{C}$ function, we can do it with JVPs or VJPs!
|
||||
|
||||
@ -960,7 +960,7 @@ grad(f, holomorphic=True)(A)
|
||||
|
||||
In this notebook, we worked through some easy, and then progressively more complicated, applications of automatic differentiation in JAX. We hope you now feel that taking derivatives in JAX is easy and powerful.
|
||||
|
||||
There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in a "Advanced Autodiff Cookbook" include:
|
||||
There's a whole world of other autodiff tricks and functionality out there. Topics we didn't cover, but hope to in an "Advanced Autodiff Cookbook" include:
|
||||
|
||||
- Gauss-Newton Vector Products, linearizing once
|
||||
- Custom VJPs and JVPs
|
||||
|
@ -169,7 +169,7 @@
|
||||
"source": [
|
||||
"For the more general types of batched convolutions often useful in the context of building deep neural networks, JAX and XLA offer the very general N-dimensional __conv_general_dilated__ function, but it's not very obvious how to use it. We'll give some examples of the common use-cases.\n",
|
||||
"\n",
|
||||
"A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285) is highly recommended reading!\n",
|
||||
"A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285), is highly recommended reading!\n",
|
||||
"\n",
|
||||
"Let's define a simple diagonal edge kernel:"
|
||||
]
|
||||
|
@ -111,7 +111,7 @@ Like in the one-dimensional case, we use `mode='same'` to specify how we would l
|
||||
|
||||
For the more general types of batched convolutions often useful in the context of building deep neural networks, JAX and XLA offer the very general N-dimensional __conv_general_dilated__ function, but it's not very obvious how to use it. We'll give some examples of the common use-cases.
|
||||
|
||||
A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285) is highly recommended reading!
|
||||
A survey of the family of convolutional operators, [a guide to convolutional arithmetic](https://arxiv.org/abs/1603.07285), is highly recommended reading!
|
||||
|
||||
Let's define a simple diagonal edge kernel:
|
||||
|
||||
|
@ -25,10 +25,10 @@
|
||||
"In this notebook we'll go through:\n",
|
||||
"\n",
|
||||
"- how to take gradients, gradients of gradients.\n",
|
||||
"- how to fit a sinusoid function with a neural network (and do auto-batching with vmap)\n",
|
||||
"- how to implement MAML and check its numerics\n",
|
||||
"- how to fit a sinusoid function with a neural network (and do auto-batching with vmap).\n",
|
||||
"- how to implement MAML and check its numerics.\n",
|
||||
"- how to implement MAML for sinusoid task (single-task objective, batching task instances).\n",
|
||||
"- extending MAML to handle batching at the task-level"
|
||||
"- extending MAML to handle batching at the task-level."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -526,7 +526,7 @@
|
||||
"source": [
|
||||
"## Batching Meta-Gradient Across Tasks\n",
|
||||
"\n",
|
||||
"Kind of does the job but not that great. Let's reduce the variance of gradients in outer loop by averaging across a batch of tasks (not just one task at a time). \n",
|
||||
"Kind of does the job but not that great. Let's reduce the variance of the gradients in the outer loop by averaging across a batch of tasks (not just one task at a time). \n",
|
||||
"\n",
|
||||
"vmap is awesome it enables nice handling of batching at two levels: inner-level \"intra-task\" batching, and outer level batching across tasks.\n",
|
||||
"\n",
|
||||
|
@ -31,10 +31,10 @@ Pedagogical tutorial for implementing Model-Agnostic Meta-Learning with JAX's aw
|
||||
In this notebook we'll go through:
|
||||
|
||||
- how to take gradients, gradients of gradients.
|
||||
- how to fit a sinusoid function with a neural network (and do auto-batching with vmap)
|
||||
- how to implement MAML and check its numerics
|
||||
- how to fit a sinusoid function with a neural network (and do auto-batching with vmap).
|
||||
- how to implement MAML and check its numerics.
|
||||
- how to implement MAML for sinusoid task (single-task objective, batching task instances).
|
||||
- extending MAML to handle batching at the task-level
|
||||
- extending MAML to handle batching at the task-level.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: zKVdo3FtgyhE
|
||||
@ -287,7 +287,7 @@ plt.legend()
|
||||
|
||||
## Batching Meta-Gradient Across Tasks
|
||||
|
||||
Kind of does the job but not that great. Let's reduce the variance of gradients in outer loop by averaging across a batch of tasks (not just one task at a time).
|
||||
Kind of does the job but not that great. Let's reduce the variance of the gradients in the outer loop by averaging across a batch of tasks (not just one task at a time).
|
||||
|
||||
vmap is awesome it enables nice handling of batching at two levels: inner-level "intra-task" batching, and outer level batching across tasks.
|
||||
|
||||
|
@ -46,7 +46,7 @@
|
||||
"\n",
|
||||
"Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).\n",
|
||||
"\n",
|
||||
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model."
|
||||
"Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model."
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -44,7 +44,7 @@ _Forked from_ `neural_network_and_data_loading.ipynb`
|
||||
|
||||
Let's combine everything we showed in the [quickstart notebook](https://colab.research.google.com/github/google/jax/blob/main/notebooks/quickstart.ipynb) to train a simple neural network. We will first specify and train a simple MLP on MNIST using JAX for the computation. We will use `tensorflow/datasets` data loading API to load images and labels (because it's pretty great, and the world doesn't need yet another data loading library :P).
|
||||
|
||||
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for builidng our model.
|
||||
Of course, you can use JAX with any API that is compatible with NumPy to make specifying the model a bit more plug-and-play. Here, just for explanatory purposes, we won't use any neural network libraries or special APIs for building our model.
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: OksHydJDtbbI
|
||||
|
@ -196,7 +196,7 @@
|
||||
"id": "iOzp0P_GoJhb"
|
||||
},
|
||||
"source": [
|
||||
"JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there's three main ones:\n",
|
||||
"JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:\n",
|
||||
"\n",
|
||||
" - {func}`~jax.jit`, for speeding up your code\n",
|
||||
" - {func}`~jax.grad`, for taking derivatives\n",
|
||||
|
@ -127,7 +127,7 @@ x = np.random.normal(size=(size, size)).astype(np.float32)
|
||||
|
||||
+++ {"id": "iOzp0P_GoJhb"}
|
||||
|
||||
JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there's three main ones:
|
||||
JAX is much more than just a GPU-backed NumPy. It also comes with a few program transformations that are useful when writing numerical code. For now, there are three main ones:
|
||||
|
||||
- {func}`~jax.jit`, for speeding up your code
|
||||
- {func}`~jax.grad`, for taking derivatives
|
||||
|
@ -77,7 +77,7 @@
|
||||
"source": [
|
||||
"## Compute score matching objective\n",
|
||||
"\n",
|
||||
"The method we apply here was originally proposed by [Hyvarinen et al. (2005)](http://jmlr.org/papers/volume6/hyvarinen05a/old.pdf). The idea behind score matching is to __learn scores:__ the gradients of $\\log p(x)$ w.r.t. $x$. When trained this model can \"improve\" a sample $x$ by changing it in the direction of highest log-probability. However, training such model can get tricky. When predicting a continuous variable, ML folks usually minimize squared error:\n",
|
||||
"The method we apply here was originally proposed by [Hyvarinen et al. (2005)](http://jmlr.org/papers/volume6/hyvarinen05a/old.pdf). The idea behind score matching is to __learn scores:__ the gradients of $\\log p(x)$ w.r.t. $x$. When trained this model can \"improve\" a sample $x$ by changing it in the direction of highest log-probability. However, training such a model can get tricky. When predicting a continuous variable, ML folks usually minimize squared error:\n",
|
||||
"\n",
|
||||
"$$ L_{mse} = E_{x \\sim p(x)} \\left\\lVert model(x) - \\nabla_x \\log p(x) \\right\\lVert_2^2 $$\n",
|
||||
"\n",
|
||||
@ -85,7 +85,7 @@
|
||||
"\n",
|
||||
"$$ L_{matching} = E_{x \\sim p(x)} \\space tr( \\space \\mathbf{J}_x [\\space model(x) \\space]) + \\frac12 \\left\\Vert model(x) \\right\\lVert_2^2 $$\n",
|
||||
"\n",
|
||||
"Here $tr( \\space \\mathbf{J}_x [\\space model(x) \\space])$ is a trace of Jacobian of $model(x)$ w.r.t. $x$. Now all it takes is to minimize the second objective with backpropagation... that is, if you can compute Jacobians. Thankfully, we have __jax__!"
|
||||
"Here $tr( \\space \\mathbf{J}_x [\\space model(x) \\space])$ is the trace of the Jacobian of $model(x)$ w.r.t. $x$. Now all it takes is to minimize the second objective with backpropagation... that is, if you can compute Jacobians. Thankfully, we have __jax__!"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -50,7 +50,7 @@ plt.scatter(*sample_batch(10**4).T, alpha=0.1)
|
||||
|
||||
## Compute score matching objective
|
||||
|
||||
The method we apply here was originally proposed by [Hyvarinen et al. (2005)](http://jmlr.org/papers/volume6/hyvarinen05a/old.pdf). The idea behind score matching is to __learn scores:__ the gradients of $\log p(x)$ w.r.t. $x$. When trained this model can "improve" a sample $x$ by changing it in the direction of highest log-probability. However, training such model can get tricky. When predicting a continuous variable, ML folks usually minimize squared error:
|
||||
The method we apply here was originally proposed by [Hyvarinen et al. (2005)](http://jmlr.org/papers/volume6/hyvarinen05a/old.pdf). The idea behind score matching is to __learn scores:__ the gradients of $\log p(x)$ w.r.t. $x$. When trained this model can "improve" a sample $x$ by changing it in the direction of highest log-probability. However, training such a model can get tricky. When predicting a continuous variable, ML folks usually minimize squared error:
|
||||
|
||||
$$ L_{mse} = E_{x \sim p(x)} \left\lVert model(x) - \nabla_x \log p(x) \right\lVert_2^2 $$
|
||||
|
||||
@ -58,7 +58,7 @@ One can't minimize this explicitly because the real $\nabla_x \log p(x)$ is usua
|
||||
|
||||
$$ L_{matching} = E_{x \sim p(x)} \space tr( \space \mathbf{J}_x [\space model(x) \space]) + \frac12 \left\Vert model(x) \right\lVert_2^2 $$
|
||||
|
||||
Here $tr( \space \mathbf{J}_x [\space model(x) \space])$ is a trace of Jacobian of $model(x)$ w.r.t. $x$. Now all it takes is to minimize the second objective with backpropagation... that is, if you can compute Jacobians. Thankfully, we have __jax__!
|
||||
Here $tr( \space \mathbf{J}_x [\space model(x) \space])$ is the trace of the Jacobian of $model(x)$ w.r.t. $x$. Now all it takes is to minimize the second objective with backpropagation... that is, if you can compute Jacobians. Thankfully, we have __jax__!
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: 98wjxKcNG6TI
|
||||
|
@ -774,7 +774,7 @@
|
||||
"\n",
|
||||
"These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.\n",
|
||||
"\n",
|
||||
"When we call the compiled fuction again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:"
|
||||
"When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
@ -349,7 +349,7 @@ Notice that the print statements execute, but rather than printing the data we p
|
||||
|
||||
These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.
|
||||
|
||||
When we call the compiled fuction again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:
|
||||
When we call the compiled function again on matching inputs, no re-compilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:
|
||||
|
||||
```{code-cell} ipython3
|
||||
:id: xGntvzNH7skE
|
||||
|
@ -407,7 +407,7 @@
|
||||
"id": "8E7ISmwju0x1"
|
||||
},
|
||||
"source": [
|
||||
"While this might seem like a handful at first, if you've seen code that uses `jnp.einsum` you are already familiar with this approach. The `einsum` function interprets an expression such as `nk,km->nm` assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right hand side of the `->` separator. While `einsum` never lets you interact with named axes directly, they do appear naturally in its implementation. `xmap` is a _generalized einsum_ because named axes are now first-class and you get to implement the function that can manipulate them.\n",
|
||||
"While this might seem like a handful at first, if you've seen code that uses `jnp.einsum` you are already familiar with this approach. The `einsum` function interprets an expression such as `nk,km->nm` assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right-hand side of the `->` separator. While `einsum` never lets you interact with named axes directly, they do appear naturally in its implementation. `xmap` is a _generalized einsum_ because named axes are now first-class and you get to implement the function that can manipulate them.\n",
|
||||
"\n",
|
||||
"Continuing this analogy, `xmap(my_func, ...)` from the above example is equivalent to `jnp.einsum('bx->xb')`. But of course not every `xmap`ped function will have an equivalent `einsum`.\n",
|
||||
"\n",
|
||||
@ -543,7 +543,7 @@
|
||||
"\n",
|
||||
"> While the rule for broadcasting named axes might seem like an arbitrary extension of the NumPy model, it is actually consistent with it.\n",
|
||||
">\n",
|
||||
"> Broadcasting first looks for pairs of dimensions it considers as equivalent in its both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.\n",
|
||||
"> Broadcasting first looks for pairs of dimensions it considers as equivalent in both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.\n",
|
||||
">\n",
|
||||
"> Now, in the positional world the way NumPy broadcasting chooses to form the pairs is by right-aligning the shapes. But our axes are named, so there is a straightforward way of finding equivalent axes: just check their names for equality!"
|
||||
]
|
||||
@ -603,7 +603,7 @@
|
||||
"\n",
|
||||
"Similarly to how we have extended reductions with support for named axes, we've also made it possible to contract over named axes using `jnp.einsum`.\n",
|
||||
"\n",
|
||||
"Operands and results still use a convention of one letter per positional axes, but now it is also possible to mention named axes in curly braces. For example `n{b,k}` implies that a value will have a single positional dimension `n` and named dimensions `b` and `k` (their order doesn't matter). Following the usual einsum semantics, any named axes that appears in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).\n",
|
||||
"Operands and results still use a convention of one letter per positional axis, but now it is also possible to mention named axes in curly braces. For example, `n{b,k}` implies that a value will have a single positional dimension `n` and named dimensions `b` and `k` (their order doesn't matter). Following the usual einsum semantics, any named axes that appear in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).\n",
|
||||
"\n",
|
||||
"It is acceptable to omit a named dimension from _all arguments and the result_ in which case it will be treated according to the usual broadcasting semantics. However, it is not acceptable to mention a named axis in one argument that has it in its named shape and skip it in another argument that also has it in its named shape. Of course, skipping it in the arguments that don't have it is required.\n",
|
||||
"\n",
|
||||
@ -674,7 +674,7 @@
|
||||
"id": "ZHoCsWkCEnKt"
|
||||
},
|
||||
"source": [
|
||||
"## Parallelism suport\n",
|
||||
"## Parallelism support\n",
|
||||
"\n",
|
||||
"While the new programming paradigm can be nice at times, the killer feature of `xmap` is its ability to parallelize code over supercomputer-scale hardware meshes!\n",
|
||||
"\n",
|
||||
@ -682,7 +682,7 @@
|
||||
"\n",
|
||||
"In all the previous examples, we haven't said a word about parallelism and for a good reason. By default `xmap` doesn't perform any parallelization and vectorizes the computation in the same way `vmap` does (i.e. it still executes on a single device). To partition the computation over multiple accelerators we have to introduce one more concept: _resource axes_.\n",
|
||||
"\n",
|
||||
"The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hadware and memory, but before the program is to be executed, they have to be placed somewhere. The default (`vmap`-like) evaluation style pays a high memory cost on the deafult JAX device. By mapping logical axes to (one or more) resource axes through the `axis_resources` argument, we can control how `xmap` evaluates the computation."
|
||||
"The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hardware and memory, but before the program is to be executed, they have to be placed somewhere. The default (`vmap`-like) evaluation style pays a high memory cost on the default JAX device. By mapping logical axes to (one or more) resource axes through the `axis_resources` argument, we can control how `xmap` evaluates the computation."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -723,7 +723,7 @@
|
||||
"\n",
|
||||
"Well, it depends, but one good choice is... a hardware mesh!\n",
|
||||
"\n",
|
||||
"For our purposes a mesh is an nd-array of devices with named axes. But, beacuse NumPy doesn't support named axes (that's our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from `jax.devices()` or `jax.local_devices()`) and a tuple of resource axis names of length matching the rank of the array."
|
||||
"For our purposes a mesh is an nd-array of devices with named axes. But, because NumPy doesn't support named axes (that's our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from `jax.devices()` or `jax.local_devices()`) and a tuple of resource axis names of length matching the rank of the array."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -819,11 +819,11 @@
|
||||
"source": [
|
||||
"### Is my data replicated? Or partitioned? Where is it?\n",
|
||||
"\n",
|
||||
"Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if an only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.\n",
|
||||
"Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if and only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.\n",
|
||||
"\n",
|
||||
"For example, assume that we're in an `xmap` that had `axis_resources={'a': 'x', 'b': 'y'}` specified (i.e. we are running the computation over a 2D mesh with `x` and `y` axes with sizes 2 and 3 respectively). Then:\n",
|
||||
"* An array of type `f32[(5, 5), {}]` is completely replicated over the whole mesh. All devices store a local copy of the value.\n",
|
||||
"* An array of type `f32[(6,), {'a': 8}]` is partitioned over mesh axis `x`, beacuse it has `'a'` in its named shape, and `'a'` is mapped to `x`. It is replicated over mesh axis `y`. To put it differently, all devices in a slice of the mesh with the same `x` coordinate will store a local copy of a chunk of this array. But, mesh slices with different `x` coordinates will store different chunks of the data.\n",
|
||||
"* An array of type `f32[(6,), {'a': 8}]` is partitioned over mesh axis `x`, because it has `'a'` in its named shape, and `'a'` is mapped to `x`. It is replicated over mesh axis `y`. To put it differently, all devices in a slice of the mesh with the same `x` coordinate will store a local copy of a chunk of this array. But, mesh slices with different `x` coordinates will store different chunks of the data.\n",
|
||||
"* An array of type `f32[(), {'a': 8, 'c': 7}]` is partitioned just like in the previous case: split over the `x` mesh axis and replicated over the `y` axis. Named dimensions with no resources specified are no different than positional dimensions when considering partitioning, so `'c'` has no influence on it.\n",
|
||||
"* An array of type `f32[(), {'a': 8, 'b': 12}]` is completely partitioned over the whole mesh. Every device holds a distinct chunk of the data."
|
||||
]
|
||||
|
@ -274,7 +274,7 @@ assert (y == x.T).all() # The first dimension was removed from x and then re-in
|
||||
|
||||
+++ {"id": "8E7ISmwju0x1"}
|
||||
|
||||
While this might seem like a handful at first, if you've seen code that uses `jnp.einsum` you are already familiar with this approach. The `einsum` function interprets an expression such as `nk,km->nm` assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right hand side of the `->` separator. While `einsum` never lets you interact with named axes directly, they do appear naturally in its implementation. `xmap` is a _generalized einsum_ because named axes are now first-class and you get to implement the function that can manipulate them.
|
||||
While this might seem like a handful at first, if you've seen code that uses `jnp.einsum` you are already familiar with this approach. The `einsum` function interprets an expression such as `nk,km->nm` assigning names (each letter is considered a separate name) to positional axes, performing necessary broadcasts and reductions, and finally putting back the results in positional axes, according to the order given by the right-hand side of the `->` separator. While `einsum` never lets you interact with named axes directly, they do appear naturally in its implementation. `xmap` is a _generalized einsum_ because named axes are now first-class and you get to implement the function that can manipulate them.
|
||||
|
||||
Continuing this analogy, `xmap(my_func, ...)` from the above example is equivalent to `jnp.einsum('bx->xb')`. But of course not every `xmap`ped function will have an equivalent `einsum`.
|
||||
|
||||
@ -375,7 +375,7 @@ No shape errors can occur when operating over named axes, because `xmap` enforce
|
||||
|
||||
> While the rule for broadcasting named axes might seem like an arbitrary extension of the NumPy model, it is actually consistent with it.
|
||||
>
|
||||
> Broadcasting first looks for pairs of dimensions it considers as equivalent in its both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.
|
||||
> Broadcasting first looks for pairs of dimensions it considers as equivalent in both operands. For all matched pairs, it asserts that both sizes are equal or one of them is 1. All unpaired dimensions are carried over to the result.
|
||||
>
|
||||
> Now, in the positional world the way NumPy broadcasting chooses to form the pairs is by right-aligning the shapes. But our axes are named, so there is a straightforward way of finding equivalent axes: just check their names for equality!
|
||||
|
||||
@ -420,7 +420,7 @@ positional_broadcast_and_reduce(jnp.arange(2, dtype=np.float32),
|
||||
|
||||
Similarly to how we have extended reductions with support for named axes, we've also made it possible to contract over named axes using `jnp.einsum`.
|
||||
|
||||
Operands and results still use a convention of one letter per positional axes, but now it is also possible to mention named axes in curly braces. For example `n{b,k}` implies that a value will have a single positional dimension `n` and named dimensions `b` and `k` (their order doesn't matter). Following the usual einsum semantics, any named axes that appears in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).
|
||||
Operands and results still use a convention of one letter per positional axis, but now it is also possible to mention named axes in curly braces. For example, `n{b,k}` implies that a value will have a single positional dimension `n` and named dimensions `b` and `k` (their order doesn't matter). Following the usual einsum semantics, any named axes that appear in inputs, but do not appear in an output will be contracted (summed after all multiplications are performed).
|
||||
|
||||
It is acceptable to omit a named dimension from _all arguments and the result_ in which case it will be treated according to the usual broadcasting semantics. However, it is not acceptable to mention a named axis in one argument that has it in its named shape and skip it in another argument that also has it in its named shape. Of course, skipping it in the arguments that don't have it is required.
|
||||
|
||||
@ -466,7 +466,7 @@ xmap(lambda x: lax.pshuffle(x, 'i', list(reversed(range(8)))),
|
||||
|
||||
+++ {"id": "ZHoCsWkCEnKt"}
|
||||
|
||||
## Parallelism suport
|
||||
## Parallelism support
|
||||
|
||||
While the new programming paradigm can be nice at times, the killer feature of `xmap` is its ability to parallelize code over supercomputer-scale hardware meshes!
|
||||
|
||||
@ -474,7 +474,7 @@ While the new programming paradigm can be nice at times, the killer feature of `
|
||||
|
||||
In all the previous examples, we haven't said a word about parallelism and for a good reason. By default `xmap` doesn't perform any parallelization and vectorizes the computation in the same way `vmap` does (i.e. it still executes on a single device). To partition the computation over multiple accelerators we have to introduce one more concept: _resource axes_.
|
||||
|
||||
The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hadware and memory, but before the program is to be executed, they have to be placed somewhere. The default (`vmap`-like) evaluation style pays a high memory cost on the deafult JAX device. By mapping logical axes to (one or more) resource axes through the `axis_resources` argument, we can control how `xmap` evaluates the computation.
|
||||
The basic idea is that logical axes (the ones that appear in named shapes) assume that we have abundant hardware and memory, but before the program is to be executed, they have to be placed somewhere. The default (`vmap`-like) evaluation style pays a high memory cost on the default JAX device. By mapping logical axes to (one or more) resource axes through the `axis_resources` argument, we can control how `xmap` evaluates the computation.
|
||||
|
||||
```{code-cell}
|
||||
:id: NnggOzOD8rl1
|
||||
@ -500,7 +500,7 @@ Both `local_matmul` and `distr_matmul` implement matrix multiplication, but `dis
|
||||
|
||||
Well, it depends, but one good choice is... a hardware mesh!
|
||||
|
||||
For our purposes a mesh is an nd-array of devices with named axes. But, beacuse NumPy doesn't support named axes (that's our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from `jax.devices()` or `jax.local_devices()`) and a tuple of resource axis names of length matching the rank of the array.
|
||||
For our purposes a mesh is an nd-array of devices with named axes. But, because NumPy doesn't support named axes (that's our extension!), the meshes are represented by a pair of an nd-array of JAX device objects (as obtained from `jax.devices()` or `jax.local_devices()`) and a tuple of resource axis names of length matching the rank of the array.
|
||||
|
||||
+++ {"id": "x3EXj1TMwZtS"}
|
||||
|
||||
@ -556,11 +556,11 @@ Anyway, the best part of it is that specifying `axis_resources` **never changes
|
||||
|
||||
### Is my data replicated? Or partitioned? Where is it?
|
||||
|
||||
Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if an only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.
|
||||
Named axes also give us a neat way of reasoning about partitioning and replication. A value is partitioned over a mesh axis if and only if it has a named axis that has been mapped to that mesh axis in its shape. Otherwise, it will be replicated over all slices along that axis.
|
||||
|
||||
For example, assume that we're in an `xmap` that had `axis_resources={'a': 'x', 'b': 'y'}` specified (i.e. we are running the computation over a 2D mesh with `x` and `y` axes with sizes 2 and 3 respectively). Then:
|
||||
* An array of type `f32[(5, 5), {}]` is completely replicated over the whole mesh. All devices store a local copy of the value.
|
||||
* An array of type `f32[(6,), {'a': 8}]` is partitioned over mesh axis `x`, beacuse it has `'a'` in its named shape, and `'a'` is mapped to `x`. It is replicated over mesh axis `y`. To put it differently, all devices in a slice of the mesh with the same `x` coordinate will store a local copy of a chunk of this array. But, mesh slices with different `x` coordinates will store different chunks of the data.
|
||||
* An array of type `f32[(6,), {'a': 8}]` is partitioned over mesh axis `x`, because it has `'a'` in its named shape, and `'a'` is mapped to `x`. It is replicated over mesh axis `y`. To put it differently, all devices in a slice of the mesh with the same `x` coordinate will store a local copy of a chunk of this array. But, mesh slices with different `x` coordinates will store different chunks of the data.
|
||||
* An array of type `f32[(), {'a': 8, 'c': 7}]` is partitioned just like in the previous case: split over the `x` mesh axis and replicated over the `y` axis. Named dimensions with no resources specified are no different than positional dimensions when considering partitioning, so `'c'` has no influence on it.
|
||||
* An array of type `f32[(), {'a': 8, 'b': 12}]` is completely partitioned over the whole mesh. Every device holds a distinct chunk of the data.
|
||||
|
||||
|
@ -29,7 +29,7 @@ That is:
|
||||
contains pytrees, is considered a pytree.
|
||||
|
||||
For each entry in the pytree container registry, a container-like type is
|
||||
registered with a pair of functions which specify how to convert an instance of
|
||||
registered with a pair of functions that specify how to convert an instance of
|
||||
the container type to a `(children, metadata)` pair and how to convert such a
|
||||
pair back to an instance of the container type. Using these functions, JAX can
|
||||
canonicalize any tree of registered container objects into tuples.
|
||||
@ -64,7 +64,7 @@ pytrees with values in the argument pytrees, the parameter pytrees are often
|
||||
constrained to be tree prefixes of the argument pytrees.
|
||||
|
||||
For example, if we pass the following input to {func}`~jax.vmap` (note that the input
|
||||
arguments to a function considered a tuple):
|
||||
arguments to a function are considered a tuple):
|
||||
|
||||
```
|
||||
(a1, {"k1": a2, "k2": a3})
|
||||
@ -281,7 +281,7 @@ class RegisteredSpecial2(Special):
|
||||
show_example(RegisteredSpecial2(1., 2.))
|
||||
```
|
||||
|
||||
JAX needs sometimes to compare `treedef` for equality. Therefore, care must be
|
||||
JAX sometimes needs to compare `treedef` for equality. Therefore, care must be
|
||||
taken to ensure that the auxiliary data specified in the flattening recipe
|
||||
supports a meaningful equality comparison.
|
||||
|
||||
|
@ -3,8 +3,8 @@ Rank promotion warning
|
||||
|
||||
`NumPy broadcasting rules
|
||||
<https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#general-broadcasting-rules>`_
|
||||
allow automatic promotion of arguments from one rank (number of array axes) to
|
||||
another. This behavior can be convenient when intended but can also lead to
|
||||
allow the automatic promotion of arguments from one rank (number of array axes)
|
||||
to another. This behavior can be convenient when intended but can also lead to
|
||||
surprising bugs where a silent rank promotion masks an underlying shape error.
|
||||
|
||||
Here's an example of rank promotion:
|
||||
|
@ -18,7 +18,8 @@ kernelspec:
|
||||
```
|
||||
|
||||
At its core, JAX is an extensible system for transforming numerical functions.
|
||||
This section will discuss four that are of primary interest: {func}`grad`, {func}`jit`, {func}`vmap`, and {func}`pmap`.
|
||||
This section will discuss four transformations that are of primary interest:
|
||||
{func}`grad`, {func}`jit`, {func}`vmap`, and {func}`pmap`.
|
||||
|
||||
## Automatic differentiation with `grad`
|
||||
|
||||
|
@ -738,7 +738,7 @@ def grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
has_aux: bool = False, holomorphic: bool = False,
|
||||
allow_int: bool = False,
|
||||
reduce_axes: Sequence[AxisName] = ()) -> Callable:
|
||||
"""Creates a function which evaluates the gradient of ``fun``.
|
||||
"""Creates a function that evaluates the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
fun: Function to be differentiated. Its arguments at positions specified by
|
||||
@ -809,7 +809,7 @@ def value_and_grad(fun: Callable, argnums: Union[int, Sequence[int]] = 0,
|
||||
has_aux: bool = False, holomorphic: bool = False,
|
||||
allow_int: bool = False, reduce_axes: Sequence[AxisName] = ()
|
||||
) -> Callable[..., Tuple[Any, Any]]:
|
||||
"""Create a function which evaluates both ``fun`` and the gradient of ``fun``.
|
||||
"""Create a function that evaluates both ``fun`` and the gradient of ``fun``.
|
||||
|
||||
Args:
|
||||
fun: Function to be differentiated. Its arguments at positions specified by
|
||||
@ -1875,7 +1875,7 @@ def jvp(fun: Callable, primals, tangents) -> Tuple[Any, Any]:
|
||||
array, scalar, or standard Python container of arrays or scalars.
|
||||
primals: The primal values at which the Jacobian of ``fun`` should be
|
||||
evaluated. Should be either a tuple or a list of arguments,
|
||||
and its length should equal to the number of positional parameters of
|
||||
and its length should be equal to the number of positional parameters of
|
||||
``fun``.
|
||||
tangents: The tangent vector for which the Jacobian-vector product should be
|
||||
evaluated. Should be either a tuple or a list of tangents, with the same
|
||||
|
@ -132,7 +132,7 @@ class custom_jvp(Generic[ReturnValue]):
|
||||
jvp: a Python callable representing the custom JVP rule. When there are no
|
||||
``nondiff_argnums``, the ``jvp`` function should accept two arguments,
|
||||
where the first is a tuple of primal inputs and the second is a tuple of
|
||||
tangent inputs. The lengths of both tuples is equal to the number of
|
||||
tangent inputs. The lengths of both tuples are equal to the number of
|
||||
parameters of the ``custom_jvp`` function. The ``jvp`` function should
|
||||
produce as output a pair where the first element is the primal output
|
||||
and the second element is the tangent output. Elements of the input and
|
||||
|
@ -1569,7 +1569,7 @@ def sort(operand: Union[Array, Sequence[Array]], dimension: int = -1,
|
||||
|
||||
def sort_key_val(keys: Array, values: Array, dimension: int = -1,
|
||||
is_stable: bool = True) -> Tuple[Array, Array]:
|
||||
"""Sorts ``keys`` along ``dimension`` and applies same permutation to ``values``."""
|
||||
"""Sorts ``keys`` along ``dimension`` and applies the same permutation to ``values``."""
|
||||
dimension = canonicalize_axis(dimension, len(keys.shape))
|
||||
k, v = sort_p.bind(keys, values, dimension=dimension, is_stable=is_stable, num_keys=1)
|
||||
return k, v
|
||||
|
@ -208,8 +208,8 @@ def vectorize(pyfunc, *, excluded=frozenset(), signature=None):
|
||||
Returns:
|
||||
Vectorized version of the given function.
|
||||
|
||||
Here a few examples of how one could write vectorized linear algebra routines
|
||||
using :func:`vectorize`:
|
||||
Here are a few examples of how one could write vectorized linear algebra
|
||||
routines using :func:`vectorize`:
|
||||
|
||||
>>> from functools import partial
|
||||
|
||||
|
@ -72,7 +72,7 @@ def minimize(
|
||||
|
||||
Args:
|
||||
fun: the objective function to be minimized, ``fun(x, *args) -> float``,
|
||||
where ``x`` is an 1-D array with shape ``(n,)`` and ``args`` is a tuple
|
||||
where ``x`` is a 1-D array with shape ``(n,)`` and ``args`` is a tuple
|
||||
of the fixed parameters needed to completely specify the function.
|
||||
``fun`` must support differentiation.
|
||||
x0: initial guess. Array of real elements of size ``(n,)``, where ``n`` is
|
||||
|
@ -303,7 +303,7 @@ Using :func:`call` to call a TensorFlow function, with reverse-mode autodiff sup
|
||||
Another possible use for host computation is to invoke a library written for
|
||||
another framework, such as TensorFlow.
|
||||
In this case it becomes interesting to support JAX autodiff for host callbacks
|
||||
by defering to the autodiff mechanism in TensorFlow,
|
||||
by deferring to the autodiff mechanism in TensorFlow,
|
||||
using the :func:`jax.custom_vjp` mechanism.
|
||||
|
||||
This is relatively easy to do, once one understands both the JAX custom VJP
|
||||
@ -363,7 +363,7 @@ error during the processing of the callback (whether raised by the user-code
|
||||
itself or due to a mismatch of the returned value and the expected return_shape)
|
||||
we send the device a "fake" result of shape ``int8[12345]``.
|
||||
This will make the device
|
||||
computation abort because the received data is different than then one that
|
||||
computation abort because the received data is different than the one that
|
||||
it expects. On CPU the runtime will crash with a distinctive error message:
|
||||
|
||||
```
|
||||
@ -416,25 +416,20 @@ for the C++ outfeed `receiver backend
|
||||
<https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/xla/python/outfeed_receiver.cc>`_.
|
||||
|
||||
* ``TF_CPP_MIN_LOG_LEVEL=0``: will turn on INFO logging, needed for all below.
|
||||
* ``TF_CPP_MIN_VLOG_LEVEL=3``: will turn make all VLOG logging up to level 3
|
||||
behave like INFO logs. This may be too much, but you will see which
|
||||
modules are logging relevant info, and then you can select which modules
|
||||
to log from:
|
||||
* `TF_CPP_VMODULE=<module_name>=3`` (the module name can be either C++ or
|
||||
* ``TF_CPP_MIN_VLOG_LEVEL=3``: will make all VLOG logging up to level 3 behave
|
||||
like INFO logs. This may be too much, but you will see which modules are
|
||||
logging relevant info, and then you can select which modules to log from.
|
||||
* ``TF_CPP_VMODULE=<module_name>=3`` (the module name can be either C++ or
|
||||
Python, without the extension).
|
||||
|
||||
You should also use the ``--verbosity=2`` flag so that you see the logs
|
||||
from Python.
|
||||
|
||||
For example, you can try to enable logging in the ``host_callback`` module:
|
||||
```
|
||||
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
|
||||
```
|
||||
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=host_callback=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``
|
||||
|
||||
If you want to enable logging in lower-level implementation modules try:
|
||||
```
|
||||
TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple
|
||||
```
|
||||
``TF_CPP_MIN_LOG_LEVEL=0 TF_CPP_VMODULE=outfeed_receiver=3,host_callback=3,outfeed_receiver_py=3,outfeed_thunk=3,infeed_thunk=3,cpu_transfer_manager=3,cpu_runtime=3,xfeed_manager=3,pjrt_client=3 python tests/host_callback_test.py --verbosity=2 HostCallbackIdTapTest.test_tap_jit_simple``
|
||||
|
||||
(For bazel tests use --test_arg=--vmodule=...
|
||||
|
||||
@ -1829,7 +1824,7 @@ def barrier_wait(logging_name: Optional[str] = None):
|
||||
"""Blocks the calling thread until all current outfeed is processed.
|
||||
|
||||
Waits until all callbacks from computations already running on all devices
|
||||
has been received and processed by the Python callbacks. Raises
|
||||
have been received and processed by the Python callbacks. Raises
|
||||
CallbackException if there were exceptions while processing the callbacks.
|
||||
|
||||
This works by enqueueing a special tap computation to all devices to which
|
||||
|
@ -58,7 +58,7 @@ special `loops.scope` object and use `for` loops over special
|
||||
|
||||
Loops constructed with `range` must have literal constant bounds. If you need
|
||||
loops with dynamic bounds, you can use the more general `while_range` iterator.
|
||||
However, in that case that `grad` transformation is not supported::
|
||||
However, in that case the `grad` transformation is not supported::
|
||||
|
||||
s.idx = start
|
||||
for _ in s.while_range(lambda: s.idx < end):
|
||||
@ -93,7 +93,7 @@ Restrictions:
|
||||
* Once the loop starts all updates to loop state must be with new values of the
|
||||
same abstract values as the values on loop start.
|
||||
* For a `while` loop, the conditional function is not allowed to modify the
|
||||
scope state. This is a checked error. Also, for `while` loops the `grad`
|
||||
scope state. This is a checked error. Also, for `while` loops, the `grad`
|
||||
transformation does not work. An alternative that allows `grad` is a bounded
|
||||
loop (`range`).
|
||||
|
||||
|
@ -210,7 +210,7 @@ def serial_loop(name: ResourceAxisName, length: int):
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mesh(devices: np.ndarray, axis_names: Sequence[ResourceAxisName]):
|
||||
"""Declare the hardware resources available in scope of this manager.
|
||||
"""Declare the hardware resources available in the scope of this manager.
|
||||
|
||||
In particular, all ``axis_names`` become valid resource names inside the
|
||||
managed block and can be used e.g. in the ``axis_resources`` argument of
|
||||
@ -391,9 +391,9 @@ def xmap(fun: Callable,
|
||||
While it is possible to assign multiple axis names to a single resource axis,
|
||||
care has to be taken to ensure that none of those named axes co-occur in a
|
||||
``named_shape`` of any value in the named program. At the moment this is
|
||||
**completely unchecked** and will result in **undefined behavior**. Final
|
||||
release of :py:func:`xmap` will enforce this invariant, but it is work
|
||||
in progress.
|
||||
**completely unchecked** and will result in **undefined behavior**. The
|
||||
final release of :py:func:`xmap` will enforce this invariant, but it is a
|
||||
work in progress.
|
||||
|
||||
Note that you do not have to worry about any of this for as long as no
|
||||
resource axis is repeated in ``axis_resources.values()``.
|
||||
@ -401,7 +401,7 @@ def xmap(fun: Callable,
|
||||
Note that any assignment of ``axis_resources`` doesn't ever change the
|
||||
results of the computation, but only how it is carried out (e.g. how many
|
||||
devices are used). This makes it easy to try out various ways of
|
||||
partitioning a single program in many distributed scenarions (both small- and
|
||||
partitioning a single program in many distributed scenarios (both small- and
|
||||
large-scale), to maximize the performance. As such, :py:func:`xmap` can be
|
||||
seen as a way to seamlessly interpolate between :py:func:`vmap` and
|
||||
:py:func:`pmap`-style execution.
|
||||
@ -489,7 +489,7 @@ def xmap(fun: Callable,
|
||||
out_axes={}) # Loss is reduced over all axes, including batch!
|
||||
|
||||
.. note::
|
||||
When using ``axis_resources`` along with a mesh that is controled by
|
||||
When using ``axis_resources`` along with a mesh that is controlled by
|
||||
multiple JAX hosts, keep in mind that in any given process :py:func:`xmap`
|
||||
only expects the data slice that corresponds to its local devices to be
|
||||
specified. This is in line with the current multi-host :py:func:`pmap`
|
||||
|
@ -213,7 +213,7 @@ def sgd(step_size):
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
that maps the iteration index to a positive scalar.
|
||||
|
||||
Returns:
|
||||
An (init_fun, update_fun, get_params) triple.
|
||||
@ -233,7 +233,7 @@ def momentum(step_size: Schedule, mass: float):
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
that maps the iteration index to a positive scalar.
|
||||
mass: positive scalar representing the momentum coefficient.
|
||||
|
||||
Returns:
|
||||
@ -260,7 +260,7 @@ def nesterov(step_size: Schedule, mass: float):
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
that maps the iteration index to a positive scalar.
|
||||
mass: positive scalar representing the momentum coefficient.
|
||||
|
||||
Returns:
|
||||
@ -290,7 +290,7 @@ def adagrad(step_size, momentum=0.9):
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
that maps the iteration index to a positive scalar.
|
||||
momentum: optional, a positive scalar value for momentum
|
||||
|
||||
Returns:
|
||||
@ -324,7 +324,7 @@ def rmsprop(step_size, gamma=0.9, eps=1e-8):
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
that maps the iteration index to a positive scalar.
|
||||
gamma: Decay parameter.
|
||||
eps: Epsilon parameter.
|
||||
|
||||
@ -355,7 +355,7 @@ def rmsprop_momentum(step_size, gamma=0.9, eps=1e-8, momentum=0.9):
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
that maps the iteration index to a positive scalar.
|
||||
gamma: Decay parameter.
|
||||
eps: Epsilon parameter.
|
||||
momentum: Momentum parameter.
|
||||
@ -386,7 +386,7 @@ def adam(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
that maps the iteration index to a positive scalar.
|
||||
b1: optional, a positive scalar value for beta_1, the exponential decay rate
|
||||
for the first moment estimates (default 0.9).
|
||||
b2: optional, a positive scalar value for beta_2, the exponential decay rate
|
||||
@ -422,7 +422,7 @@ def adamax(step_size, b1=0.9, b2=0.999, eps=1e-8):
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
that maps the iteration index to a positive scalar.
|
||||
b1: optional, a positive scalar value for beta_1, the exponential decay rate
|
||||
for the first moment estimates (default 0.9).
|
||||
b2: optional, a positive scalar value for beta_2, the exponential decay rate
|
||||
@ -460,7 +460,7 @@ def sm3(step_size, momentum=0.9):
|
||||
|
||||
Args:
|
||||
step_size: positive scalar, or a callable representing a step size schedule
|
||||
that maps the iteration index to positive scalar.
|
||||
that maps the iteration index to a positive scalar.
|
||||
momentum: optional, a positive scalar value for momentum
|
||||
|
||||
Returns:
|
||||
|
@ -56,9 +56,9 @@ def pjit(fun: Callable,
|
||||
version of ``fun`` would not fit in a single device's memory, or to speed up
|
||||
``fun`` by running each operation in parallel across multiple devices.
|
||||
|
||||
The partitioning over devices happens automatically based on
|
||||
propagation of input partitioning specified in ``in_axis_resources`` and
|
||||
output partitioning specified in ``out_axis_resources``. The resources
|
||||
The partitioning over devices happens automatically based on the
|
||||
propagation of the input partitioning specified in ``in_axis_resources`` and
|
||||
the output partitioning specified in ``out_axis_resources``. The resources
|
||||
specified in those two arguments must refer to mesh axes, as defined by
|
||||
the :py:func:`jax.experimental.maps.mesh` context manager. Note that the mesh
|
||||
definition at ``pjit`` application time is ignored, and the returned function
|
||||
@ -84,7 +84,7 @@ def pjit(fun: Callable,
|
||||
mesh axis size, and outputs will be similarly sized according to the local
|
||||
mesh. ``fun`` will still be executed across *all* devices in the mesh,
|
||||
including those from other processes, and will be given a global view of the
|
||||
data spread accross multiple processes as a single array. However, outside
|
||||
data spread across multiple processes as a single array. However, outside
|
||||
of ``pjit`` every process only "sees" its local piece of the input and output,
|
||||
corresponding to its local sub-mesh.
|
||||
|
||||
@ -140,7 +140,7 @@ def pjit(fun: Callable,
|
||||
|
||||
Returns:
|
||||
A wrapped version of ``fun``, set up for just-in-time compilation and
|
||||
automatic partitioned by the mesh available at each call site.
|
||||
automaticly partitioned by the mesh available at each call site.
|
||||
|
||||
For example, a convolution operator can be automatically partitioned over
|
||||
an arbitrary set of devices by a single ```pjit`` application:
|
||||
|
Loading…
x
Reference in New Issue
Block a user