mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
write in-process AOT walkthrough doc
This commit is contained in:
parent
43db06491c
commit
bb68fbeefa
236
docs/aot.md
Normal file
236
docs/aot.md
Normal file
@ -0,0 +1,236 @@
|
||||
# Ahead-of-time lowering and compilation
|
||||
|
||||
JAX offers several transformations, such as `jax.jit` and `jax.pmap`, returning
|
||||
a function that is compiled and runs on accelerators or the CPU. As the JIT
|
||||
acronym indicates, all compilation happens _just-in-time_ for execution.
|
||||
|
||||
Some situations call for _ahead-of-time_ (AOT) compilation instead. When you
|
||||
want to fully compile prior to execution time, or you want control over when
|
||||
different parts of the compilation process take place, JAX has some options for
|
||||
you.
|
||||
|
||||
First, let's review the stages of compilation. Suppose that `f` is a
|
||||
function/callable output by {func}`jax.jit`, say `f = jax.jit(F)` for some input
|
||||
callable `F`. When it is invoked with arguments, say `f(x, y)` where `x` and `y`
|
||||
are arrays, JAX does the following in order:
|
||||
|
||||
1. **Stage out** a specialized version of the original Python callable `F` to an
|
||||
internal representation. The specialization reflects a restriction of `F` to
|
||||
input types inferred from properties of the arguments `x` and `y` (usually
|
||||
their shape and element type).
|
||||
|
||||
2. **Lower** this specialized, staged-out computation to the XLA compiler's
|
||||
input language, MHLO.
|
||||
|
||||
3. **Compile** the lowered HLO program to produce an optimized executable for
|
||||
the target device (CPU, GPU, or TPU).
|
||||
|
||||
4. **Execute** the compiled executable with the arrays `x` and `y` as arguments.
|
||||
|
||||
JAX's AOT API gives you direct control over steps #2, #3, and #4 (but [not
|
||||
#1](#inspecting-staged-out-computations)), plus some other features along the
|
||||
way. An example:
|
||||
|
||||
```python
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
>>> import numpy as np
|
||||
|
||||
>>> def f(x, y): return 2 * x + y
|
||||
>>> x, y = 3, 4
|
||||
|
||||
>>> lowered = jax.jit(f).lower(x, y)
|
||||
|
||||
>>> # Print lowered HLO
|
||||
>>> print(lowered.as_text())
|
||||
module @jit_f.0 {
|
||||
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
|
||||
%0 = mhlo.constant dense<2> : tensor<i32>
|
||||
%1 = mhlo.multiply %0, %arg0 : tensor<i32>
|
||||
%2 = mhlo.add %1, %arg1 : tensor<i32>
|
||||
return %2 : tensor<i32>
|
||||
}
|
||||
}
|
||||
|
||||
>>> compiled = lowered.compile()
|
||||
|
||||
>>> # Query for cost analysis, print FLOP estimate
|
||||
>>> compiled.cost_analysis()[0]['flops']
|
||||
2.0
|
||||
|
||||
>>> # Execute the compiled function!
|
||||
>>> compiled(x, y)
|
||||
DeviceArray(10, dtype=int32)
|
||||
```
|
||||
|
||||
See the {mod}`jax.stages` documentation for more details on what functionality
|
||||
the lowering and compiled functions provide.
|
||||
|
||||
In place of `jax.jit` above, you can also `lower(...)` the result of
|
||||
{func}`jax.pmap`, as well as `pjit` and `xmap` (from
|
||||
{mod}`jax.experimental.pjit` and {mod}`jax.experimental.maps` respectively). In
|
||||
each case, you can `compile()` the result similarly.
|
||||
|
||||
All optional arguments to `jit`---such as `static_argnums`---are respected in
|
||||
the corresponding lowering, compilation, and execution. Again the same goes for
|
||||
`pmap`, `pjit`, and `xmap`.
|
||||
|
||||
In the example above, we can replace the arguments to `lower` with any objects
|
||||
that have `shape` and `dtype` attributes:
|
||||
|
||||
```python
|
||||
>>> i32_scalar = jax.ShapeDtypeStruct((), jnp.dtype('int32'))
|
||||
>>> jax.jit(f).lower(i32_scalar, i32_scalar).compile()(x, y)
|
||||
DeviceArray(10, dtype=int32)
|
||||
```
|
||||
|
||||
More generally, `lower` only needs its arguments to structurally supply what JAX
|
||||
must know for specialization and lowering. For typical array arguments like the
|
||||
ones above, this means `shape` and `dtype` fields. For static arguments, by
|
||||
contrast, JAX needs actual array values (more on this
|
||||
[below](#lowering-with-static-arguments)).
|
||||
|
||||
Invoking an AOT-compiled function with arguments that are incompatible with its
|
||||
lowering raises an error:
|
||||
|
||||
```python
|
||||
>>> x_1d = y_1d = jnp.arange(3)
|
||||
>>> jax.jit(f)(i32_scalar, i32_scalar).compile(x_1d, y_1d)
|
||||
...
|
||||
TypeError: Computation compiled for input types:
|
||||
ShapedArray(int32[]), ShapedArray(int32[])
|
||||
called with:
|
||||
ShapedArray(int32[3]), ShapedArray(int32[3])
|
||||
|
||||
>>> x_f = y_f = 72.0
|
||||
>>> jax.jit(f)(i32_scalar, i32_scalar).compile(x_f, y_f)
|
||||
...
|
||||
TypeError: Computation compiled for input types:
|
||||
ShapedArray(int32[]), ShapedArray(int32[])
|
||||
called with:
|
||||
ShapedArray(float32[]), ShapedArray(float32[])
|
||||
```
|
||||
|
||||
Relatedly, AOT-compiled functions [cannot be transformed by JAX's just-in-time
|
||||
transformations](#aot-compiled-functions-cannot-be-transformed) such as
|
||||
`jax.jit`, {func}`jax.grad`, and {func}`jax.vmap`.
|
||||
|
||||
|
||||
## Lowering with static arguments
|
||||
|
||||
Lowering with static arguments underscores the interaction between options
|
||||
passed to `jax.jit`, the arguments passed to `lower`, and the arguments needed
|
||||
to invoke the resulting compiled function. Continuing with our example above:
|
||||
|
||||
```python
|
||||
>>> lowered_with_x = jax.jit(f, static_argnums=0).lower(7, 8)
|
||||
|
||||
>>> # Lowered HLO, specialized to the *value* of the first argument (7)
|
||||
>>> print(lowered_with_x.as_text())
|
||||
module @jit_f.1 {
|
||||
func.func public @main(%arg0: tensor<i32>) -> tensor<i32> {
|
||||
%0 = mhlo.constant dense<14> : tensor<i32>
|
||||
%1 = mhlo.add %0, %arg0 : tensor<i32>
|
||||
return %1 : tensor<i32>
|
||||
}
|
||||
}
|
||||
>>> lowered_with_x.compile()(5)
|
||||
DeviceArray(19, dtype=int32)
|
||||
```
|
||||
|
||||
Note that `lower` here takes two arguments as usual, but the subsequent compiled
|
||||
function accepts only the remaining non-static second argument. The static first
|
||||
argument (value 7) is taken as a constant at lowering time and built into the
|
||||
lowered computation, where it is possibly folded in with other constants. In
|
||||
this case, its multiplication by 2 is simplified, resulting in the constant 14.
|
||||
|
||||
Although the second argument to `lower` above can be replaced by a hollow
|
||||
shape/dtype structure, it is necessary that the static first argument be a
|
||||
concrete value. Otherwise, lowering would err:
|
||||
|
||||
```python
|
||||
>>> jax.jit(f, static_argnums=0).lower(i32_scalar, i32_scalar)
|
||||
TypeError: unsupported operand type(s) for *: 'int' and 'ShapeDtypeStruct'
|
||||
|
||||
>>> jax.jit(f, static_argnums=0).lower(10, i32_scalar).compile()(5)
|
||||
DeviceArray(25, dtype=int32)
|
||||
```
|
||||
|
||||
## AOT-compiled functions cannot be transformed
|
||||
|
||||
Compiled functions are specialized to a particular set of argument "types," such
|
||||
as arrays with a specific shape and element type in our running example. From
|
||||
JAX's internal point of view, transformations such as {func}`jax.vmap` alter the
|
||||
type signature of functions in a way that invalidates the compiled-for type
|
||||
signature. As a policy, JAX simply disallows compiled functions to be involved
|
||||
in transformations. Example:
|
||||
|
||||
```python
|
||||
>>> def g(x):
|
||||
... assert x.shape == (3, 2)
|
||||
... return x @ jnp.ones(2)
|
||||
|
||||
>>> def make_z(*shape):
|
||||
... return jnp.arange(np.prod(shape)).reshape(shape)
|
||||
|
||||
>>> z, zs = make_z(3, 2), make_z(4, 3, 2)
|
||||
|
||||
>>> g_jit = jax.jit(g)
|
||||
>>> g_aot = jax.jit(g).lower(z).compile()
|
||||
|
||||
>>> jax.vmap(g_jit)(zs)
|
||||
DeviceArray([[ 1., 5., 9.],
|
||||
[13., 17., 21.],
|
||||
[25., 29., 33.],
|
||||
[37., 41., 45.]], dtype=float32)
|
||||
|
||||
>>> jax.vmap(g_aot)(zs)
|
||||
TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class 'jax.interpreters.batching.BatchTracer'>.
|
||||
```
|
||||
|
||||
A similar error is raised when `g_aot` is involved in autodiff
|
||||
(e.g. {func}`jax.grad`). For consistency, transformation by `jax.jit` is
|
||||
disallowed as well, even though `jit` does not meaningfully modify its
|
||||
argument's type signature.
|
||||
|
||||
|
||||
## Debug information and analyses, when available
|
||||
|
||||
In addition to the primary AOT functionality (separate and explicit lowering,
|
||||
compilation, and execution), JAX's various AOT stages also offer some additional
|
||||
features to help with debugging and gathering compiler feedback.
|
||||
|
||||
For instance, as the initial example above shows, lowered functions often offer
|
||||
a text representation. Compiled functions do the same, and also offer cost and
|
||||
memory analyses from the compiler. All of these are provided via methods on the
|
||||
{class}`jax.stages.Lowered` and {class}`jax.stages.Compiled` objects (e.g.,
|
||||
`lowered.as_text()` and `compiled.cost_analysis()` above).
|
||||
|
||||
These methods are meant as an aid for manual inspection and debugging, not as a
|
||||
reliably programmable API. Their availability and output vary by compiler,
|
||||
platform, and runtime. This makes for two important caveats:
|
||||
|
||||
1. If some functionality is unavailable on JAX's current backend, then the
|
||||
method for it returns something trivial (and `False`-like). For example, if
|
||||
the compiler underlying JAX does not provide a cost analysis, then
|
||||
`compiled.cost_analysis()` will be `None`.
|
||||
|
||||
2. If some functionality is available, there are still very limited guarantees
|
||||
on what the corresponding method provides. The return value is not required
|
||||
to be consistent---in type, structure, or value---across JAX configurations,
|
||||
backends/platforms, versions, or even invocations of the method. JAX cannot
|
||||
guarantee that the output of `compiled.cost_analysis()` on one day will
|
||||
remain the same on the following day.
|
||||
|
||||
When in doubt, see the package API documentation for {mod}`jax.stages`.
|
||||
|
||||
|
||||
## Inspecting staged-out computations
|
||||
|
||||
Stage #1 in the list at the top of this note mentions specialization and
|
||||
staging, prior to lowering. JAX's internal notion of a function specialized to
|
||||
the types of its arguments is not always a reified data structure in memory. To
|
||||
explicitly construct a view of JAX's specialization of a function in the
|
||||
internal [Jaxpr intermediate
|
||||
language](https://jax.readthedocs.io/en/latest/jaxpr.html), see
|
||||
{func}`jax.make_jaxpr`.
|
@ -30,6 +30,7 @@ parallelize, Just-In-Time compile to GPU/TPU, and more.
|
||||
|
||||
faq
|
||||
async_dispatch
|
||||
aot
|
||||
jaxpr
|
||||
notebooks/convolutions
|
||||
pytrees
|
||||
|
@ -17,6 +17,8 @@ JAX transformations that compile just in time for execution, such as
|
||||
``jax.jit`` and ``jax.pmap``, also support a common means of explicit
|
||||
lowering and compilation *ahead of time*. This module defines types
|
||||
that represent the stages of this process.
|
||||
|
||||
For more, see the `AOT walkthrough <https://jax.readthedocs.io/en/latest/aot.html>`_.
|
||||
"""
|
||||
|
||||
from jax._src.stages import (
|
||||
|
Loading…
x
Reference in New Issue
Block a user