Apply suggestions from code review

Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
Lena Martens 2022-02-08 20:23:40 +00:00 committed by GitHub
parent 0042edb5f4
commit b2cf12aa7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -275,15 +275,33 @@ def is_scalar_pred(pred) -> bool:
pred.dtype == jnp.dtype('bool'))
def check_error(error: Error) -> None:
"""Check if an error has occurred.
"""Raise an Exception if `error` represents a failure. Functionalized by `checkify`.
When running in an un-staged function, this will throw the error
if an error has occured.
When running in a staged function, the error will be threaded back
into the function, so it can be functionalized later.
The semantics of this function are equivalent to:
def check_error(err: Error) -> None:
err.throw() # can raise ValueError Exception
But unlike that implementation, `check_error` can be functionalized using
the `checkify` transformation.
This is useful when you want to re-check an error after you've
functionalized a function.
This function is similar to `check` but with a different signature: whereas
`check` takes as arguments a boolean predicate and a new error message
string, this function takes an `Error` value as argument. Both `check` and
this function raise a Python Exception on failure (a side-effect), and thus
cannot be staged out by `jit`, `pmap`, `scan`, etc. Both also can be
functionalized by using `checkify`.
But unlike `check`, this function is like a direct inverse of `checkify`: whereas
`checkify` takes as input a function which can raise a Python Exception and
produces a new function without that effect but which produces an `Error`
value as output, this `check_error` function can accept an `Error` value as
input and can produce the side-effect of raising an Exception. That is, while
`checkify` goes from functionalizable Exception effect to error value, this
`check_error` goes from error value to functionalizable Exception effect.
`check_error` is useful when you want to turn checks represented by an `Error` value
(produced by functionalizing `check`s via `checkify`) back into Python Exceptions.
Args:
error: Error to check
@ -575,16 +593,16 @@ Out = TypeVar('Out')
def checkify(fun: Callable[..., Out],
errors: FrozenSet[ErrorCategory] = user_checks
errors: Set[ErrorCategory] = user_checks
) -> Callable[..., Tuple[Error, Out]]:
"""Check for run-time errors in ``fun``.
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
Run-time errors are either user-added ``checkify.check`` assertions, or
automatically added checks like NaN checks, depending on the ``errors``
argument.
The returned function will return an Error object along with the output of
the function. ``err.get()`` will either return ``None`` (if no error occurred)
The returned function will return an Error object `err` along with the output of
the original function. ``err.get()`` will either return ``None`` (if no error occurred)
or a string containing an error message. This error message will correspond to
the first error which occurred.
@ -596,7 +614,7 @@ def checkify(fun: Callable[..., Out],
- ErrorCategory.DIV: division by zero
- ErrorCategory.OOB: an indexing operation was out-of-bounds
Multiple categories can be enabled together (eg. ``errors={ErrorCategory.NAN,
Multiple categories can be enabled together by creating a `Set` (eg. ``errors={ErrorCategory.NAN,
ErrorCategory.OOB}``).
Args:
@ -604,7 +622,9 @@ def checkify(fun: Callable[..., Out],
errors: Enabled errors. Options are NAN, OOB, DIV. By default USER_CHECK is
enabled.
Returns:
A function with output signature (Error, out)
A function which accepts the same arguments as ``fun`` and returns as output
a pair where the first element is an ``Error`` value, representing any failed
``check``s, and the second element is the original output of ``fun``.
For example: