mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Apply suggestions from code review
Co-authored-by: Matthew Johnson <mattjj@google.com>
This commit is contained in:
parent
0042edb5f4
commit
b2cf12aa7e
@ -275,15 +275,33 @@ def is_scalar_pred(pred) -> bool:
|
||||
pred.dtype == jnp.dtype('bool'))
|
||||
|
||||
def check_error(error: Error) -> None:
|
||||
"""Check if an error has occurred.
|
||||
"""Raise an Exception if `error` represents a failure. Functionalized by `checkify`.
|
||||
|
||||
When running in an un-staged function, this will throw the error
|
||||
if an error has occured.
|
||||
When running in a staged function, the error will be threaded back
|
||||
into the function, so it can be functionalized later.
|
||||
The semantics of this function are equivalent to:
|
||||
|
||||
def check_error(err: Error) -> None:
|
||||
err.throw() # can raise ValueError Exception
|
||||
|
||||
But unlike that implementation, `check_error` can be functionalized using
|
||||
the `checkify` transformation.
|
||||
|
||||
This is useful when you want to re-check an error after you've
|
||||
functionalized a function.
|
||||
This function is similar to `check` but with a different signature: whereas
|
||||
`check` takes as arguments a boolean predicate and a new error message
|
||||
string, this function takes an `Error` value as argument. Both `check` and
|
||||
this function raise a Python Exception on failure (a side-effect), and thus
|
||||
cannot be staged out by `jit`, `pmap`, `scan`, etc. Both also can be
|
||||
functionalized by using `checkify`.
|
||||
|
||||
But unlike `check`, this function is like a direct inverse of `checkify`: whereas
|
||||
`checkify` takes as input a function which can raise a Python Exception and
|
||||
produces a new function without that effect but which produces an `Error`
|
||||
value as output, this `check_error` function can accept an `Error` value as
|
||||
input and can produce the side-effect of raising an Exception. That is, while
|
||||
`checkify` goes from functionalizable Exception effect to error value, this
|
||||
`check_error` goes from error value to functionalizable Exception effect.
|
||||
|
||||
`check_error` is useful when you want to turn checks represented by an `Error` value
|
||||
(produced by functionalizing `check`s via `checkify`) back into Python Exceptions.
|
||||
|
||||
Args:
|
||||
error: Error to check
|
||||
@ -575,16 +593,16 @@ Out = TypeVar('Out')
|
||||
|
||||
|
||||
def checkify(fun: Callable[..., Out],
|
||||
errors: FrozenSet[ErrorCategory] = user_checks
|
||||
errors: Set[ErrorCategory] = user_checks
|
||||
) -> Callable[..., Tuple[Error, Out]]:
|
||||
"""Check for run-time errors in ``fun``.
|
||||
"""Functionalize `check` calls in `fun`, and optionally add run-time error checks.
|
||||
|
||||
Run-time errors are either user-added ``checkify.check`` assertions, or
|
||||
automatically added checks like NaN checks, depending on the ``errors``
|
||||
argument.
|
||||
|
||||
The returned function will return an Error object along with the output of
|
||||
the function. ``err.get()`` will either return ``None`` (if no error occurred)
|
||||
The returned function will return an Error object `err` along with the output of
|
||||
the original function. ``err.get()`` will either return ``None`` (if no error occurred)
|
||||
or a string containing an error message. This error message will correspond to
|
||||
the first error which occurred.
|
||||
|
||||
@ -596,7 +614,7 @@ def checkify(fun: Callable[..., Out],
|
||||
- ErrorCategory.DIV: division by zero
|
||||
- ErrorCategory.OOB: an indexing operation was out-of-bounds
|
||||
|
||||
Multiple categories can be enabled together (eg. ``errors={ErrorCategory.NAN,
|
||||
Multiple categories can be enabled together by creating a `Set` (eg. ``errors={ErrorCategory.NAN,
|
||||
ErrorCategory.OOB}``).
|
||||
|
||||
Args:
|
||||
@ -604,7 +622,9 @@ def checkify(fun: Callable[..., Out],
|
||||
errors: Enabled errors. Options are NAN, OOB, DIV. By default USER_CHECK is
|
||||
enabled.
|
||||
Returns:
|
||||
A function with output signature (Error, out)
|
||||
A function which accepts the same arguments as ``fun`` and returns as output
|
||||
a pair where the first element is an ``Error`` value, representing any failed
|
||||
``check``s, and the second element is the original output of ``fun``.
|
||||
|
||||
For example:
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user