diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index aedf9ab94..e5bbc873d 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -315,7 +315,7 @@ def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals): ## assert primitive def check(pred: Bool, msg: str) -> None: - """Check a condition, add an error with msg if condition is False. + """Check a predicate, add an error with msg if predicate is False. This is an effectful operation, and can't be staged (jitted/scanned/...). Before staging a function with checks, ``checkify`` it! @@ -352,40 +352,42 @@ def is_scalar_pred(pred) -> bool: pred.dtype == jnp.dtype('bool')) def check_error(error: Error) -> None: - """Raise an Exception if `error` represents a failure. Functionalized by `checkify`. + """Raise an Exception if ``error`` represents a failure. Functionalized by ``checkify``. The semantics of this function are equivalent to: - def check_error(err: Error) -> None: - err.throw() # can raise ValueError Exception + >>> def check_error(err: Error) -> None: + ... err.throw() # can raise ValueError - But unlike that implementation, `check_error` can be functionalized using - the `checkify` transformation. + But unlike that implementation, ``check_error`` can be functionalized using + the ``checkify`` transformation. - This function is similar to `check` but with a different signature: whereas - `check` takes as arguments a boolean predicate and a new error message - string, this function takes an `Error` value as argument. Both `check` and - this function raise a Python Exception on failure (a side-effect), and thus - cannot be staged out by `jit`, `pmap`, `scan`, etc. Both also can be - functionalized by using `checkify`. + This function is similar to ``check`` but with a different signature: whereas + ``check`` takes as arguments a boolean predicate and a new error message + string, this function takes an ``Error`` value as argument. Both ``check`` + and this function raise a Python Exception on failure (a side-effect), and + thus cannot be staged out by ``jit``, ``pmap``, ``scan``, etc. Both also can + be functionalized by using ``checkify``. - But unlike `check`, this function is like a direct inverse of `checkify`: - whereas `checkify` takes as input a function which can raise a Python - Exception and produces a new function without that effect but which - produces an `Error` value as output, this `check_error` function can accept - an `Error` value as input and can produce the side-effect of raising an - Exception. That is, while `checkify` goes from functionalizable Exception - effect to error value, this `check_error` goes from error value to + But unlike ``check``, this function is like a direct inverse of ``checkify``: + whereas ``checkify`` takes as input a function which can raise a Python + Exception and produces a new function without that effect but which produces + an ``Error`` value as output, this ``check_error`` function can accept an + ``Error`` value as input and can produce the side-effect of raising an + Exception. That is, while ``checkify`` goes from functionalizable Exception + effect to error value, this ``check_error`` goes from error value to functionalizable Exception effect. - `check_error` is useful when you want to turn checks represented by an - `Error` value (produced by functionalizing `check`s via `checkify`) back - into Python Exceptions. + ``check_error`` is useful when you want to turn checks represented by an + ``Error`` value (produced by functionalizing ``checks`` via ``checkify``) + back into Python Exceptions. Args: - error: Error to check + error: Error to check. - For example: + For example, you might want to functionalize part of your program through + checkify, stage out your functionalized code through ``jit``, then re-inject + your error value outside of the ``jit``: >>> import jax >>> from jax.experimental import checkify @@ -663,10 +665,13 @@ error_checks[assert_p] = assert_discharge_rule ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'USER_CHECK']) -float_errors = frozenset({ErrorCategory.NAN, ErrorCategory.DIV}) -index_errors = frozenset({ErrorCategory.OOB}) -automatic_errors = float_errors | index_errors user_checks = frozenset({ErrorCategory.USER_CHECK}) +nan_checks = frozenset({ErrorCategory.NAN}) +index_checks = frozenset({ErrorCategory.OOB}) +div_checks = frozenset({ErrorCategory.DIV}) +float_checks = nan_checks | div_checks +automatic_checks = float_checks | index_checks +all_checks = automatic_checks | user_checks Out = TypeVar('Out') @@ -683,31 +688,35 @@ def checkify(fun: Callable[..., Out], The returned function will return an Error object `err` along with the output of the original function. ``err.get()`` will either return ``None`` (if no error occurred) or a string containing an error message. This error message - will correspond to the first error which occurred. + will correspond to the first error which occurred. ``err.throw()`` will raise + a ValueError with the error message if an error occurred. - The kinds of errors are: - - ErrorCategory.USER_CHECK: a ``checkify.check`` predicate evaluated - to False. - - ErrorCategory.NAN: a floating-point operation generated a NaN value + By default only user-added ``checkify.check`` assertions are enabled. You can + enable automatic checks through the ``errors`` argument. + + The automatic check sets which can be enabled, and when an error is generated: + - ``user_checks``: a ``checkify.check`` evaluated to False. + - ``nan_checks``: a floating-point operation generated a NaN value as output. - - ErrorCategory.DIV: division by zero - - ErrorCategory.OOB: an indexing operation was out-of-bounds + - ``div_checks``: a division by zero. + - ``index_checks``: an index was out-of-bounds. Multiple categories can be enabled together by creating a `Set` (eg. - ``errors={ErrorCategory.NAN, ErrorCategory.OOB}``). + ``errors={ErrorCategory.NAN, ErrorCategory.OOB}``). Multiple sets can be + re-combined (eg. ``errors=float_checks|user_checks``) Args: fun: Callable which can contain user checks (see ``check``). errors: A set of ErrorCategory values which defines the set of enabled - checks. By default only explicit ``check``s are enabled - (``{ErrorCategory.USER_CHECK}``). You can also for example enable NAN and - DIV errors through passing the ``checkify.float_errors`` set, or for + checks. By default only explicit ``checks`` are enabled + (``user_checks``). You can also for example enable NAN and + DIV errors by passing the ``float_checks`` set, or for example combine multiple sets through set operations - (``checkify.float_errors|checkify.user_checks``) + (``float_checks | user_checks``) Returns: A function which accepts the same arguments as ``fun`` and returns as output - a pair where the first element is an ``Error`` value, representing any - failed ``check``s, and the second element is the original output of ``fun``. + a pair where the first element is an ``Error`` value, representing the first + failed ``check``, and the second element is the original output of ``fun``. For example: @@ -719,7 +728,7 @@ def checkify(fun: Callable[..., Out], ... def f(x): ... y = jnp.sin(x) ... return x+y - >>> err, out = checkify.checkify(f, errors=checkify.float_errors)(jnp.inf) + >>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf) >>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL Traceback (most recent call last): ... diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 3ca8b2e34..0b6b1df8d 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -41,7 +41,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): return y1 + y2 f = jax.jit(f) if jit else f - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(3., 4.) self.assertIs(err.get(), None) @@ -61,7 +61,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): return w f = jax.jit(f) if jit else f - checked_f = checkify.checkify(f, errors=checkify.index_errors) + checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.arange(3), 2) self.assertIs(err.get(), None) @@ -79,7 +79,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): return getattr(x.at[i], update_fn)(1.) f = jax.jit(f) - checked_f = checkify.checkify(f, errors=checkify.index_errors) + checked_f = checkify.checkify(f, errors=checkify.index_checks) err, _ = checked_f(jnp.arange(3), 2) self.assertIs(err.get(), None) @@ -96,7 +96,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): return x/y f = jax.jit(f) if jit else f - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,))) self.assertIs(err.get(), None) @@ -120,7 +120,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): return z f = jax.jit(f) if jit else f - checked_f = checkify.checkify(f, errors=checkify.automatic_errors) + checked_f = checkify.checkify(f, errors=checkify.automatic_checks) # no error err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2) @@ -146,7 +146,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): return y * z f = jax.jit(f) if jit else f - checked_f = checkify.checkify(f, errors=checkify.automatic_errors) + checked_f = checkify.checkify(f, errors=checkify.automatic_checks) # both oob and nan error, but oob happens first err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 5) @@ -163,7 +163,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): y1 = jnp.sin(x1) y2 = jnp.sin(x2) return y1 + y2 - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) xs = jnp.array([0., 2.]) err, _ = checked_f(xs, xs) @@ -182,7 +182,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): lambda: jnp.sin(x), lambda: x) - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) err, _ = checked_f(3.) self.assertIs(err.get(), None) @@ -203,7 +203,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): def f(xs): return lax.scan(scan_body, None, xs) - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) xs = jnp.array([0., 2.]) err, (_, ch_outs) = checked_f(xs) @@ -229,7 +229,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): def f(carry, xs): return lax.scan(scan_body, carry, xs) - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) carry, xs = 3., jnp.ones((2,)) err, (ch_out_carry, ch_outs) = checked_f(carry, xs) @@ -271,7 +271,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): def f(init_val): return lax.while_loop(while_cond, while_body, (init_val, 0.)) - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) init_val = 1. err, ch_out = checked_f(init_val) @@ -299,7 +299,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): def f(init_val): return lax.while_loop(while_cond, while_body, init_val) - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) init_val = 1. err, ch_out = checked_f(init_val) @@ -325,7 +325,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): def f(init_val): return lax.while_loop(while_cond, lambda val: val-1, init_val) - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) # error on first cond init_val = 0. @@ -355,7 +355,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): def f(cond_val, body_val): return lax.while_loop(while_cond, while_body, (0., cond_val, body_val)) - checked_f = checkify.checkify(f, errors=checkify.float_errors) + checked_f = checkify.checkify(f, errors=checkify.float_checks) cond_val = jnp.inf body_val = 1. @@ -384,8 +384,8 @@ class CheckifyTransformTests(jtu.JaxTestCase): ("assert", checkify.user_checks, "must be negative!"), ("div", {checkify.ErrorCategory.DIV}, "divided by zero"), ("nan", {checkify.ErrorCategory.NAN}, "nan generated"), - ("oob", checkify.index_errors, "out-of-bounds indexing"), - ("automatic_errors", checkify.automatic_errors, "divided by zero"), + ("oob", checkify.index_checks, "out-of-bounds indexing"), + ("automatic_checks", checkify.automatic_checks, "divided by zero"), ) @jtu.skip_on_devices("tpu") def test_enabled_errors(self, error_set, expected_error): @@ -403,7 +403,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") def test_post_process_call(self): - @partial(checkify.checkify, errors=checkify.float_errors) + @partial(checkify.checkify, errors=checkify.float_checks) def g(x): @jax.jit def f(y): @@ -415,7 +415,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): @jtu.skip_on_devices("tpu") def test_post_process_map(self): - @partial(checkify.checkify, errors=checkify.float_errors) + @partial(checkify.checkify, errors=checkify.float_checks) def g(x): @jax.pmap def f(y): @@ -436,7 +436,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): (x,), (xdot,) = primals, tangents return sin(x), jnp.cos(x) * xdot - f = checkify.checkify(sin, errors=checkify.float_errors) + f = checkify.checkify(sin, errors=checkify.float_checks) err, y = f(3.) self.assertIsNone(err.get()) @@ -459,7 +459,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): # Checkify-of-jvp adds checks (unlike jvp-of-checkify above). g = checkify.checkify(lambda x, xdot: jax.jvp(sin, (x,), (xdot,)), - errors=checkify.float_errors) + errors=checkify.float_checks) err, (y, ydot) = g(3., 1.) # doesn't crash self.assertIsNone(err.get()) # no error self.assertNotEmpty(err.msgs) # but checks were added! @@ -481,7 +481,7 @@ class CheckifyTransformTests(jtu.JaxTestCase): return jnp.cos(x2 / 2.) * g, sin.defvjp(sin_fwd, sin_bwd) - f = checkify.checkify(sin, errors=checkify.float_errors) + f = checkify.checkify(sin, errors=checkify.float_checks) # no differentiation, no error err, y = f(3.) @@ -498,11 +498,11 @@ class CheckifyTransformTests(jtu.JaxTestCase): self.assertEmpty(err.msgs) # and no checks were added! # Checkify-of-vjp adds checks (unlike vjp-of-checkify above). - err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_errors)(3.) + err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.) self.assertIsNone(err.get()) # no error self.assertNotEmpty(err.msgs) # but checks were added! err, y = checkify.checkify(jax.grad(sin), - errors=checkify.float_errors)(jnp.inf) + errors=checkify.float_checks)(jnp.inf) self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'nan generated by primitive sin')