diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index d36908234..ba7fd96d8 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -29,11 +29,15 @@ from jax.interpreters import partial_eval as pe from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node from jax._src import source_info_util, traceback_util from jax import lax -from jax._src.util import as_hashable_function, unzip2, split_list +from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map, + safe_zip) source_info_util.register_exclusion(__file__) traceback_util.register_exclusion(__file__) +map, unsafe_map = safe_map, map +zip, unsafe_zip = safe_zip, zip + ## Utils @@ -112,7 +116,8 @@ class CheckifyTrace(core.Trace): in_vals = [t.val for t in tracers] rule = error_checks.get(primitive) if rule: - out, self.main.error = rule(self.main.error, self.main.enabled_errors, *in_vals, **params) # type: ignore + out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore + *in_vals, **params) else: out = primitive.bind(*in_vals, **params) if primitive.multiple_results: @@ -149,22 +154,25 @@ class CheckifyTrace(core.Trace): def post_process_call(self, primitive, tracers, params): vals = [t.val for t in tracers] main = self.main - e = popattr(self.main, 'error') + e = popattr(main, 'error') err, code, main.msgs = e.err, e.code, e.msgs def todo(vals): - trace = main.with_cur_sublevel() err, code, *vals = vals + setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'))) + trace = main.with_cur_sublevel() return [CheckifyTracer(trace, x) for x in vals] return (err, code, *vals), todo def post_process_map(self, primitive, tracers, params): vals = [t.val for t in tracers] main = self.main - e = popattr(self.main, 'error') + e = popattr(main, 'error') err, code, main.msgs = e.err, e.code, e.msgs def todo(vals): + errs, codes, *vals = vals + err, code = _reduce_any_error(errs, codes) + setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs'))) trace = main.with_cur_sublevel() - err, code, *vals = vals return [CheckifyTracer(trace, x) for x in vals] def out_axes_transform(out_axes): return (0, 0, *out_axes) @@ -174,10 +182,11 @@ def _reduce_any_error(errs, codes): errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0) return errs_[-1], codes_[-1] -ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error) +ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error) error_checks: Dict[core.Primitive, ErrorCheckRule] = {} -def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], *args): +def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], + *args): fun, msgs = checkify_subtrace(fun) fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors) err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args) @@ -341,7 +350,8 @@ error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p) def cond_error_check(error, enabled_errors, index, *ops, branches, linear): - new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) for jxpr in branches) + new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) + for jxpr in branches) new_linear = (False, False, *linear) err, code, *outs = lax.cond_p.bind( index, error.err, error.code, *ops, @@ -350,7 +360,8 @@ def cond_error_check(error, enabled_errors, index, *ops, branches, linear): return outs, Error(err, code, new_msgs) error_checks[lax.cond_p] = cond_error_check -def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll): +def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, + num_consts, num_carry, linear, unroll): consts, carry, xs = split_list(in_flat, [num_consts, num_carry]) checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors) new_linear = (False, False, *linear) @@ -371,7 +382,8 @@ def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors): out = body_f(*vals) _ = cond_f(*out) # this checks if the next cond application will error return out - return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals) + return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, + body_jaxpr.in_avals) def ignore_errors_jaxpr(jaxpr, error): """Constructs a jaxpr which takes two extra args but ignores them.""" @@ -385,13 +397,15 @@ def ignore_errors_jaxpr(jaxpr, error): jaxpr.outvars, jaxpr.eqns) return core.ClosedJaxpr(new_jaxpr, consts) -def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr): - checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors) - checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr) +def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, + cond_jaxpr, body_nconsts, body_jaxpr): + cond_jaxpr_, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors) + checked_cond_fun = core.jaxpr_as_fun(cond_jaxpr_) # Check if the first cond application will error. cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat) - checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors) + checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr( + cond_jaxpr, body_jaxpr, error, enabled_errors) compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error) c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts]) new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry] @@ -489,7 +503,8 @@ automatic_errors = float_errors | index_errors user_asserts = {ErrorCategory.ASSERT} Out = TypeVar('Out') -def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts) -> Callable[..., Tuple[Error, Out]]: +def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts + ) -> Callable[..., Tuple[Error, Out]]: if not errors: raise ValueError('Checkify needs to be called with at least one enabled' ' ErrorCategory, was called with an empty errors set.') diff --git a/tests/checkify_test.py b/tests/checkify_test.py index 0a9e1238f..e0a35a52e 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from functools import partial import unittest from absl.testing import absltest @@ -192,7 +193,6 @@ class CheckifyTransformTests(jtu.JaxTestCase): err, y = checked_f(-jnp.inf) self.assertIs(err.get(), None) - @jtu.skip_on_devices('tpu') def test_scan_map(self): def scan_body(_, x): @@ -400,6 +400,30 @@ class CheckifyTransformTests(jtu.JaxTestCase): self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), expected_error) + @jtu.skip_on_devices('tpu') + def test_post_process_call(self): + @partial(checkify.checkify, errors=checkify.float_errors) + def g(x): + @jax.jit + def f(y): + return jnp.sin(x * y) + return f(jnp.inf) + err, _ = g(2.) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), 'nan generated by primitive sin') + + @jtu.skip_on_devices('tpu') + def test_post_process_map(self): + @partial(checkify.checkify, errors=checkify.float_errors) + def g(x): + @jax.pmap + def f(y): + return jnp.sin(x * y) + return f(jnp.array([jnp.inf]))[0] + err, _ = g(2.) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), 'nan generated by primitive sin') + class AssertPrimitiveTests(jtu.JaxTestCase): def test_assert_primitive_impl(self):