mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Merge pull request #11665 from LenaMartens:docs
PiperOrigin-RevId: 464113785
This commit is contained in:
commit
cc19c94d36
@ -37,6 +37,9 @@ err.throw()
|
||||
err, z = checked_f(jnp.array([jnp.inf, 1]), 0)
|
||||
err.throw()
|
||||
# ValueError: nan generated by primitive sin at <...>:8 (f)
|
||||
|
||||
err, z = checked_f(jnp.array([5, 1]), 0)
|
||||
err.throw() # if no error occurred, throw does nothing!
|
||||
```
|
||||
|
||||
## Functionalizing checks
|
||||
@ -63,7 +66,7 @@ ValueError:
|
||||
|
||||
## Why does JAX need checkify?
|
||||
|
||||
In some circumstances you can express runtime error checks with ordinary Python assertions, for example when only using `jax.grad` and `jax.numpy`:
|
||||
Under some JAX transformations you can express runtime error checks with ordinary Python assertions, for example when only using `jax.grad` and `jax.numpy`:
|
||||
|
||||
```python
|
||||
def f(x):
|
||||
@ -74,17 +77,18 @@ jax.grad(f)(0.)
|
||||
# ValueError: "must be positive!"
|
||||
```
|
||||
|
||||
But ordinary assertions don't work inside `jit`, `pmap`, `pjit`, or `scan`. In those cases, numeric computations are staged out (i.e. JAX will call the function with abstracted values) rather than evaluated eagerly during Python execution, and as a result numeric values aren't available:
|
||||
But ordinary assertions don't work inside `jit`, `pmap`, `pjit`, or `scan`. In those cases, numeric computations are staged out rather than evaluated eagerly during Python execution, and as a result numeric values aren't available:
|
||||
|
||||
```python
|
||||
jax.jit(f)(0.)
|
||||
# ConcretizationTypeError: "Abstract tracer value encountered ..."
|
||||
```
|
||||
|
||||
So we can't use ordinary Python assertions everywhere. But beyond needing a new API, the situation is trickier still:
|
||||
XLA HLO doesn't have assertions or error semantics, so even if we had an API which worked with JAX's tracing to stage out assertions, what would we lower them to?
|
||||
JAX transformation semantics rely on functional purity, especially when composing multiple transformations, so how can we provide an error mechanism without disrupting all that?
|
||||
How can we add functional runtime checks, based on error values?
|
||||
We could manually add run-time checks to our function and plumb out values representing errors:
|
||||
Beyond needing a new API, the situation is trickier still:
|
||||
XLA HLO doesn't support assertions or throwing errors, so even if we had a JAX API which was able to stage out assertions, how would we lower these assertions to XLA?
|
||||
|
||||
You could imagine manually adding run-time checks to your function and plumbing out values representing errors:
|
||||
|
||||
```python
|
||||
def f_checked(x):
|
||||
@ -98,9 +102,9 @@ if err:
|
||||
# ValueError: "must be positive!"
|
||||
```
|
||||
|
||||
Here `f_checked` is functionally pure, since there are no Exceptions involved and only ordinary values, so we know by construction that it'll already work with `jit`, pmap, pjit, scan, and all of JAX's transformations. The only problem is that this plumbing can be a pain!
|
||||
The error is a regular value computed by the function, and the error is raised outside of `f_checked`. `f_checked` is functionally pure, so we know by construction that it'll already work with `jit`, pmap, pjit, scan, and all of JAX's transformations. The only problem is that this plumbing can be a pain!
|
||||
|
||||
checkify does this plumbing for you. When it sees a check application, checkify rewrites it to a functional form and plumbs the resulting error value through your function. That includes plumbing the error value through subsequent check calls (so that only the first is reported) and then ultimately returning it as an additional output of the checkify-transformed function:
|
||||
`checkify` does this rewrite for you: that includes plumbing the error value through the function, rewriting checks to boolean operations and merging the result with the tracked error value, and returning the final error value as an output to the checkified function:
|
||||
|
||||
```python
|
||||
def f(x):
|
||||
@ -114,23 +118,23 @@ err.throw()
|
||||
# ValueError: must be positive! (check failed at <...>:2 (f))
|
||||
```
|
||||
|
||||
We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but expose throw and get methods; see {mod}`jax.experimental.checkify`.
|
||||
We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but also track error messages and expose throw and get methods; see {mod}`jax.experimental.checkify`).
|
||||
|
||||
Checkify can also automatically add common checks
|
||||
In addition to functionalizing check calls, checkify can add error checks automatically, e.g. checks for NaNs and out-of-bounds indexing. Consider these error cases:
|
||||
You could now instrument your code with run-time checks, but `checkify` can also automatically add checks for common errors!
|
||||
Consider these error cases:
|
||||
|
||||
```python
|
||||
jnp.arange(3)[5] # out of bounds
|
||||
jnp.sin(jnp.inf) # nan generated
|
||||
jnp.sin(jnp.inf) # NaN generated
|
||||
jnp.ones((5,)) / jnp.arange(5) # division by zero
|
||||
```
|
||||
|
||||
By default checkify only discharges check applications, and won't do anything to catch errors like the above. But if you ask it to, checkify will also instrument your code with checks for these common issues automatically.
|
||||
By default `checkify` only discharges `checkify.check`s, and won't do anything to catch errors like the above. But if you ask it to, `checkify` will also instrument your code with checks automatically.
|
||||
|
||||
```python
|
||||
def f(x, i):
|
||||
y = x[i]
|
||||
z = jnp.sin(y)
|
||||
y = x[i] # i could be out of bounds.
|
||||
z = jnp.sin(y) # z could become NaN
|
||||
return z
|
||||
|
||||
errors = checkify.user_checks | checkify.index_checks | checkify.float_checks
|
||||
@ -145,14 +149,119 @@ err.throw()
|
||||
# ValueError: nan generated by primitive sin at <...>:8 (f)
|
||||
```
|
||||
|
||||
These automatically-added checks include source line information in their error messages.
|
||||
The API for selecting which automatic checks to enable is based on Sets. See {mod}`jax.experimental.checkify` for more details.
|
||||
|
||||
## `checkify` under JAX transformations.
|
||||
|
||||
As demonstrated in the examples above, a checkified function can be happily
|
||||
jitted. Here's a few more examples of `checkify` with other JAX
|
||||
transformations. Note that checkified functions are functionally pure, and
|
||||
should trivially compose with all JAX transformations!
|
||||
|
||||
### `vmap`/`pmap`
|
||||
|
||||
Mapping a checkified function will give you a mapped error, which can contain
|
||||
different errors for every element of the mapped dimension.
|
||||
|
||||
```python
|
||||
def f(x, i):
|
||||
checkify.check(i >= 0, "index needs to be non-negative!")
|
||||
return x[i]
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.all_errors)
|
||||
errs, out = jax.vmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
|
||||
errs.throw()
|
||||
"""
|
||||
ValueError:
|
||||
at mapped index 0: index needs to be non-negative! (check failed at <...>:2 (f))
|
||||
at mapped index 2: out-of-bounds indexing at <...>:3 (f)
|
||||
"""
|
||||
```
|
||||
|
||||
However, a checkify-of-vmap will produce a single (unmapped) error!
|
||||
|
||||
```python
|
||||
@jax.vmap
|
||||
def f(x, i):
|
||||
checkify.check(i >= 0, "index needs to be non-negative!")
|
||||
return x[i]
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.all_errors)
|
||||
err, out = checked_f(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
|
||||
err.throw()
|
||||
# ValueError: index needs to be non-negative! (check failed at <...>:2 (f))
|
||||
```
|
||||
|
||||
### `pjit`
|
||||
|
||||
`pjit` of a checkified function _just works_, you only need to specify an
|
||||
additional `out_axis_resources` of `None` for the error value output.
|
||||
|
||||
```python
|
||||
def f(x):
|
||||
return x / x
|
||||
|
||||
f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
f = pjit(
|
||||
f,
|
||||
in_axis_resources=PartitionSpec('x', None),
|
||||
out_axis_resources=(None, PartitionSpec('x', None)))
|
||||
|
||||
with maps.Mesh(mesh.devices, mesh.axis_names):
|
||||
err, data = f(input_data)
|
||||
err.throw()
|
||||
# ValueError: divided by zero at <...>:4 (f)
|
||||
```
|
||||
|
||||
### `grad`
|
||||
|
||||
Your gradient computation will also be instrumented if you checkify-of-grad:
|
||||
|
||||
```python
|
||||
def f(x):
|
||||
return x / (1 + jnp.sqrt(x))
|
||||
|
||||
grad_f = jax.grad(f)
|
||||
|
||||
err, _ = checkify.checkify(grad_f, errors=checkify.nan_checks)(0.)
|
||||
print(err.get())
|
||||
>> nan generated by primitive mul at <...>:3 (f)
|
||||
```
|
||||
|
||||
Note that there’s no multiply in `f`, but there is a multiply in its gradient computation (and this is where the NaN is generated!). So use checkify-of-grad to add automatic checks to both forward and backward pass operations.
|
||||
|
||||
`checkify.check`s will only be applied to the primal value of your function. If
|
||||
you want to use a `check` on a gradient value, use a `custom_vjp`:
|
||||
|
||||
```python
|
||||
@jax.custom_vjp
|
||||
def assert_gradient_negative(x):
|
||||
return x
|
||||
|
||||
def fwd(x):
|
||||
return assert_gradient_negative(x), None
|
||||
|
||||
def bwd(_, grad):
|
||||
checkify.check(grad < 0, "gradient needs to be negative!")
|
||||
return (grad,)
|
||||
|
||||
assert_gradient_negative.defvjp(fwd, bwd)
|
||||
|
||||
jax.grad(assert_gradient_negative)(-1.)
|
||||
# ValueError: gradient needs to be negative!
|
||||
```
|
||||
|
||||
<!-- TODO: scan? -->
|
||||
<!-- TODO: check\_error -->
|
||||
|
||||
## Strengths and limitations of `jax.experimental.checkify`
|
||||
<!-- TODO gotta finish this -->
|
||||
### Strengths
|
||||
* You can use it everywhere (errors are values and behave intuitively under transformations)
|
||||
* You can use it everywhere (errors are "just values" and behave intuitively under transformations like other values)
|
||||
* Automatic instrumentation: you don't need to make local modifications to your code. Instead, `checkify` can instrument all of it!
|
||||
### Limitations
|
||||
* Adding a lot of runtime checks can be expensive
|
||||
* Requires threading error values in and out of transformations
|
||||
* Adding a lot of runtime checks can be expensive (eg. adding a NaN check to
|
||||
every primitive will add a lot of operations to your computation)
|
||||
* Requires threading error values out of functions and manually throwing the
|
||||
error. If the error is not explicitly thrown, you might miss out on errors!
|
||||
* Throwing an error value will materialize that error value on the host, meaning
|
||||
it's a blocking operation which defeats JAX's async run-ahead.
|
||||
|
Loading…
x
Reference in New Issue
Block a user