mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #20897 from jakevdp:doc-updates
PiperOrigin-RevId: 627486188
This commit is contained in:
commit
8ead2df7bb
@ -15,7 +15,8 @@ kernelspec:
|
||||
(automatic-differentiation)=
|
||||
# Automatic differentiation
|
||||
|
||||
In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general automatic differentiation (autodiff) system. Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
|
||||
In this section, you will learn about fundamental applications of automatic differentiation (autodiff) in JAX. JAX has a pretty general autodiff system.
|
||||
Computing gradients is a critical part of modern machine learning methods, and this tutorial will walk you through a few introductory autodiff topics, such as:
|
||||
|
||||
- {ref}`automatic-differentiation-taking-gradients`
|
||||
- {ref}`automatic-differentiation-linear logistic regression`
|
||||
|
@ -55,7 +55,10 @@ Importantly, notice that the jaxpr does not capture the side-effect present in t
|
||||
This is a feature, not a bug: JAX transformations are designed to understand side-effect-free (a.k.a. functionally pure) code.
|
||||
If *pure function* and *side-effect* are unfamiliar terms, this is explained in a little more detail in [🔪 JAX - The Sharp Bits 🔪: Pure Functions](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#pure-functions).
|
||||
|
||||
Of course, impure functions can still be written and even run, but JAX gives no guarantees about their behaviour under transformations. However, as a rule of thumb, you can expect (but shouldn't rely on) the side-effects of a JIT-compiled function to run once (during the first call), and never again, due to JAX's traced execution model.
|
||||
Impure functions are dangerous because under JAX transformations they are likely not to behave as intended; they might fail silently, or produce surprising downstream errors like leaked Tracers.
|
||||
Moreover, JAX often can't detect when side effects are present.
|
||||
(If you want debug printing, use {func}`jax.debug.print`. To express general side-effects at the cost of performance, see {func}`jax.experimental.io_callback`.
|
||||
To check for tracer leaks at the cost of performance, use with {func}`jax.check_tracer_leaks`).
|
||||
|
||||
When tracing, JAX wraps each argument by a *tracer* object. These tracers then record all JAX operations performed on them during the function call (which happens in regular Python). Then, JAX uses the tracer records to reconstruct the entire function. The output of that reconstruction is the jaxpr. Since the tracers do not record the Python side-effects, they do not appear in the jaxpr. However, the side-effects still happen during the trace itself.
|
||||
|
||||
@ -126,7 +129,6 @@ Here's what just happened:
|
||||
1) We defined `selu_jit` as the compiled version of `selu`.
|
||||
|
||||
2) We called `selu_jit` once on `x`. This is where JAX does its tracing -- it needs to have some inputs to wrap in tracers, after all. The jaxpr is then compiled using XLA into very efficient code optimized for your GPU or TPU. Finally, the compiled code is executed to satisfy the call. Subsequent calls to `selu_jit` will use the compiled code directly, skipping the python implementation entirely.
|
||||
|
||||
(If we didn't include the warm-up call separately, everything would still work, but then the compilation time would be included in the benchmark. It would still be faster, because we run many loops in the benchmark, but it wouldn't be a fair comparison.)
|
||||
|
||||
3) We timed the execution speed of the compiled version. (Note the use of {func}`~jax.block_until_ready`, which is required due to JAX's {ref}`async-dispatch`).
|
||||
@ -219,23 +221,6 @@ def g_jit_decorated(x, n):
|
||||
print(g_jit_decorated(10, 20))
|
||||
```
|
||||
|
||||
## When to use JIT
|
||||
|
||||
In many of the examples above, using `jit` is not worth it:
|
||||
|
||||
```{code-cell}
|
||||
print("g jitted:")
|
||||
%timeit g_jit_correct(10, 20).block_until_ready()
|
||||
|
||||
print("g:")
|
||||
%timeit g(10, 20)
|
||||
```
|
||||
|
||||
This is because {func}`jax.jit` introduces some overhead itself, and so it usually only saves time if the compiled function is nontrivial, or if you will run it numerous times.
|
||||
Fortunately, this is common in machine learning, where we tend to compile a large, complicated model, then run it for millions of iterations.
|
||||
|
||||
Generally, you want to JIT-compile the largest possible chunk of your computation; ideally, the entire update step. This gives the compiler maximum freedom to optimise.
|
||||
|
||||
## JIT and caching
|
||||
|
||||
With the compilation overhead of the first JIT call, understanding how and when {func}`jax.jit` caches previous compilations is key to using it effectively.
|
||||
|
@ -26,7 +26,7 @@ has some important differences.
|
||||
|
||||
### Array creation
|
||||
|
||||
JAX arrays are never constructed directly, but rather are constructed via JAX API functions.
|
||||
We typically don't call the {class}`jax.Array` constructor directly, but rather create arrays via JAX API functions.
|
||||
For example, {mod}`jax.numpy` provides familar NumPy-style array construction functionality
|
||||
such as {func}`jax.numpy.zeros`, {func}`jax.numpy.linspace`, {func}`jax.numpy.arange`, etc.
|
||||
|
||||
@ -93,7 +93,7 @@ key to using JAX effectively, and we'll cover them in detail in later sections.
|
||||
(key-concepts-tracing)=
|
||||
## Tracing
|
||||
|
||||
The magic behind transformations is the notion of a {term}`Tracers <Tracer>`.
|
||||
The magic behind transformations is the notion of a {term}`Tracer`.
|
||||
Tracers are abstract stand-ins for array objects, and are passed to JAX functions in order
|
||||
to extract the sequence of operations that the function encodes.
|
||||
|
||||
@ -119,10 +119,8 @@ of input operations to a transformed sequence of operations.
|
||||
(key-concepts-jaxprs)=
|
||||
## Jaxprs
|
||||
|
||||
JAX has its own intermediate representation for sequences of operations, and these are
|
||||
known as {term}`jaxprs <jaxpr>`. A jaxpr (short for *JAX eXPRession*) represents a list
|
||||
of core units of computation called {term}`primitives <primitive>` that represent the
|
||||
effect of a computation.
|
||||
JAX has its own intermediate representation for sequences of operations, known as a {term}`jaxpr`.
|
||||
A jaxpr (short for *JAX exPRession*) is a simple representation of a functional program, comprising a sequence of {term}`primitive` operations.
|
||||
|
||||
For example, consider the `selu` function we defined above:
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user