**Summary:** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
The assert-like check API by itself is not functionally pure: it can raise a Python Exception as a side-effect, just like assert. So it can't be staged out with `jit`, `pmap`, `pjit`, or `scan`:
<!-- TODO this error message might need updating -->
```python
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
```
But the checkify transformation functionalizes (or discharges) these effects. A checkify-transformed function returns an error _value_ as a new output and remains functionally pure. That functionalization means checkify-transformed functions can be composed with staging/transforms however we like:
```python
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
Under some JAX transformations you can express runtime error checks with ordinary Python assertions, for example when only using `jax.grad` and `jax.numpy`:
But ordinary assertions don't work inside `jit`, `pmap`, `pjit`, or `scan`. In those cases, numeric computations are staged out rather than evaluated eagerly during Python execution, and as a result numeric values aren't available:
# ConcretizationTypeError: "Abstract tracer value encountered ..."
```
JAX transformation semantics rely on functional purity, especially when composing multiple transformations, so how can we provide an error mechanism without disrupting all that?
Beyond needing a new API, the situation is trickier still:
XLA HLO doesn't support assertions or throwing errors, so even if we had a JAX API which was able to stage out assertions, how would we lower these assertions to XLA?
You could imagine manually adding run-time checks to your function and plumbing out values representing errors:
The error is a regular value computed by the function, and the error is raised outside of `f_checked`. `f_checked` is functionally pure, so we know by construction that it'll already work with `jit`, pmap, pjit, scan, and all of JAX's transformations. The only problem is that this plumbing can be a pain!
`checkify` does this rewrite for you: that includes plumbing the error value through the function, rewriting checks to boolean operations and merging the result with the tracked error value, and returning the final error value as an output to the checkified function:
We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but also track error messages and expose throw and get methods; see {mod}`jax.experimental.checkify`). `checkify.check` also allows you to add run-time values to your error message by providing them as format arguments to the error message.
By default `checkify` only discharges `checkify.check`s, and won't do anything to catch errors like the above. But if you ask it to, `checkify` will also instrument your code with checks automatically.
Note that there’s no multiply in `f`, but there is a multiply in its gradient computation (and this is where the NaN is generated!). So use checkify-of-grad to add automatic checks to both forward and backward pass operations.
`checkify.check`s will only be applied to the primal value of your function. If
you want to use a `check` on a gradient value, use a `custom_vjp`: