rocm_jax/docs/debugging/checkify_guide.md
Sharad Vikram 4386a0f909 Add debugging tools under jax.debug and documentation
Co-authored-by: Matthew Johnson <mattjj@google.com>
Co-authored-by: Lena Martens <lenamartens@google.com>
2022-07-28 20:07:26 -07:00

159 lines
6.5 KiB
Markdown

# The `checkify` transformation
**TL;DR** Checkify lets you add `jit`-able runtime error checking (e.g. out of bounds indexing) to your JAX code. Use the `checkify.checkify` transformation together with the assert-like `checkify.check` function to add runtime checks to JAX code:
```python
from jax.experimental import checkify
import jax
import jax.numpy as jnp
def f(x, i):
checkify.check(i >= 0, "index needs to be non-negative!")
y = x[i]
z = jnp.sin(y)
return z
jittable_f = checkify.checkify(f)
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
print(err.get())
# >> index needs to be non-negative! (check failed at <...>:6 (f))
```
You can also use checkify to automatically add common checks:
```python
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.ones((5,)), -1)
err.throw()
# ValueError: index needs to be non-negative! (check failed at <…>:6 (f))
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
```
## Functionalizing checks
The assert-like check API by itself is not functionally pure: it can raise a Python Exception as a side-effect, just like assert. So it can't be staged out with `jit`, `pmap`, `pjit`, or `scan`:
<!-- TODO this error message might need updating -->
```python
jax.jit(f)(jnp.ones((5,)), -1) # checkify transformation not used
# ValueError: Cannot abstractly evaluate a checkify.check which was not functionalized.
```
But the checkify transformation functionalizes (or discharges) these effects. A checkify-transformed function returns an error _value_ as a new output and remains functionally pure. That functionalization means checkify-transformed functions can be composed with staging/transforms however we like:
```python
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
err.throw()
"""
ValueError:
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
"""
```
## Why does JAX need checkify?
In some circumstances you can express runtime error checks with ordinary Python assertions, for example when only using `jax.grad` and `jax.numpy`:
```python
def f(x):
assert x > 0., "must be positive!"
return jnp.log(x)
jax.grad(f)(0.)
# ValueError: "must be positive!"
```
But ordinary assertions don't work inside `jit`, `pmap`, `pjit`, or `scan`. In those cases, numeric computations are staged out (i.e. JAX will call the function with abstracted values) rather than evaluated eagerly during Python execution, and as a result numeric values aren't available:
```python
jax.jit(f)(0.)
# ConcretizationTypeError: "Abstract tracer value encountered ..."
```
So we can't use ordinary Python assertions everywhere. But beyond needing a new API, the situation is trickier still:
XLA HLO doesn't have assertions or error semantics, so even if we had an API which worked with JAX's tracing to stage out assertions, what would we lower them to?
JAX transformation semantics rely on functional purity, especially when composing multiple transformations, so how can we provide an error mechanism without disrupting all that?
How can we add functional runtime checks, based on error values?
We could manually add run-time checks to our function and plumb out values representing errors:
```python
def f_checked(x):
error = x <= 0.
result = jnp.log(x)
return error, result
err, y = jax.jit(f_checked)(0.)
if err:
raise ValueError("must be positive!")
# ValueError: "must be positive!"
```
Here `f_checked` is functionally pure, since there are no Exceptions involved and only ordinary values, so we know by construction that it'll already work with `jit`, pmap, pjit, scan, and all of JAX's transformations. The only problem is that this plumbing can be a pain!
checkify does this plumbing for you. When it sees a check application, checkify rewrites it to a functional form and plumbs the resulting error value through your function. That includes plumbing the error value through subsequent check calls (so that only the first is reported) and then ultimately returning it as an additional output of the checkify-transformed function:
```python
def f(x):
checkify.check(x > 0., "must be positive!") # convenient but effectful API
return jnp.log(x)
f_checked = checkify(f)
err, x = jax.jit(f_checked)(0.)
err.throw()
# ValueError: must be positive! (check failed at <...>:2 (f))
```
We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but expose throw and get methods; see {mod}`jax.experimental.checkify`.
Checkify can also automatically add common checks
In addition to functionalizing check calls, checkify can add error checks automatically, e.g. checks for NaNs and out-of-bounds indexing. Consider these error cases:
```python
jnp.arange(3)[5] # out of bounds
jnp.sin(jnp.inf) # nan generated
jnp.ones((5,)) / jnp.arange(5) # division by zero
```
By default checkify only discharges check applications, and won't do anything to catch errors like the above. But if you ask it to, checkify will also instrument your code with checks for these common issues automatically.
```python
def f(x, i):
y = x[i]
z = jnp.sin(y)
return z
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
checked_f = checkify.checkify(f, errors=errors)
err, z = checked_f(jnp.ones((5,)), 100)
err.throw()
# ValueError: out-of-bounds indexing at <..>:7 (f)
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
err.throw()
# ValueError: nan generated by primitive sin at <...>:8 (f)
```
These automatically-added checks include source line information in their error messages.
The API for selecting which automatic checks to enable is based on Sets. See {mod}`jax.experimental.checkify` for more details.
## Strengths and limitations of `jax.experimental.checkify`
<!-- TODO gotta finish this -->
### Strengths
* You can use it everywhere (errors are values and behave intuitively under transformations)
* Automatic instrumentation: you don't need to make local modifications to your code. Instead, `checkify` can instrument all of it!
### Limitations
* Adding a lot of runtime checks can be expensive
* Requires threading error values in and out of transformations