Merge pull request #13618 from LenaMartens:debcheck

PiperOrigin-RevId: 495271080
This commit is contained in:
jax authors 2022-12-14 03:54:13 -08:00
commit 5832dfd812
3 changed files with 134 additions and 21 deletions

View File

@ -243,7 +243,7 @@ class Error:
return None
def throw(self):
check_error(self)
_check_error(self)
def __str__(self):
return f'Error({self.get()})'
@ -605,7 +605,7 @@ def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
Before staging a function with checks, :func:`~checkify` it!
Args:
pred: if False, an error is added.
pred: if False, a FailedCheckError error is added.
msg: error message if error is added. Can be a format string.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
@ -631,11 +631,23 @@ def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
"""
_check(pred, msg, False, *fmt_args, **fmt_kwargs)
def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
if not is_scalar_pred(pred):
raise TypeError(f'check takes a scalar pred as argument, got {pred}')
prim_name = 'debug_check' if debug else 'check'
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
error = assert_func(init_error, jnp.logical_not(pred), new_error)
return check_error(error)
_check_error(error, debug=debug)
def _check_error(error, *, debug=False):
error = tree_map(core.raise_as_much_as_possible, error)
if any(map(np.shape, error._pred.values())):
error = _reduce_any_error(error)
err_args, tree_def = tree_flatten(error)
return check_p.bind(*err_args, err_tree=tree_def, debug=debug)
def is_scalar_pred(pred) -> bool:
@ -643,6 +655,44 @@ def is_scalar_pred(pred) -> bool:
isinstance(pred, jnp.ndarray) and pred.shape == () and
pred.dtype == jnp.dtype('bool'))
def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
"""Check a predicate when running under checkify, otherwise is a no-op.
A `debug_check` will only be run if it is transformed by :func:`~checkify`,
otherwise the check will be dropped.
Args:
pred: if False, a FailedCheckError error is added.
msg: error message if error is added.
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
`msg`, eg.:
``debug_check(.., "check failed on values {} and {named}", x, named=y)``
Note that these arguments can be traced values allowing you to add
run-time values to the error message.
Note that tracking these run-time arrays will increase your memory usage,
even if no error happens.
For example:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental import checkify
>>> def f(x):
... checkify.debug_check(x!=0, "cannot be zero!")
... return x
>>> _ = f(0) # running without checkify means no debug_check is run.
>>> checked_f = checkify.checkify(f)
>>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check.
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError: cannot be zero!
"""
_check(pred, msg, True, *fmt_args, **fmt_kwargs)
def check_error(error: Error) -> None:
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
@ -708,12 +758,8 @@ def check_error(error: Error) -> None:
raise ValueError('check_error takes an Error as argument, '
f'got type {type(error)} instead.')
error = tree_map(core.raise_as_much_as_possible, error)
if any(map(np.shape, error._pred.values())):
error = _reduce_any_error(error)
err_args, tree_def = tree_flatten(error)
_check_error(error, debug=False)
return check_p.bind(*err_args, err_tree=tree_def)
## check primitive
@ -725,7 +771,10 @@ class JaxRuntimeError(ValueError):
pass
@check_p.def_impl
def check_impl(*args, err_tree):
def check_impl(*args, err_tree, debug):
if debug:
# NOOP (check will only trigger when discharged)
return []
error = tree_unflatten(err_tree, args)
exc = error.get_exception()
if exc:
@ -733,7 +782,8 @@ def check_impl(*args, err_tree):
return []
@check_p.def_effectful_abstract_eval
def check_abstract_eval(*args, err_tree):
def check_abstract_eval(*args, err_tree, debug):
del debug
return [], set(tree_unflatten(err_tree, args)._pred.keys())
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
@ -744,7 +794,10 @@ functionalization_error = ValueError(
' through `checkify.checkify`.'
)
def check_lowering_rule(ctx, *args, err_tree):
def check_lowering_rule(ctx, *args, err_tree, debug):
if debug:
# NOOP (check will only trigger when discharged)
return []
if not config.jax_experimental_unsafe_xla_runtime_errors:
raise functionalization_error
@ -758,12 +811,14 @@ def check_lowering_rule(ctx, *args, err_tree):
ctx.module_context.add_keepalive(keep_alive)
return out_op
def check_lowering_rule_unsupported(*a, **k):
def check_lowering_rule_unsupported(*a, debug, **k):
if debug:
return []
raise functionalization_error
def python_err(err_tree, *args):
error = tree_unflatten(err_tree, args)
check_error(error)
_check_error(error)
return []
mlir.register_lowering(check_p, check_lowering_rule_unsupported,
@ -773,22 +828,20 @@ mlir.register_lowering(check_p, check_lowering_rule,
mlir.register_lowering(check_p, check_lowering_rule,
platform='gpu')
def check_batching_rule(batched_args, batch_dims, *, err_tree):
def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
if dim is not batching.not_mapped)
batched_args = (batching.bdim_at_front(a, d, size)
for a, d in zip(batched_args, batch_dims))
err = tree_unflatten(err_tree, batched_args)
check_error(err)
_check_error(err, debug=debug)
return [], []
batching.primitive_batchers[check_p] = check_batching_rule
def check_jvp_rule(primals, _, *, err_tree):
def check_jvp_rule(primals, _, *, err_tree, debug):
# Check primals, discard tangents.
check_p.bind(*primals, err_tree=err_tree)
check_p.bind(*primals, err_tree=err_tree, debug=debug)
return [], []
ad.primitive_jvps[check_p] = check_jvp_rule
## checkify rules
@ -1106,7 +1159,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
error_checks[pjit.pjit_p] = pjit_error_check
def check_discharge_rule(error, enabled_errors, *args, err_tree):
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
del debug
new_error = tree_unflatten(err_tree, args)
# Split up new_error into error to be functionalized if it's included in
# enabled_errors (=discharged_error) and an error to be defunctionalized if

View File

@ -21,6 +21,7 @@ from jax._src.checkify import (
check as check,
check_error as check_error,
checkify as checkify,
debug_check as debug_check,
div_checks as div_checks,
float_checks as float_checks,
index_checks as index_checks,

View File

@ -1092,6 +1092,64 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "hi!")
def test_debug_check_noop(self):
def f(x):
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
return x
x = jnp.ones(())
f(x) # no error.
jax.jit(f)(x) # no error.
jax.vmap(f)(jnp.ones((2,))) # no error.
jax.grad(f)(x) # no error.
@parameterized.named_parameters(("with_jit", True), ("without_jit", False))
def test_debug_check_nonscalar_pred(self, with_jit):
def f(x):
checkify.debug_check(x != x, "{x} cannot be {x}", x=x)
return x
checked_f = checkify.checkify(f)
if with_jit:
checked_f = jax.jit(checked_f)
with self.assertRaisesRegex(TypeError, "debug_check takes a scalar pred"):
checked_f(jnp.ones((5,)))
@parameterized.named_parameters(("with_jit", True), ("without_jit", False))
def test_debug_check(self, with_jit):
def f(x):
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
return x
checked_f = checkify.checkify(f)
if with_jit:
checked_f = jax.jit(checked_f)
err, _ = checked_f(jnp.ones(()))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "1.0 cannot be 1.0")
@parameterized.named_parameters(("with_jit", True), ("without_jit", False))
def test_debug_check_disabled_errors(self, with_jit):
def f(x):
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
return x
checked_f = checkify.checkify(f, errors={})
if with_jit:
checked_f = jax.jit(checked_f)
err, _ = checked_f(jnp.ones((1,)))
self.assertIsNone(err.get())
def test_debug_check_jaxpr_roundtrip(self):
def f(x):
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
return x
x = jnp.ones(())
jaxpr = jax.make_jaxpr(f)(x)
roundtrip_f = partial(jax.core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)
checked_f = checkify.checkify(jax.jit(roundtrip_f))
err, _ = checked_f(jnp.ones(()))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "1.0 cannot be 1.0")
class LowerableChecksTest(jtu.JaxTestCase):
def setUp(self):