mirror of
https://github.com/ROCm/jax.git
synced 2025-04-17 12:26:07 +00:00
Checkify: Add documentation for adding run-time values to error message.
This commit is contained in:
parent
942aa7a907
commit
3db909ee8d
@ -8,16 +8,16 @@ import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
def f(x, i):
|
||||
checkify.check(i >= 0, "index needs to be non-negative!")
|
||||
checkify.check(i >= 0, "index needs to be non-negative, got {i}", i=i)
|
||||
y = x[i]
|
||||
z = jnp.sin(y)
|
||||
return z
|
||||
|
||||
jittable_f = checkify.checkify(f)
|
||||
|
||||
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -1)
|
||||
err, z = jax.jit(jittable_f)(jnp.ones((5,)), -2)
|
||||
print(err.get())
|
||||
# >> index needs to be non-negative! (check failed at <...>:6 (f))
|
||||
# >> index needs to be non-negative, got -2! (check failed at <...>:6 (f))
|
||||
```
|
||||
|
||||
You can also use checkify to automatically add common checks:
|
||||
@ -58,7 +58,7 @@ But the checkify transformation functionalizes (or discharges) these effects. A
|
||||
err, z = jax.pmap(checked_f)(jnp.ones((3, 5)), jnp.array([-1, 2, 100]))
|
||||
err.throw()
|
||||
"""
|
||||
ValueError:
|
||||
ValueError:
|
||||
.. at mapped index 0: index needs to be non-negative! (check failed at :6 (f))
|
||||
.. at mapped index 2: out-of-bounds indexing at <..>:7 (f)
|
||||
"""
|
||||
@ -108,19 +108,19 @@ The error is a regular value computed by the function, and the error is raised o
|
||||
|
||||
```python
|
||||
def f(x):
|
||||
checkify.check(x > 0., "must be positive!") # convenient but effectful API
|
||||
checkify.check(x > 0., "{} must be positive!", x) # convenient but effectful API
|
||||
return jnp.log(x)
|
||||
|
||||
f_checked = checkify(f)
|
||||
|
||||
err, x = jax.jit(f_checked)(0.)
|
||||
err, x = jax.jit(f_checked)(-1.)
|
||||
err.throw()
|
||||
# ValueError: must be positive! (check failed at <...>:2 (f))
|
||||
# ValueError: -1. must be positive! (check failed at <...>:2 (f))
|
||||
```
|
||||
|
||||
We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but also track error messages and expose throw and get methods; see {mod}`jax.experimental.checkify`).
|
||||
We call this functionalizing or discharging the effect introduced by calling check. (In the "manual" example above the error value is just a boolean. checkify's error values are conceptually similar but also track error messages and expose throw and get methods; see {mod}`jax.experimental.checkify`). `checkify.check` also allows you to add run-time values to your error message by providing them as format arguments to the error message.
|
||||
|
||||
You could now instrument your code with run-time checks, but `checkify` can also automatically add checks for common errors!
|
||||
You could now manually instrument your code with run-time checks, but `checkify` can also automatically add checks for common errors!
|
||||
Consider these error cases:
|
||||
|
||||
```python
|
||||
@ -158,10 +158,29 @@ jitted. Here's a few more examples of `checkify` with other JAX
|
||||
transformations. Note that checkified functions are functionally pure, and
|
||||
should trivially compose with all JAX transformations!
|
||||
|
||||
### `jit`
|
||||
|
||||
You can safely add `jax.jit` to a checkified function, or `checkify` a jitted
|
||||
function, both will work.
|
||||
|
||||
```python
|
||||
def f(x, i):
|
||||
return x[i]
|
||||
|
||||
checkify_of_jit = checkify.checkify(jax.jit(f))
|
||||
jit_of_checkify = jax.jit(checkify.checkify(f))
|
||||
err, _ = checkify_of_jit(jnp.ones((5,)), 100)
|
||||
err.get()
|
||||
# out-of-bounds indexing at <..>:2 (f)
|
||||
err, _ = jit_of_checkify(jnp.ones((5,)), 100)
|
||||
# out-of-bounds indexing at <..>:2 (f)
|
||||
```
|
||||
|
||||
### `vmap`/`pmap`
|
||||
|
||||
Mapping a checkified function will give you a mapped error, which can contain
|
||||
different errors for every element of the mapped dimension.
|
||||
You can `vmap` and `pmap` checkified functions (or `checkify` mapped functions).
|
||||
Mapping a checkified function will give you a mapped error, which can contain
|
||||
different errors for every element of the mapped dimension.
|
||||
|
||||
```python
|
||||
def f(x, i):
|
||||
@ -206,7 +225,7 @@ f = pjit(
|
||||
f,
|
||||
in_axis_resources=PartitionSpec('x', None),
|
||||
out_axis_resources=(None, PartitionSpec('x', None)))
|
||||
|
||||
|
||||
with maps.Mesh(mesh.devices, mesh.axis_names):
|
||||
err, data = f(input_data)
|
||||
err.throw()
|
||||
@ -264,4 +283,4 @@ jax.grad(assert_gradient_negative)(-1.)
|
||||
* Requires threading error values out of functions and manually throwing the
|
||||
error. If the error is not explicitly thrown, you might miss out on errors!
|
||||
* Throwing an error value will materialize that error value on the host, meaning
|
||||
it's a blocking operation which defeats JAX's async run-ahead.
|
||||
it's a blocking operation which defeats JAX's async run-ahead.
|
||||
|
@ -598,7 +598,7 @@ def checkify_fun_to_jaxpr(
|
||||
|
||||
|
||||
|
||||
def check(pred: Bool, msg: str, *args, **kwargs) -> None:
|
||||
def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
|
||||
"""Check a predicate, add an error with msg if predicate is False.
|
||||
|
||||
This is an effectful operation, and can't be staged (jitted/scanned/...).
|
||||
@ -606,7 +606,14 @@ def check(pred: Bool, msg: str, *args, **kwargs) -> None:
|
||||
|
||||
Args:
|
||||
pred: if False, an error is added.
|
||||
msg: error message if error is added.
|
||||
msg: error message if error is added. Can be a format string.
|
||||
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
|
||||
`msg`, eg.:
|
||||
``check(.., "check failed on values {} and {named_arg}", x, named_arg=y)``
|
||||
Note that these arguments can be traced values allowing you to add
|
||||
run-time values to the error message.
|
||||
Note that tracking these run-time arrays will increase your memory usage,
|
||||
even if no error happens.
|
||||
|
||||
For example:
|
||||
|
||||
@ -614,22 +621,23 @@ def check(pred: Bool, msg: str, *args, **kwargs) -> None:
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax.experimental import checkify
|
||||
>>> def f(x):
|
||||
... checkify.check(x!=0, "cannot be zero!")
|
||||
... checkify.check(x>0, "{x} needs to be positive!", x=x)
|
||||
... return 1/x
|
||||
>>> checked_f = checkify.checkify(f)
|
||||
>>> err, out = jax.jit(checked_f)(0)
|
||||
>>> err, out = jax.jit(checked_f)(-3.)
|
||||
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
jax._src.checkify.JaxRuntimeError: cannot be zero!
|
||||
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
|
||||
|
||||
"""
|
||||
if not is_scalar_pred(pred):
|
||||
raise TypeError(f'check takes a scalar pred as argument, got {pred}')
|
||||
new_error = FailedCheckError(summary(), msg, *args, **kwargs)
|
||||
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
|
||||
error = assert_func(init_error, jnp.logical_not(pred), new_error)
|
||||
return check_error(error)
|
||||
|
||||
|
||||
def is_scalar_pred(pred) -> bool:
|
||||
return (isinstance(pred, bool) or
|
||||
isinstance(pred, jnp.ndarray) and pred.shape == () and
|
||||
|
Loading…
x
Reference in New Issue
Block a user