Merge pull request #26476 from froystig:aot-doc-traced

PiperOrigin-RevId: 725902103
This commit is contained in:
jax authors 2025-02-11 22:01:21 -08:00
commit 914adaf60c

View File

@ -4,9 +4,10 @@
<!--* freshness: { reviewed: '2024-06-12' } *-->
JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning
a function that is compiled and runs on accelerators or the CPU. As the JIT
acronym indicates, all compilation happens _just-in-time_ for execution.
JAX's `jax.jit` transformation returns a function that, when called,
compiles a computation and runs it on accelerators (or the CPU). As
the JIT acronym indicates, all compilation happens _just-in-time_ for
execution.
Some situations call for _ahead-of-time_ (AOT) compilation instead. When you
want to fully compile prior to execution time, or you want control over when
@ -18,10 +19,14 @@ function/callable output by {func}`jax.jit`, say `f = jax.jit(F)` for some input
callable `F`. When it is invoked with arguments, say `f(x, y)` where `x` and `y`
are arrays, JAX does the following in order:
1. **Stage out** a specialized version of the original Python callable `F` to an
internal representation. The specialization reflects a restriction of `F` to
input types inferred from properties of the arguments `x` and `y` (usually
their shape and element type).
1. **Stage out** a specialized version of the original Python callable
`F` to an internal representation. The specialization reflects a
restriction of `F` to input types inferred from properties of the
arguments `x` and `y` (usually their shape and element type). JAX
carries out this specialization by a process that we call
_tracing_. During tracing, JAX stages the specialization of `F` to
a jaxpr, which is a function in the [Jaxpr intermediate
language](https://jax.readthedocs.io/en/latest/jaxpr.html).
2. **Lower** this specialized, staged-out computation to the XLA compiler's
input language, StableHLO.
@ -31,9 +36,8 @@ are arrays, JAX does the following in order:
4. **Execute** the compiled executable with the arrays `x` and `y` as arguments.
JAX's AOT API gives you direct control over steps #2, #3, and #4 (but [not
#1](#inspecting-staged-out-computations)), plus some other features along the
way. An example:
JAX's AOT API gives you direct control over each of these steps, plus
some other features along the way. An example:
```python
>>> import jax
@ -41,7 +45,13 @@ way. An example:
>>> def f(x, y): return 2 * x + y
>>> x, y = 3, 4
>>> lowered = jax.jit(f).lower(x, y)
>>> traced = jax.jit(f).trace(x, y)
>>> # Print the specialized, staged-out representation (as Jaxpr IR)
>>> print(traced.jaxpr)
{ lambda ; a:i32[] b:i32[]. let c:i32[] = mul 2 a; d:i32[] = add c b in (d,) }
>>> lowered = traced.lower()
>>> # Print lowered HLO
>>> print(lowered.as_text())
@ -67,37 +77,36 @@ Array(10, dtype=int32, weak_type=True)
```
Note that the lowered objects can be used only in the same process
in which they were lowered. For exporting use cases,
see the {ref}`export` APIs.
in which they were lowered. For exporting use cases, see the {ref}`export` APIs.
See the {mod}`jax.stages` documentation for more details on what functionality
the lowering and compiled functions provide.
All optional arguments to `jit`---such as `static_argnums`---are respected in
the corresponding lowering, compilation, and execution.
the corresponding tracing, lowering, compilation, and execution.
In the example above, we can replace the arguments to `lower` with any objects
In the example above, we can replace the arguments to `trace` with any objects
that have `shape` and `dtype` attributes:
```python
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x, y)
Array(10, dtype=int32)
```
More generally, `lower` only needs its arguments to structurally supply what JAX
More generally, `trace` only needs its arguments to structurally supply what JAX
must know for specialization and lowering. For typical array arguments like the
ones above, this means `shape` and `dtype` fields. For static arguments, by
contrast, JAX needs actual array values (more on this
[below](#lowering-with-static-arguments)).
[below](#tracing-with-static-arguments)).
Invoking an AOT-compiled function with arguments that are incompatible with its
lowering raises an error:
tracing raises an error:
```python
>>> x_1d = y_1d = jnp.arange(3)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
@ -105,7 +114,7 @@ Argument 'x' compiled with int32[] and called with int32[3]
Argument 'y' compiled with int32[] and called with int32[3]
>>> x_f = y_f = jnp.float32(72.)
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL
...
Traceback (most recent call last):
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
@ -119,14 +128,14 @@ transformations](#aot-compiled-functions-cannot-be-transformed) such as
`jax.jit`, {func}`jax.grad`, and {func}`jax.vmap`.
## Lowering with static arguments
## Tracing with static arguments
Lowering with static arguments underscores the interaction between options
passed to `jax.jit`, the arguments passed to `lower`, and the arguments needed
Tracing with static arguments underscores the interaction between options
passed to `jax.jit`, the arguments passed to `trace`, and the arguments needed
to invoke the resulting compiled function. Continuing with our example above:
```python
>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)
>>> lowered_with_x = jax.jit(f, static_argnums=0).trace(7, 8).lower()
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
>>> print(lowered_with_x.as_text())
@ -143,30 +152,29 @@ Array(19, dtype=int32, weak_type=True)
```
The result of `lower` is not safe to serialize directly for use
in a different process.
See {ref}`export` for additional APIs for this purpose.
Note that `lower` here takes two arguments as usual, but the subsequent compiled
Note that `trace` here takes two arguments as usual, but the subsequent compiled
function accepts only the remaining non-static second argument. The static first
argument (value 7) is taken as a constant at lowering time and built into the
lowered computation, where it is possibly folded in with other constants. In
this case, its multiplication by 2 is simplified, resulting in the constant 14.
Although the second argument to `lower` above can be replaced by a hollow
Although the second argument to `trace` above can be replaced by a hollow
shape/dtype structure, it is necessary that the static first argument be a
concrete value. Otherwise, lowering would err:
concrete value. Otherwise, tracing errs:
```python
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) # doctest: +SKIP
>>> jax.jit(f, static_argnums=0).trace(i32_scalar, i32_scalar) # doctest: +SKIP
Traceback (most recent call last):
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'
>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
>>> jax.jit(f, static_argnums=0).trace(10, i32_scalar).lower().compile()(5)
Array(25, dtype=int32)
```
The results of `trace` and of `lower` are not safe to serialize directly for use
in a different process. See {ref}`export` for additional APIs for this purpose.
## AOT-compiled functions cannot be transformed
Compiled functions are specialized to a particular set of argument "types," such
@ -187,7 +195,7 @@ in transformations. Example:
>>> z, zs = make_z(3, 2), make_z(4, 3, 2)
>>> g_jit = jax.jit(g)
>>> g_aot = jax.jit(g).lower(z).compile()
>>> g_aot = jax.jit(g).trace(z).lower().compile()
>>> jax.vmap(g_jit)(zs)
Array([[ 1., 5., 9.],
@ -218,7 +226,7 @@ a text representation. Compiled functions do the same, and also offer cost and
memory analyses from the compiler. All of these are provided via methods on the
{class}`jax.stages.Lowered` and {class}`jax.stages.Compiled` objects (e.g.,
`lowered.as_text()` and `compiled.cost_analysis()` above).
You can obtain more debbugging information, e.g., source location,
You can obtain more debugging information, e.g., source location,
by using the `debug_info` parameter to `lowered.as_text()`.
These methods are meant as an aid for manual inspection and debugging, not as a
@ -238,14 +246,3 @@ platform, and runtime. This makes for two important caveats:
remain the same on the following day.
When in doubt, see the package API documentation for {mod}`jax.stages`.
## Inspecting staged-out computations
Stage #1 in the list at the top of this note mentions specialization and
staging, prior to lowering. JAX's internal notion of a function specialized to
the types of its arguments is not always a reified data structure in memory. To
explicitly construct a view of JAX's specialization of a function in the
internal [Jaxpr intermediate
language](https://jax.readthedocs.io/en/latest/jaxpr.html), see
{func}`jax.make_jaxpr`.