mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
update AOT walkthrough to cover explicit tracing stage
This commit is contained in:
parent
005c14b4da
commit
af381a73a3
91
docs/aot.md
91
docs/aot.md
@ -4,9 +4,10 @@
|
||||
|
||||
<!--* freshness: { reviewed: '2024-06-12' } *-->
|
||||
|
||||
JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning
|
||||
a function that is compiled and runs on accelerators or the CPU. As the JIT
|
||||
acronym indicates, all compilation happens _just-in-time_ for execution.
|
||||
JAX's `jax.jit` transformation returns a function that, when called,
|
||||
compiles a computation and runs it on accelerators (or the CPU). As
|
||||
the JIT acronym indicates, all compilation happens _just-in-time_ for
|
||||
execution.
|
||||
|
||||
Some situations call for _ahead-of-time_ (AOT) compilation instead. When you
|
||||
want to fully compile prior to execution time, or you want control over when
|
||||
@ -18,10 +19,14 @@ function/callable output by {func}`jax.jit`, say `f = jax.jit(F)` for some input
|
||||
callable `F`. When it is invoked with arguments, say `f(x, y)` where `x` and `y`
|
||||
are arrays, JAX does the following in order:
|
||||
|
||||
1. **Stage out** a specialized version of the original Python callable `F` to an
|
||||
internal representation. The specialization reflects a restriction of `F` to
|
||||
input types inferred from properties of the arguments `x` and `y` (usually
|
||||
their shape and element type).
|
||||
1. **Stage out** a specialized version of the original Python callable
|
||||
`F` to an internal representation. The specialization reflects a
|
||||
restriction of `F` to input types inferred from properties of the
|
||||
arguments `x` and `y` (usually their shape and element type). JAX
|
||||
carries out this specialization by a process that we call
|
||||
_tracing_. During tracing, JAX stages the specialization of `F` to
|
||||
a jaxpr, which is a function in the [Jaxpr intermediate
|
||||
language](https://jax.readthedocs.io/en/latest/jaxpr.html).
|
||||
|
||||
2. **Lower** this specialized, staged-out computation to the XLA compiler's
|
||||
input language, StableHLO.
|
||||
@ -31,9 +36,8 @@ are arrays, JAX does the following in order:
|
||||
|
||||
4. **Execute** the compiled executable with the arrays `x` and `y` as arguments.
|
||||
|
||||
JAX's AOT API gives you direct control over steps #2, #3, and #4 (but [not
|
||||
#1](#inspecting-staged-out-computations)), plus some other features along the
|
||||
way. An example:
|
||||
JAX's AOT API gives you direct control over each of these steps, plus
|
||||
some other features along the way. An example:
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
@ -41,7 +45,13 @@ way. An example:
|
||||
>>> def f(x, y): return 2 * x + y
|
||||
>>> x, y = 3, 4
|
||||
|
||||
>>> lowered = jax.jit(f).lower(x, y)
|
||||
>>> traced = jax.jit(f).trace(x, y)
|
||||
|
||||
>>> # Print the specialized, staged-out representation (as Jaxpr IR)
|
||||
>>> print(traced.jaxpr)
|
||||
{ lambda ; a:i32[] b:i32[]. let c:i32[] = mul 2 a; d:i32[] = add c b in (d,) }
|
||||
|
||||
>>> lowered = traced.lower()
|
||||
|
||||
>>> # Print lowered HLO
|
||||
>>> print(lowered.as_text())
|
||||
@ -67,37 +77,36 @@ Array(10, dtype=int32, weak_type=True)
|
||||
```
|
||||
|
||||
Note that the lowered objects can be used only in the same process
|
||||
in which they were lowered. For exporting use cases,
|
||||
see the {ref}`export` APIs.
|
||||
in which they were lowered. For exporting use cases, see the {ref}`export` APIs.
|
||||
|
||||
See the {mod}`jax.stages` documentation for more details on what functionality
|
||||
the lowering and compiled functions provide.
|
||||
|
||||
All optional arguments to `jit`---such as `static_argnums`---are respected in
|
||||
the corresponding lowering, compilation, and execution.
|
||||
the corresponding tracing, lowering, compilation, and execution.
|
||||
|
||||
In the example above, we can replace the arguments to `lower` with any objects
|
||||
In the example above, we can replace the arguments to `trace` with any objects
|
||||
that have `shape` and `dtype` attributes:
|
||||
|
||||
```python
|
||||
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
|
||||
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x, y)
|
||||
Array(10, dtype=int32)
|
||||
|
||||
```
|
||||
|
||||
More generally, `lower` only needs its arguments to structurally supply what JAX
|
||||
More generally, `trace` only needs its arguments to structurally supply what JAX
|
||||
must know for specialization and lowering. For typical array arguments like the
|
||||
ones above, this means `shape` and `dtype` fields. For static arguments, by
|
||||
contrast, JAX needs actual array values (more on this
|
||||
[below](#lowering-with-static-arguments)).
|
||||
[below](#tracing-with-static-arguments)).
|
||||
|
||||
Invoking an AOT-compiled function with arguments that are incompatible with its
|
||||
lowering raises an error:
|
||||
tracing raises an error:
|
||||
|
||||
```python
|
||||
>>> x_1d = y_1d = jnp.arange(3)
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_1d, y_1d) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
...
|
||||
Traceback (most recent call last):
|
||||
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
|
||||
@ -105,7 +114,7 @@ Argument 'x' compiled with int32[] and called with int32[3]
|
||||
Argument 'y' compiled with int32[] and called with int32[3]
|
||||
|
||||
>>> x_f = y_f = jnp.float32(72.)
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
>>> jax.jit(f).trace(i32_scalar, i32_scalar).lower().compile()(x_f, y_f) # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
...
|
||||
Traceback (most recent call last):
|
||||
TypeError: Argument types differ from the types for which this computation was compiled. The mismatches are:
|
||||
@ -119,14 +128,14 @@ transformations](#aot-compiled-functions-cannot-be-transformed) such as
|
||||
`jax.jit`, {func}`jax.grad`, and {func}`jax.vmap`.
|
||||
|
||||
|
||||
## Lowering with static arguments
|
||||
## Tracing with static arguments
|
||||
|
||||
Lowering with static arguments underscores the interaction between options
|
||||
passed to `jax.jit`, the arguments passed to `lower`, and the arguments needed
|
||||
Tracing with static arguments underscores the interaction between options
|
||||
passed to `jax.jit`, the arguments passed to `trace`, and the arguments needed
|
||||
to invoke the resulting compiled function. Continuing with our example above:
|
||||
|
||||
```python
|
||||
>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)
|
||||
>>> lowered_with_x = jax.jit(f, static_argnums=0).trace(7, 8).lower()
|
||||
|
||||
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
|
||||
>>> print(lowered_with_x.as_text())
|
||||
@ -143,30 +152,29 @@ Array(19, dtype=int32, weak_type=True)
|
||||
|
||||
```
|
||||
|
||||
The result of `lower` is not safe to serialize directly for use
|
||||
in a different process.
|
||||
See {ref}`export` for additional APIs for this purpose.
|
||||
|
||||
Note that `lower` here takes two arguments as usual, but the subsequent compiled
|
||||
Note that `trace` here takes two arguments as usual, but the subsequent compiled
|
||||
function accepts only the remaining non-static second argument. The static first
|
||||
argument (value 7) is taken as a constant at lowering time and built into the
|
||||
lowered computation, where it is possibly folded in with other constants. In
|
||||
this case, its multiplication by 2 is simplified, resulting in the constant 14.
|
||||
|
||||
Although the second argument to `lower` above can be replaced by a hollow
|
||||
Although the second argument to `trace` above can be replaced by a hollow
|
||||
shape/dtype structure, it is necessary that the static first argument be a
|
||||
concrete value. Otherwise, lowering would err:
|
||||
concrete value. Otherwise, tracing errs:
|
||||
|
||||
```python
|
||||
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar) # doctest: +SKIP
|
||||
>>> jax.jit(f, static_argnums=0).trace(i32_scalar, i32_scalar) # doctest: +SKIP
|
||||
Traceback (most recent call last):
|
||||
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'
|
||||
|
||||
>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
|
||||
>>> jax.jit(f, static_argnums=0).trace(10, i32_scalar).lower().compile()(5)
|
||||
Array(25, dtype=int32)
|
||||
|
||||
```
|
||||
|
||||
The results of `trace` and of `lower` are not safe to serialize directly for use
|
||||
in a different process. See {ref}`export` for additional APIs for this purpose.
|
||||
|
||||
## AOT-compiled functions cannot be transformed
|
||||
|
||||
Compiled functions are specialized to a particular set of argument "types," such
|
||||
@ -187,7 +195,7 @@ in transformations. Example:
|
||||
>>> z, zs = make_z(3, 2), make_z(4, 3, 2)
|
||||
|
||||
>>> g_jit = jax.jit(g)
|
||||
>>> g_aot = jax.jit(g).lower(z).compile()
|
||||
>>> g_aot = jax.jit(g).trace(z).lower().compile()
|
||||
|
||||
>>> jax.vmap(g_jit)(zs)
|
||||
Array([[ 1., 5., 9.],
|
||||
@ -218,7 +226,7 @@ a text representation. Compiled functions do the same, and also offer cost and
|
||||
memory analyses from the compiler. All of these are provided via methods on the
|
||||
{class}`jax.stages.Lowered` and {class}`jax.stages.Compiled` objects (e.g.,
|
||||
`lowered.as_text()` and `compiled.cost_analysis()` above).
|
||||
You can obtain more debbugging information, e.g., source location,
|
||||
You can obtain more debugging information, e.g., source location,
|
||||
by using the `debug_info` parameter to `lowered.as_text()`.
|
||||
|
||||
These methods are meant as an aid for manual inspection and debugging, not as a
|
||||
@ -238,14 +246,3 @@ platform, and runtime. This makes for two important caveats:
|
||||
remain the same on the following day.
|
||||
|
||||
When in doubt, see the package API documentation for {mod}`jax.stages`.
|
||||
|
||||
|
||||
## Inspecting staged-out computations
|
||||
|
||||
Stage #1 in the list at the top of this note mentions specialization and
|
||||
staging, prior to lowering. JAX's internal notion of a function specialized to
|
||||
the types of its arguments is not always a reified data structure in memory. To
|
||||
explicitly construct a view of JAX's specialization of a function in the
|
||||
internal [Jaxpr intermediate
|
||||
language](https://jax.readthedocs.io/en/latest/jaxpr.html), see
|
||||
{func}`jax.make_jaxpr`.
|
||||
|
Loading…
x
Reference in New Issue
Block a user