mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Checkify: Fix docstring formatting and polish enabled_errors sets.
This commit is contained in:
parent
1340fbbc09
commit
2ba8aec274
@ -315,7 +315,7 @@ def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
|
||||
## assert primitive
|
||||
|
||||
def check(pred: Bool, msg: str) -> None:
|
||||
"""Check a condition, add an error with msg if condition is False.
|
||||
"""Check a predicate, add an error with msg if predicate is False.
|
||||
|
||||
This is an effectful operation, and can't be staged (jitted/scanned/...).
|
||||
Before staging a function with checks, ``checkify`` it!
|
||||
@ -352,40 +352,42 @@ def is_scalar_pred(pred) -> bool:
|
||||
pred.dtype == jnp.dtype('bool'))
|
||||
|
||||
def check_error(error: Error) -> None:
|
||||
"""Raise an Exception if `error` represents a failure. Functionalized by `checkify`.
|
||||
"""Raise an Exception if ``error`` represents a failure. Functionalized by ``checkify``.
|
||||
|
||||
The semantics of this function are equivalent to:
|
||||
|
||||
def check_error(err: Error) -> None:
|
||||
err.throw() # can raise ValueError Exception
|
||||
>>> def check_error(err: Error) -> None:
|
||||
... err.throw() # can raise ValueError
|
||||
|
||||
But unlike that implementation, `check_error` can be functionalized using
|
||||
the `checkify` transformation.
|
||||
But unlike that implementation, ``check_error`` can be functionalized using
|
||||
the ``checkify`` transformation.
|
||||
|
||||
This function is similar to `check` but with a different signature: whereas
|
||||
`check` takes as arguments a boolean predicate and a new error message
|
||||
string, this function takes an `Error` value as argument. Both `check` and
|
||||
this function raise a Python Exception on failure (a side-effect), and thus
|
||||
cannot be staged out by `jit`, `pmap`, `scan`, etc. Both also can be
|
||||
functionalized by using `checkify`.
|
||||
This function is similar to ``check`` but with a different signature: whereas
|
||||
``check`` takes as arguments a boolean predicate and a new error message
|
||||
string, this function takes an ``Error`` value as argument. Both ``check``
|
||||
and this function raise a Python Exception on failure (a side-effect), and
|
||||
thus cannot be staged out by ``jit``, ``pmap``, ``scan``, etc. Both also can
|
||||
be functionalized by using ``checkify``.
|
||||
|
||||
But unlike `check`, this function is like a direct inverse of `checkify`:
|
||||
whereas `checkify` takes as input a function which can raise a Python
|
||||
Exception and produces a new function without that effect but which
|
||||
produces an `Error` value as output, this `check_error` function can accept
|
||||
an `Error` value as input and can produce the side-effect of raising an
|
||||
Exception. That is, while `checkify` goes from functionalizable Exception
|
||||
effect to error value, this `check_error` goes from error value to
|
||||
But unlike ``check``, this function is like a direct inverse of ``checkify``:
|
||||
whereas ``checkify`` takes as input a function which can raise a Python
|
||||
Exception and produces a new function without that effect but which produces
|
||||
an ``Error`` value as output, this ``check_error`` function can accept an
|
||||
``Error`` value as input and can produce the side-effect of raising an
|
||||
Exception. That is, while ``checkify`` goes from functionalizable Exception
|
||||
effect to error value, this ``check_error`` goes from error value to
|
||||
functionalizable Exception effect.
|
||||
|
||||
`check_error` is useful when you want to turn checks represented by an
|
||||
`Error` value (produced by functionalizing `check`s via `checkify`) back
|
||||
into Python Exceptions.
|
||||
``check_error`` is useful when you want to turn checks represented by an
|
||||
``Error`` value (produced by functionalizing ``checks`` via ``checkify``)
|
||||
back into Python Exceptions.
|
||||
|
||||
Args:
|
||||
error: Error to check
|
||||
error: Error to check.
|
||||
|
||||
For example:
|
||||
For example, you might want to functionalize part of your program through
|
||||
checkify, stage out your functionalized code through ``jit``, then re-inject
|
||||
your error value outside of the ``jit``:
|
||||
|
||||
>>> import jax
|
||||
>>> from jax.experimental import checkify
|
||||
@ -663,10 +665,13 @@ error_checks[assert_p] = assert_discharge_rule
|
||||
|
||||
ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'USER_CHECK'])
|
||||
|
||||
float_errors = frozenset({ErrorCategory.NAN, ErrorCategory.DIV})
|
||||
index_errors = frozenset({ErrorCategory.OOB})
|
||||
automatic_errors = float_errors | index_errors
|
||||
user_checks = frozenset({ErrorCategory.USER_CHECK})
|
||||
nan_checks = frozenset({ErrorCategory.NAN})
|
||||
index_checks = frozenset({ErrorCategory.OOB})
|
||||
div_checks = frozenset({ErrorCategory.DIV})
|
||||
float_checks = nan_checks | div_checks
|
||||
automatic_checks = float_checks | index_checks
|
||||
all_checks = automatic_checks | user_checks
|
||||
|
||||
Out = TypeVar('Out')
|
||||
|
||||
@ -683,31 +688,35 @@ def checkify(fun: Callable[..., Out],
|
||||
The returned function will return an Error object `err` along with the output
|
||||
of the original function. ``err.get()`` will either return ``None`` (if no
|
||||
error occurred) or a string containing an error message. This error message
|
||||
will correspond to the first error which occurred.
|
||||
will correspond to the first error which occurred. ``err.throw()`` will raise
|
||||
a ValueError with the error message if an error occurred.
|
||||
|
||||
The kinds of errors are:
|
||||
- ErrorCategory.USER_CHECK: a ``checkify.check`` predicate evaluated
|
||||
to False.
|
||||
- ErrorCategory.NAN: a floating-point operation generated a NaN value
|
||||
By default only user-added ``checkify.check`` assertions are enabled. You can
|
||||
enable automatic checks through the ``errors`` argument.
|
||||
|
||||
The automatic check sets which can be enabled, and when an error is generated:
|
||||
- ``user_checks``: a ``checkify.check`` evaluated to False.
|
||||
- ``nan_checks``: a floating-point operation generated a NaN value
|
||||
as output.
|
||||
- ErrorCategory.DIV: division by zero
|
||||
- ErrorCategory.OOB: an indexing operation was out-of-bounds
|
||||
- ``div_checks``: a division by zero.
|
||||
- ``index_checks``: an index was out-of-bounds.
|
||||
|
||||
Multiple categories can be enabled together by creating a `Set` (eg.
|
||||
``errors={ErrorCategory.NAN, ErrorCategory.OOB}``).
|
||||
``errors={ErrorCategory.NAN, ErrorCategory.OOB}``). Multiple sets can be
|
||||
re-combined (eg. ``errors=float_checks|user_checks``)
|
||||
|
||||
Args:
|
||||
fun: Callable which can contain user checks (see ``check``).
|
||||
errors: A set of ErrorCategory values which defines the set of enabled
|
||||
checks. By default only explicit ``check``s are enabled
|
||||
(``{ErrorCategory.USER_CHECK}``). You can also for example enable NAN and
|
||||
DIV errors through passing the ``checkify.float_errors`` set, or for
|
||||
checks. By default only explicit ``checks`` are enabled
|
||||
(``user_checks``). You can also for example enable NAN and
|
||||
DIV errors by passing the ``float_checks`` set, or for
|
||||
example combine multiple sets through set operations
|
||||
(``checkify.float_errors|checkify.user_checks``)
|
||||
(``float_checks | user_checks``)
|
||||
Returns:
|
||||
A function which accepts the same arguments as ``fun`` and returns as output
|
||||
a pair where the first element is an ``Error`` value, representing any
|
||||
failed ``check``s, and the second element is the original output of ``fun``.
|
||||
a pair where the first element is an ``Error`` value, representing the first
|
||||
failed ``check``, and the second element is the original output of ``fun``.
|
||||
|
||||
For example:
|
||||
|
||||
@ -719,7 +728,7 @@ def checkify(fun: Callable[..., Out],
|
||||
... def f(x):
|
||||
... y = jnp.sin(x)
|
||||
... return x+y
|
||||
>>> err, out = checkify.checkify(f, errors=checkify.float_errors)(jnp.inf)
|
||||
>>> err, out = checkify.checkify(f, errors=checkify.float_checks)(jnp.inf)
|
||||
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
|
@ -41,7 +41,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return y1 + y2
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
err, _ = checked_f(3., 4.)
|
||||
self.assertIs(err.get(), None)
|
||||
@ -61,7 +61,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return w
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
||||
|
||||
err, _ = checked_f(jnp.arange(3), 2)
|
||||
self.assertIs(err.get(), None)
|
||||
@ -79,7 +79,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return getattr(x.at[i], update_fn)(1.)
|
||||
|
||||
f = jax.jit(f)
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
||||
|
||||
err, _ = checked_f(jnp.arange(3), 2)
|
||||
self.assertIs(err.get(), None)
|
||||
@ -96,7 +96,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return x/y
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,)))
|
||||
self.assertIs(err.get(), None)
|
||||
@ -120,7 +120,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.automatic_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.automatic_checks)
|
||||
|
||||
# no error
|
||||
err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2)
|
||||
@ -146,7 +146,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return y * z
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.automatic_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.automatic_checks)
|
||||
|
||||
# both oob and nan error, but oob happens first
|
||||
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 5)
|
||||
@ -163,7 +163,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
y1 = jnp.sin(x1)
|
||||
y2 = jnp.sin(x2)
|
||||
return y1 + y2
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
xs = jnp.array([0., 2.])
|
||||
err, _ = checked_f(xs, xs)
|
||||
@ -182,7 +182,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
lambda: jnp.sin(x),
|
||||
lambda: x)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
err, _ = checked_f(3.)
|
||||
self.assertIs(err.get(), None)
|
||||
@ -203,7 +203,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(xs):
|
||||
return lax.scan(scan_body, None, xs)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
xs = jnp.array([0., 2.])
|
||||
err, (_, ch_outs) = checked_f(xs)
|
||||
@ -229,7 +229,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(carry, xs):
|
||||
return lax.scan(scan_body, carry, xs)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
carry, xs = 3., jnp.ones((2,))
|
||||
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
||||
@ -271,7 +271,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(init_val):
|
||||
return lax.while_loop(while_cond, while_body, (init_val, 0.))
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
init_val = 1.
|
||||
err, ch_out = checked_f(init_val)
|
||||
@ -299,7 +299,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(init_val):
|
||||
return lax.while_loop(while_cond, while_body, init_val)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
init_val = 1.
|
||||
err, ch_out = checked_f(init_val)
|
||||
@ -325,7 +325,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(init_val):
|
||||
return lax.while_loop(while_cond, lambda val: val-1, init_val)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
# error on first cond
|
||||
init_val = 0.
|
||||
@ -355,7 +355,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(cond_val, body_val):
|
||||
return lax.while_loop(while_cond, while_body, (0., cond_val, body_val))
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
cond_val = jnp.inf
|
||||
body_val = 1.
|
||||
@ -384,8 +384,8 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
("assert", checkify.user_checks, "must be negative!"),
|
||||
("div", {checkify.ErrorCategory.DIV}, "divided by zero"),
|
||||
("nan", {checkify.ErrorCategory.NAN}, "nan generated"),
|
||||
("oob", checkify.index_errors, "out-of-bounds indexing"),
|
||||
("automatic_errors", checkify.automatic_errors, "divided by zero"),
|
||||
("oob", checkify.index_checks, "out-of-bounds indexing"),
|
||||
("automatic_checks", checkify.automatic_checks, "divided by zero"),
|
||||
)
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_enabled_errors(self, error_set, expected_error):
|
||||
@ -403,7 +403,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_post_process_call(self):
|
||||
@partial(checkify.checkify, errors=checkify.float_errors)
|
||||
@partial(checkify.checkify, errors=checkify.float_checks)
|
||||
def g(x):
|
||||
@jax.jit
|
||||
def f(y):
|
||||
@ -415,7 +415,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_post_process_map(self):
|
||||
@partial(checkify.checkify, errors=checkify.float_errors)
|
||||
@partial(checkify.checkify, errors=checkify.float_checks)
|
||||
def g(x):
|
||||
@jax.pmap
|
||||
def f(y):
|
||||
@ -436,7 +436,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
(x,), (xdot,) = primals, tangents
|
||||
return sin(x), jnp.cos(x) * xdot
|
||||
|
||||
f = checkify.checkify(sin, errors=checkify.float_errors)
|
||||
f = checkify.checkify(sin, errors=checkify.float_checks)
|
||||
|
||||
err, y = f(3.)
|
||||
self.assertIsNone(err.get())
|
||||
@ -459,7 +459,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
# Checkify-of-jvp adds checks (unlike jvp-of-checkify above).
|
||||
g = checkify.checkify(lambda x, xdot: jax.jvp(sin, (x,), (xdot,)),
|
||||
errors=checkify.float_errors)
|
||||
errors=checkify.float_checks)
|
||||
err, (y, ydot) = g(3., 1.) # doesn't crash
|
||||
self.assertIsNone(err.get()) # no error
|
||||
self.assertNotEmpty(err.msgs) # but checks were added!
|
||||
@ -481,7 +481,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return jnp.cos(x2 / 2.) * g,
|
||||
sin.defvjp(sin_fwd, sin_bwd)
|
||||
|
||||
f = checkify.checkify(sin, errors=checkify.float_errors)
|
||||
f = checkify.checkify(sin, errors=checkify.float_checks)
|
||||
|
||||
# no differentiation, no error
|
||||
err, y = f(3.)
|
||||
@ -498,11 +498,11 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
self.assertEmpty(err.msgs) # and no checks were added!
|
||||
|
||||
# Checkify-of-vjp adds checks (unlike vjp-of-checkify above).
|
||||
err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_errors)(3.)
|
||||
err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.)
|
||||
self.assertIsNone(err.get()) # no error
|
||||
self.assertNotEmpty(err.msgs) # but checks were added!
|
||||
err, y = checkify.checkify(jax.grad(sin),
|
||||
errors=checkify.float_errors)(jnp.inf)
|
||||
errors=checkify.float_checks)(jnp.inf)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user