TypeError: Cannot apply JAX transformations to a function lowered and compiled for a particular signature. Detected argument of Tracer type <class'jax.interpreters.batching.BatchTracer'>.
```
A similar error is raised when `g_aot` is involved in autodiff
(e.g. {func}`jax.grad`). For consistency, transformation by `jax.jit` is
disallowed as well, even though `jit` does not meaningfully modify its
argument's type signature.
## Debug information and analyses, when available
In addition to the primary AOT functionality (separate and explicit lowering,
compilation, and execution), JAX's various AOT stages also offer some additional
features to help with debugging and gathering compiler feedback.
For instance, as the initial example above shows, lowered functions often offer
a text representation. Compiled functions do the same, and also offer cost and
memory analyses from the compiler. All of these are provided via methods on the
{class}`jax.stages.Lowered` and {class}`jax.stages.Compiled` objects (e.g.,
`lowered.as_text()` and `compiled.cost_analysis()` above).
These methods are meant as an aid for manual inspection and debugging, not as a
reliably programmable API. Their availability and output vary by compiler,
platform, and runtime. This makes for two important caveats:
1. If some functionality is unavailable on JAX's current backend, then the
method for it returns something trivial (and `False`-like). For example, if
the compiler underlying JAX does not provide a cost analysis, then
`compiled.cost_analysis()` will be `None`.
2. If some functionality is available, there are still very limited guarantees
on what the corresponding method provides. The return value is not required
to be consistent---in type, structure, or value---across JAX configurations,
backends/platforms, versions, or even invocations of the method. JAX cannot
guarantee that the output of `compiled.cost_analysis()` on one day will
remain the same on the following day.
When in doubt, see the package API documentation for {mod}`jax.stages`.
## Inspecting staged-out computations
Stage #1 in the list at the top of this note mentions specialization and
staging, prior to lowering. JAX's internal notion of a function specialized to
the types of its arguments is not always a reified data structure in memory. To
explicitly construct a view of JAX's specialization of a function in the
internal [Jaxpr intermediate
language](https://jax.readthedocs.io/en/latest/jaxpr.html), see