diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index 2624f0c3f..ac4880fcd 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -275,15 +275,33 @@ def is_scalar_pred(pred) -> bool: pred.dtype == jnp.dtype('bool')) def check_error(error: Error) -> None: - """Check if an error has occurred. + """Raise an Exception if `error` represents a failure. Functionalized by `checkify`. - When running in an un-staged function, this will throw the error - if an error has occured. - When running in a staged function, the error will be threaded back - into the function, so it can be functionalized later. + The semantics of this function are equivalent to: + + def check_error(err: Error) -> None: + err.throw() # can raise ValueError Exception + + But unlike that implementation, `check_error` can be functionalized using + the `checkify` transformation. - This is useful when you want to re-check an error after you've - functionalized a function. + This function is similar to `check` but with a different signature: whereas + `check` takes as arguments a boolean predicate and a new error message + string, this function takes an `Error` value as argument. Both `check` and + this function raise a Python Exception on failure (a side-effect), and thus + cannot be staged out by `jit`, `pmap`, `scan`, etc. Both also can be + functionalized by using `checkify`. + + But unlike `check`, this function is like a direct inverse of `checkify`: whereas + `checkify` takes as input a function which can raise a Python Exception and + produces a new function without that effect but which produces an `Error` + value as output, this `check_error` function can accept an `Error` value as + input and can produce the side-effect of raising an Exception. That is, while + `checkify` goes from functionalizable Exception effect to error value, this + `check_error` goes from error value to functionalizable Exception effect. + + `check_error` is useful when you want to turn checks represented by an `Error` value + (produced by functionalizing `check`s via `checkify`) back into Python Exceptions. Args: error: Error to check @@ -575,16 +593,16 @@ Out = TypeVar('Out') def checkify(fun: Callable[..., Out], - errors: FrozenSet[ErrorCategory] = user_checks + errors: Set[ErrorCategory] = user_checks ) -> Callable[..., Tuple[Error, Out]]: - """Check for run-time errors in ``fun``. + """Functionalize `check` calls in `fun`, and optionally add run-time error checks. Run-time errors are either user-added ``checkify.check`` assertions, or automatically added checks like NaN checks, depending on the ``errors`` argument. - The returned function will return an Error object along with the output of - the function. ``err.get()`` will either return ``None`` (if no error occurred) + The returned function will return an Error object `err` along with the output of + the original function. ``err.get()`` will either return ``None`` (if no error occurred) or a string containing an error message. This error message will correspond to the first error which occurred. @@ -596,7 +614,7 @@ def checkify(fun: Callable[..., Out], - ErrorCategory.DIV: division by zero - ErrorCategory.OOB: an indexing operation was out-of-bounds - Multiple categories can be enabled together (eg. ``errors={ErrorCategory.NAN, + Multiple categories can be enabled together by creating a `Set` (eg. ``errors={ErrorCategory.NAN, ErrorCategory.OOB}``). Args: @@ -604,7 +622,9 @@ def checkify(fun: Callable[..., Out], errors: Enabled errors. Options are NAN, OOB, DIV. By default USER_CHECK is enabled. Returns: - A function with output signature (Error, out) + A function which accepts the same arguments as ``fun`` and returns as output + a pair where the first element is an ``Error`` value, representing any failed + ``check``s, and the second element is the original output of ``fun``. For example: