Checkify: Fix docstring formatting and polish enabled_errors sets.

This commit is contained in:
Lena Martens 2022-02-10 14:28:46 +00:00 committed by lenamartens
parent 1340fbbc09
commit 2ba8aec274
2 changed files with 74 additions and 65 deletions

View File

@ -315,7 +315,7 @@ def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
## assert primitive
def check(pred: Bool, msg: str) -> None:
"""Check a condition, add an error with msg if condition is False.
"""Check a predicate, add an error with msg if predicate is False.
This is an effectful operation, and can't be staged (jitted/scanned/...).
Before staging a function with checks, ``checkify`` it!
@ -352,40 +352,42 @@ def is_scalar_pred(pred) -> bool:
pred.dtype == jnp.dtype('bool'))
def check_error(error: Error) -> None:
"""Raise an Exception if `error` represents a failure. Functionalized by `checkify`.
"""Raise an Exception if ``error`` represents a failure. Functionalized by ``checkify``.
The semantics of this function are equivalent to:
def check_error(err: Error) -> None:
err.throw() # can raise ValueError Exception
>>> def check_error(err: Error) -> None:
... err.throw() # can raise ValueError
But unlike that implementation, `check_error` can be functionalized using
the `checkify` transformation.
But unlike that implementation, ``check_error`` can be functionalized using
the ``checkify`` transformation.
This function is similar to `check` but with a different signature: whereas
`check` takes as arguments a boolean predicate and a new error message
string, this function takes an `Error` value as argument. Both `check` and
this function raise a Python Exception on failure (a side-effect), and thus
cannot be staged out by `jit`, `pmap`, `scan`, etc. Both also can be
functionalized by using `checkify`.
This function is similar to ``check`` but with a different signature: whereas
``check`` takes as arguments a boolean predicate and a new error message
string, this function takes an ``Error`` value as argument. Both ``check``
and this function raise a Python Exception on failure (a side-effect), and
thus cannot be staged out by ``jit``, ``pmap``, ``scan``, etc. Both also can
be functionalized by using ``checkify``.
But unlike `check`, this function is like a direct inverse of `checkify`:
whereas `checkify` takes as input a function which can raise a Python
Exception and produces a new function without that effect but which
produces an `Error` value as output, this `check_error` function can accept
an `Error` value as input and can produce the side-effect of raising an
Exception. That is, while `checkify` goes from functionalizable Exception
effect to error value, this `check_error` goes from error value to
But unlike ``check``, this function is like a direct inverse of ``checkify``:
whereas ``checkify`` takes as input a function which can raise a Python
Exception and produces a new function without that effect but which produces
an ``Error`` value as output, this ``check_error`` function can accept an
``Error`` value as input and can produce the side-effect of raising an
Exception. That is, while ``checkify`` goes from functionalizable Exception
effect to error value, this ``check_error`` goes from error value to
functionalizable Exception effect.
`check_error` is useful when you want to turn checks represented by an
`Error` value (produced by functionalizing `check`s via `checkify`) back
into Python Exceptions.
``check_error`` is useful when you want to turn checks represented by an
``Error`` value (produced by functionalizing ``checks`` via ``checkify``)
back into Python Exceptions.
Args:
error: Error to check
error: Error to check.
For example:
For example, you might want to functionalize part of your program through
checkify, stage out your functionalized code through ``jit``, then re-inject
your error value outside of the ``jit``:
>>> import jax
>>> from jax.experimental import checkify
@ -663,10 +665,13 @@ error_checks[assert_p] = assert_discharge_rule
ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'USER_CHECK'])
float_errors = frozenset({ErrorCategory.NAN, ErrorCategory.DIV})
index_errors = frozenset({ErrorCategory.OOB})
automatic_errors = float_errors | index_errors
user_checks = frozenset({ErrorCategory.USER_CHECK})
nan_checks = frozenset({ErrorCategory.NAN})
index_checks = frozenset({ErrorCategory.OOB})
div_checks = frozenset({ErrorCategory.DIV})
float_checks = nan_checks | div_checks
automatic_checks = float_checks | index_checks
all_checks = automatic_checks | user_checks
Out = TypeVar('Out')
@ -683,31 +688,35 @@ def checkify(fun: Callable[..., Out],
The returned function will return an Error object `err` along with the output
of the original function. ``err.get()`` will either return ``None`` (if no
error occurred) or a string containing an error message. This error message
will correspond to the first error which occurred.
will correspond to the first error which occurred. ``err.throw()`` will raise
a ValueError with the error message if an error occurred.
The kinds of errors are:
- ErrorCategory.USER_CHECK: a ``checkify.check`` predicate evaluated
to False.
- ErrorCategory.NAN: a floating-point operation generated a NaN value
By default only user-added ``checkify.check`` assertions are enabled. You can
enable automatic checks through the ``errors`` argument.
The automatic check sets which can be enabled, and when an error is generated:
- ``user_checks``: a ``checkify.check`` evaluated to False.
- ``nan_checks``: a floating-point operation generated a NaN value
as output.
- ErrorCategory.DIV: division by zero
- ErrorCategory.OOB: an indexing operation was out-of-bounds
- ``div_checks``: a division by zero.
- ``index_checks``: an index was out-of-bounds.
Multiple categories can be enabled together by creating a `Set` (eg.
``errors={ErrorCategory.NAN, ErrorCategory.OOB}``).
``errors={ErrorCategory.NAN, ErrorCategory.OOB}``). Multiple sets can be
re-combined (eg. ``errors=float_checks|user_checks``)
Args:
fun: Callable which can contain user checks (see ``check``).
errors: A set of ErrorCategory values which defines the set of enabled
checks. By default only explicit ``check``s are enabled
(``{ErrorCategory.USER_CHECK}``). You can also for example enable NAN and
DIV errors through passing the ``checkify.float_errors`` set, or for
checks. By default only explicit ``checks`` are enabled
(``user_checks``). You can also for example enable NAN and
DIV errors by passing the ``float_checks`` set, or for
example combine multiple sets through set operations
(``checkify.float_errors|checkify.user_checks``)
(``float_checks | user_checks``)
Returns:
A function which accepts the same arguments as ``fun`` and returns as output
a pair where the first element is an ``Error`` value, representing any
failed ``check``s, and the second element is the original output of ``fun``.
a pair where the first element is an ``Error`` value, representing the first
failed ``check``, and the second element is the original output of ``fun``.
For example:
@ -719,7 +728,7 @@ def checkify(fun: Callable[..., Out],
... def f(x):
... y = jnp.sin(x)
... return x+y
>>> err, out = checkify.checkify(f, errors=checkify.float_errors)(jnp.inf)
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...

View File

@ -41,7 +41,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return y1 + y2
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
err, _ = checked_f(3., 4.)
self.assertIs(err.get(), None)
@ -61,7 +61,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return w
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.index_errors)
checked_f = checkify.checkify(f, errors=checkify.index_checks)
err, _ = checked_f(jnp.arange(3), 2)
self.assertIs(err.get(), None)
@ -79,7 +79,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return getattr(x.at[i], update_fn)(1.)
f = jax.jit(f)
checked_f = checkify.checkify(f, errors=checkify.index_errors)
checked_f = checkify.checkify(f, errors=checkify.index_checks)
err, _ = checked_f(jnp.arange(3), 2)
self.assertIs(err.get(), None)
@ -96,7 +96,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return x/y
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,)))
self.assertIs(err.get(), None)
@ -120,7 +120,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return z
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.automatic_errors)
checked_f = checkify.checkify(f, errors=checkify.automatic_checks)
# no error
err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2)
@ -146,7 +146,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return y * z
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.automatic_errors)
checked_f = checkify.checkify(f, errors=checkify.automatic_checks)
# both oob and nan error, but oob happens first
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 5)
@ -163,7 +163,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
y1 = jnp.sin(x1)
y2 = jnp.sin(x2)
return y1 + y2
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
xs = jnp.array([0., 2.])
err, _ = checked_f(xs, xs)
@ -182,7 +182,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
lambda: jnp.sin(x),
lambda: x)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
err, _ = checked_f(3.)
self.assertIs(err.get(), None)
@ -203,7 +203,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(xs):
return lax.scan(scan_body, None, xs)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
xs = jnp.array([0., 2.])
err, (_, ch_outs) = checked_f(xs)
@ -229,7 +229,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(carry, xs):
return lax.scan(scan_body, carry, xs)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
carry, xs = 3., jnp.ones((2,))
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
@ -271,7 +271,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(init_val):
return lax.while_loop(while_cond, while_body, (init_val, 0.))
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
init_val = 1.
err, ch_out = checked_f(init_val)
@ -299,7 +299,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(init_val):
return lax.while_loop(while_cond, while_body, init_val)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
init_val = 1.
err, ch_out = checked_f(init_val)
@ -325,7 +325,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(init_val):
return lax.while_loop(while_cond, lambda val: val-1, init_val)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
# error on first cond
init_val = 0.
@ -355,7 +355,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(cond_val, body_val):
return lax.while_loop(while_cond, while_body, (0., cond_val, body_val))
checked_f = checkify.checkify(f, errors=checkify.float_errors)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
cond_val = jnp.inf
body_val = 1.
@ -384,8 +384,8 @@ class CheckifyTransformTests(jtu.JaxTestCase):
("assert", checkify.user_checks, "must be negative!"),
("div", {checkify.ErrorCategory.DIV}, "divided by zero"),
("nan", {checkify.ErrorCategory.NAN}, "nan generated"),
("oob", checkify.index_errors, "out-of-bounds indexing"),
("automatic_errors", checkify.automatic_errors, "divided by zero"),
("oob", checkify.index_checks, "out-of-bounds indexing"),
("automatic_checks", checkify.automatic_checks, "divided by zero"),
)
@jtu.skip_on_devices("tpu")
def test_enabled_errors(self, error_set, expected_error):
@ -403,7 +403,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def test_post_process_call(self):
@partial(checkify.checkify, errors=checkify.float_errors)
@partial(checkify.checkify, errors=checkify.float_checks)
def g(x):
@jax.jit
def f(y):
@ -415,7 +415,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
@jtu.skip_on_devices("tpu")
def test_post_process_map(self):
@partial(checkify.checkify, errors=checkify.float_errors)
@partial(checkify.checkify, errors=checkify.float_checks)
def g(x):
@jax.pmap
def f(y):
@ -436,7 +436,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
(x,), (xdot,) = primals, tangents
return sin(x), jnp.cos(x) * xdot
f = checkify.checkify(sin, errors=checkify.float_errors)
f = checkify.checkify(sin, errors=checkify.float_checks)
err, y = f(3.)
self.assertIsNone(err.get())
@ -459,7 +459,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
# Checkify-of-jvp adds checks (unlike jvp-of-checkify above).
g = checkify.checkify(lambda x, xdot: jax.jvp(sin, (x,), (xdot,)),
errors=checkify.float_errors)
errors=checkify.float_checks)
err, (y, ydot) = g(3., 1.) # doesn't crash
self.assertIsNone(err.get()) # no error
self.assertNotEmpty(err.msgs) # but checks were added!
@ -481,7 +481,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return jnp.cos(x2 / 2.) * g,
sin.defvjp(sin_fwd, sin_bwd)
f = checkify.checkify(sin, errors=checkify.float_errors)
f = checkify.checkify(sin, errors=checkify.float_checks)
# no differentiation, no error
err, y = f(3.)
@ -498,11 +498,11 @@ class CheckifyTransformTests(jtu.JaxTestCase):
self.assertEmpty(err.msgs) # and no checks were added!
# Checkify-of-vjp adds checks (unlike vjp-of-checkify above).
err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_errors)(3.)
err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.)
self.assertIsNone(err.get()) # no error
self.assertNotEmpty(err.msgs) # but checks were added!
err, y = checkify.checkify(jax.grad(sin),
errors=checkify.float_errors)(jnp.inf)
errors=checkify.float_checks)(jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')