mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Merge pull request #13618 from LenaMartens:debcheck
PiperOrigin-RevId: 495271080
This commit is contained in:
commit
5832dfd812
@ -243,7 +243,7 @@ class Error:
|
||||
return None
|
||||
|
||||
def throw(self):
|
||||
check_error(self)
|
||||
_check_error(self)
|
||||
|
||||
def __str__(self):
|
||||
return f'Error({self.get()})'
|
||||
@ -605,7 +605,7 @@ def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
|
||||
Before staging a function with checks, :func:`~checkify` it!
|
||||
|
||||
Args:
|
||||
pred: if False, an error is added.
|
||||
pred: if False, a FailedCheckError error is added.
|
||||
msg: error message if error is added. Can be a format string.
|
||||
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
|
||||
`msg`, eg.:
|
||||
@ -631,11 +631,23 @@ def check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
|
||||
jax._src.checkify.JaxRuntimeError: -3. needs to be positive!
|
||||
|
||||
"""
|
||||
_check(pred, msg, False, *fmt_args, **fmt_kwargs)
|
||||
|
||||
def _check(pred, msg, debug, *fmt_args, **fmt_kwargs):
|
||||
if not is_scalar_pred(pred):
|
||||
raise TypeError(f'check takes a scalar pred as argument, got {pred}')
|
||||
prim_name = 'debug_check' if debug else 'check'
|
||||
raise TypeError(f'{prim_name} takes a scalar pred as argument, got {pred}')
|
||||
new_error = FailedCheckError(summary(), msg, *fmt_args, **fmt_kwargs)
|
||||
error = assert_func(init_error, jnp.logical_not(pred), new_error)
|
||||
return check_error(error)
|
||||
_check_error(error, debug=debug)
|
||||
|
||||
def _check_error(error, *, debug=False):
|
||||
error = tree_map(core.raise_as_much_as_possible, error)
|
||||
if any(map(np.shape, error._pred.values())):
|
||||
error = _reduce_any_error(error)
|
||||
err_args, tree_def = tree_flatten(error)
|
||||
|
||||
return check_p.bind(*err_args, err_tree=tree_def, debug=debug)
|
||||
|
||||
|
||||
def is_scalar_pred(pred) -> bool:
|
||||
@ -643,6 +655,44 @@ def is_scalar_pred(pred) -> bool:
|
||||
isinstance(pred, jnp.ndarray) and pred.shape == () and
|
||||
pred.dtype == jnp.dtype('bool'))
|
||||
|
||||
|
||||
def debug_check(pred: Bool, msg: str, *fmt_args, **fmt_kwargs) -> None:
|
||||
"""Check a predicate when running under checkify, otherwise is a no-op.
|
||||
|
||||
A `debug_check` will only be run if it is transformed by :func:`~checkify`,
|
||||
otherwise the check will be dropped.
|
||||
|
||||
Args:
|
||||
pred: if False, a FailedCheckError error is added.
|
||||
msg: error message if error is added.
|
||||
fmt_args, fmt_kwargs: Positional and keyword formatting arguments for
|
||||
`msg`, eg.:
|
||||
``debug_check(.., "check failed on values {} and {named}", x, named=y)``
|
||||
Note that these arguments can be traced values allowing you to add
|
||||
run-time values to the error message.
|
||||
Note that tracking these run-time arrays will increase your memory usage,
|
||||
even if no error happens.
|
||||
|
||||
For example:
|
||||
|
||||
>>> import jax
|
||||
>>> import jax.numpy as jnp
|
||||
>>> from jax.experimental import checkify
|
||||
>>> def f(x):
|
||||
... checkify.debug_check(x!=0, "cannot be zero!")
|
||||
... return x
|
||||
>>> _ = f(0) # running without checkify means no debug_check is run.
|
||||
>>> checked_f = checkify.checkify(f)
|
||||
>>> err, out = jax.jit(checked_f)(0) # running with checkify runs debug_check.
|
||||
>>> err.throw() # doctest: +IGNORE_EXCEPTION_DETAIL
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
jax._src.checkify.JaxRuntimeError: cannot be zero!
|
||||
|
||||
"""
|
||||
_check(pred, msg, True, *fmt_args, **fmt_kwargs)
|
||||
|
||||
|
||||
def check_error(error: Error) -> None:
|
||||
"""Raise an Exception if ``error`` represents a failure. Functionalized by :func:`~checkify`.
|
||||
|
||||
@ -708,12 +758,8 @@ def check_error(error: Error) -> None:
|
||||
raise ValueError('check_error takes an Error as argument, '
|
||||
f'got type {type(error)} instead.')
|
||||
|
||||
error = tree_map(core.raise_as_much_as_possible, error)
|
||||
if any(map(np.shape, error._pred.values())):
|
||||
error = _reduce_any_error(error)
|
||||
err_args, tree_def = tree_flatten(error)
|
||||
_check_error(error, debug=False)
|
||||
|
||||
return check_p.bind(*err_args, err_tree=tree_def)
|
||||
|
||||
## check primitive
|
||||
|
||||
@ -725,7 +771,10 @@ class JaxRuntimeError(ValueError):
|
||||
pass
|
||||
|
||||
@check_p.def_impl
|
||||
def check_impl(*args, err_tree):
|
||||
def check_impl(*args, err_tree, debug):
|
||||
if debug:
|
||||
# NOOP (check will only trigger when discharged)
|
||||
return []
|
||||
error = tree_unflatten(err_tree, args)
|
||||
exc = error.get_exception()
|
||||
if exc:
|
||||
@ -733,7 +782,8 @@ def check_impl(*args, err_tree):
|
||||
return []
|
||||
|
||||
@check_p.def_effectful_abstract_eval
|
||||
def check_abstract_eval(*args, err_tree):
|
||||
def check_abstract_eval(*args, err_tree, debug):
|
||||
del debug
|
||||
return [], set(tree_unflatten(err_tree, args)._pred.keys())
|
||||
|
||||
# TODO(lenamartens) add in-depth error explanation to link to in module docs.
|
||||
@ -744,7 +794,10 @@ functionalization_error = ValueError(
|
||||
' through `checkify.checkify`.'
|
||||
)
|
||||
|
||||
def check_lowering_rule(ctx, *args, err_tree):
|
||||
def check_lowering_rule(ctx, *args, err_tree, debug):
|
||||
if debug:
|
||||
# NOOP (check will only trigger when discharged)
|
||||
return []
|
||||
if not config.jax_experimental_unsafe_xla_runtime_errors:
|
||||
raise functionalization_error
|
||||
|
||||
@ -758,12 +811,14 @@ def check_lowering_rule(ctx, *args, err_tree):
|
||||
ctx.module_context.add_keepalive(keep_alive)
|
||||
return out_op
|
||||
|
||||
def check_lowering_rule_unsupported(*a, **k):
|
||||
def check_lowering_rule_unsupported(*a, debug, **k):
|
||||
if debug:
|
||||
return []
|
||||
raise functionalization_error
|
||||
|
||||
def python_err(err_tree, *args):
|
||||
error = tree_unflatten(err_tree, args)
|
||||
check_error(error)
|
||||
_check_error(error)
|
||||
return []
|
||||
|
||||
mlir.register_lowering(check_p, check_lowering_rule_unsupported,
|
||||
@ -773,22 +828,20 @@ mlir.register_lowering(check_p, check_lowering_rule,
|
||||
mlir.register_lowering(check_p, check_lowering_rule,
|
||||
platform='gpu')
|
||||
|
||||
def check_batching_rule(batched_args, batch_dims, *, err_tree):
|
||||
def check_batching_rule(batched_args, batch_dims, *, err_tree, debug):
|
||||
size = next(x.shape[dim] for x, dim in zip(batched_args, batch_dims)
|
||||
if dim is not batching.not_mapped)
|
||||
batched_args = (batching.bdim_at_front(a, d, size)
|
||||
for a, d in zip(batched_args, batch_dims))
|
||||
err = tree_unflatten(err_tree, batched_args)
|
||||
check_error(err)
|
||||
_check_error(err, debug=debug)
|
||||
return [], []
|
||||
|
||||
batching.primitive_batchers[check_p] = check_batching_rule
|
||||
|
||||
def check_jvp_rule(primals, _, *, err_tree):
|
||||
def check_jvp_rule(primals, _, *, err_tree, debug):
|
||||
# Check primals, discard tangents.
|
||||
check_p.bind(*primals, err_tree=err_tree)
|
||||
check_p.bind(*primals, err_tree=err_tree, debug=debug)
|
||||
return [], []
|
||||
|
||||
ad.primitive_jvps[check_p] = check_jvp_rule
|
||||
|
||||
## checkify rules
|
||||
@ -1106,7 +1159,8 @@ def pjit_error_check(error, enabled_errors, *vals_in, jaxpr,
|
||||
error_checks[pjit.pjit_p] = pjit_error_check
|
||||
|
||||
|
||||
def check_discharge_rule(error, enabled_errors, *args, err_tree):
|
||||
def check_discharge_rule(error, enabled_errors, *args, err_tree, debug):
|
||||
del debug
|
||||
new_error = tree_unflatten(err_tree, args)
|
||||
# Split up new_error into error to be functionalized if it's included in
|
||||
# enabled_errors (=discharged_error) and an error to be defunctionalized if
|
||||
|
@ -21,6 +21,7 @@ from jax._src.checkify import (
|
||||
check as check,
|
||||
check_error as check_error,
|
||||
checkify as checkify,
|
||||
debug_check as debug_check,
|
||||
div_checks as div_checks,
|
||||
float_checks as float_checks,
|
||||
index_checks as index_checks,
|
||||
|
@ -1092,6 +1092,64 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "hi!")
|
||||
|
||||
def test_debug_check_noop(self):
|
||||
def f(x):
|
||||
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
|
||||
return x
|
||||
x = jnp.ones(())
|
||||
f(x) # no error.
|
||||
jax.jit(f)(x) # no error.
|
||||
jax.vmap(f)(jnp.ones((2,))) # no error.
|
||||
jax.grad(f)(x) # no error.
|
||||
|
||||
@parameterized.named_parameters(("with_jit", True), ("without_jit", False))
|
||||
def test_debug_check_nonscalar_pred(self, with_jit):
|
||||
def f(x):
|
||||
checkify.debug_check(x != x, "{x} cannot be {x}", x=x)
|
||||
return x
|
||||
checked_f = checkify.checkify(f)
|
||||
if with_jit:
|
||||
checked_f = jax.jit(checked_f)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "debug_check takes a scalar pred"):
|
||||
checked_f(jnp.ones((5,)))
|
||||
|
||||
|
||||
@parameterized.named_parameters(("with_jit", True), ("without_jit", False))
|
||||
def test_debug_check(self, with_jit):
|
||||
def f(x):
|
||||
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
|
||||
return x
|
||||
checked_f = checkify.checkify(f)
|
||||
if with_jit:
|
||||
checked_f = jax.jit(checked_f)
|
||||
err, _ = checked_f(jnp.ones(()))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "1.0 cannot be 1.0")
|
||||
|
||||
@parameterized.named_parameters(("with_jit", True), ("without_jit", False))
|
||||
def test_debug_check_disabled_errors(self, with_jit):
|
||||
def f(x):
|
||||
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
|
||||
return x
|
||||
checked_f = checkify.checkify(f, errors={})
|
||||
if with_jit:
|
||||
checked_f = jax.jit(checked_f)
|
||||
err, _ = checked_f(jnp.ones((1,)))
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
def test_debug_check_jaxpr_roundtrip(self):
|
||||
def f(x):
|
||||
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
|
||||
return x
|
||||
x = jnp.ones(())
|
||||
jaxpr = jax.make_jaxpr(f)(x)
|
||||
roundtrip_f = partial(jax.core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)
|
||||
checked_f = checkify.checkify(jax.jit(roundtrip_f))
|
||||
err, _ = checked_f(jnp.ones(()))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "1.0 cannot be 1.0")
|
||||
|
||||
|
||||
class LowerableChecksTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user