diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index 4365f4c64..39d314f77 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -15,7 +15,7 @@ from dataclasses import dataclass from functools import partial import itertools as it -from typing import Union, Optional, Callable, Dict +from typing import Union, Optional, Callable, Dict, Tuple, TypeVar import numpy as np @@ -86,7 +86,7 @@ def assert_func(error: Error, pred: Bool, msg: str) -> Error: ## Checkify transformation for plumbing functional error values. -class ErrorTracer(core.Tracer): +class CheckifyTracer(core.Tracer): def __init__(self, trace, val): self._trace = trace self.val = val @@ -94,11 +94,11 @@ class ErrorTracer(core.Tracer): aval = property(lambda self: core.get_aval(self.val)) full_lower = lambda self: self -class ErrorTrace(core.Trace): - pure = lift = lambda self, val: ErrorTracer(self, val) +class CheckifyTrace(core.Trace): + pure = lift = lambda self, val: CheckifyTracer(self, val) def sublift(self, tracer): - return ErrorTracer(self, tracer.val) + return CheckifyTracer(self, tracer.val) def process_primitive(self, primitive, tracers, params): in_vals = [t.val for t in tracers] @@ -108,23 +108,23 @@ class ErrorTrace(core.Trace): else: out = primitive.bind(*in_vals, **params) if primitive.multiple_results: - return [ErrorTracer(self, x) for x in out] + return [CheckifyTracer(self, x) for x in out] else: - return ErrorTracer(self, out) + return CheckifyTracer(self, out) def process_call(self, primitive, f, tracers, params): in_vals = [t.val for t in tracers] e = popattr(self.main, 'error') - f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items())) + f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items())) params_ = dict(params, donated_invars=(False, False, *params['donated_invars'])) err, code, *out_vals = primitive.bind(f, e.err, e.code, *in_vals, **params_) setnewattr(self.main, 'error', Error(err, code, msgs())) - return [ErrorTracer(self, x) for x in out_vals] + return [CheckifyTracer(self, x) for x in out_vals] def process_map(self, primitive, f, tracers, params): in_vals = [t.val for t in tracers] e = popattr(self.main, 'error') - f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items())) + f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items())) @as_hashable_function(closure=params['out_axes_thunk']) def new_out_axes_thunk(): @@ -136,7 +136,7 @@ class ErrorTrace(core.Trace): errs, codes, *outs = primitive.bind(f, e.err, e.code, *in_vals, **params_) err, code = _reduce_any_error(errs, codes) setnewattr(self.main, 'error', Error(err, code, msgs())) - return [ErrorTracer(self, x) for x in outs] + return [CheckifyTracer(self, x) for x in outs] def post_process_call(self, primitive, tracers, params): vals = [t.val for t in tracers] @@ -146,7 +146,7 @@ class ErrorTrace(core.Trace): def todo(vals): trace = main.with_cur_sublevel() err, code, *vals = vals - return [ErrorTracer(trace, x) for x in vals] + return [CheckifyTracer(trace, x) for x in vals] return (err, code, *vals), todo def post_process_map(self, primitive, tracers, params): @@ -157,7 +157,7 @@ class ErrorTrace(core.Trace): def todo(vals): trace = main.with_cur_sublevel() err, code, *vals = vals - return [ErrorTracer(trace, x) for x in vals] + return [CheckifyTracer(trace, x) for x in vals] def out_axes_transform(out_axes): return (0, 0, *out_axes) return (err, code, *vals), (todo, out_axes_transform) @@ -169,26 +169,24 @@ def _reduce_any_error(errs, codes): ErrorCheckRule = Callable error_checks: Dict[core.Primitive, ErrorCheckRule] = {} -def check_errors_flat(fun: lu.WrappedFun, *args): - fun, msgs = check_errors_subtrace(fun) - fun = check_errors_toplevel(fun) - err, code, *out_vals = fun.call_wrapped(*args) - return (err, code, out_vals), msgs() +def checkify_flat(fun: lu.WrappedFun, *args): + fun, msgs = checkify_subtrace(fun) + fun = checkify_traceable(fun, tuple(init_error.msgs.items())) + err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args) + return (err, code, outvals), msgs() @lu.transformation -def check_errors_toplevel(*args): - error = init_error - with core.new_main(ErrorTrace) as main: - msgs = tuple(error.msgs.items()) - outs = yield (main, msgs, error.err, error.code, *args), {} +def checkify_traceable(msgs, err, code, *args): + with core.new_main(CheckifyTrace) as main: + outs = yield (main, msgs, err, code, *args), {} del main yield outs @lu.transformation_with_aux -def check_errors_subtrace(main, msgs, err, code, *args): +def checkify_subtrace(main, msgs, err, code, *args): setnewattr(main, 'error', Error(err, code, dict(msgs))) trace = main.with_cur_sublevel() - in_tracers = [ErrorTracer(trace, x) for x in args] + in_tracers = [CheckifyTracer(trace, x) for x in args] out = yield in_tracers, {} out_tracers = map(trace.full_raise, out) out_vals = [t.val for t in out_tracers] @@ -196,27 +194,20 @@ def check_errors_subtrace(main, msgs, err, code, *args): del main.error yield (err, code, *out_vals), msgs -def checkify_fun_to_jaxpr(f, error, in_avals): - f, msgs = check_errors_subtrace(f) - f = check_errors_traceable(f, tuple(error.msgs.items())) - err_aval = core.raise_to_shaped(core.get_aval(error.err)) - code_aval = core.raise_to_shaped(core.get_aval(error.code)) - avals_in = [err_aval, code_aval, *in_avals] - jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) - return core.ClosedJaxpr(jaxpr_out, literals_out), msgs() # TODO take (error_aval, code_aval) instead of error here? def checkify_jaxpr(jaxpr, error): f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals) -# TODO dedup with check_errors_toplevel -@lu.transformation -def check_errors_traceable(msgs, err, code, *args): - with core.new_main(ErrorTrace) as main: - outs = yield (main, msgs, err, code, *args), {} - del main - yield outs +def checkify_fun_to_jaxpr(f, error, in_avals): + f, msgs = checkify_subtrace(f) + f = checkify_traceable(f, tuple(error.msgs.items())) + err_aval = core.raise_to_shaped(core.get_aval(error.err)) + code_aval = core.raise_to_shaped(core.get_aval(error.code)) + avals_in = [err_aval, code_aval, *in_avals] + jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in) + return core.ClosedJaxpr(jaxpr_out, literals_out), msgs() ## assert primitive @@ -398,30 +389,6 @@ def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_ncons return out, Error(err, code, new_msgs) error_checks[lax.while_p] = while_loop_error_check -# TODO(mattjj,lenamartens): currently we bundle effectful-assert-discharging -# with the error-check-adding transformation (checkify), but they could be -# separated into two orthogonal transformations. -def assert_discharge_rule(error, pred, code, *, msgs): - out_err = error.err | jnp.logical_not(pred) - out_code = lax.select(error.err, error.code, code) - return [], Error(out_err, out_code, {**error.msgs, **msgs}) -error_checks[assert_p] = assert_discharge_rule - - -## checkify api - -def checkify(fun: Callable) -> Callable: - @traceback_util.api_boundary - def checked_fun(*args, **kwargs): - args_flat, in_tree = tree_flatten((args, kwargs)) - f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) - (err, code, out_flat), msgs = check_errors_flat(f, *args_flat) - out = tree_unflatten(out_tree(), out_flat) - return Error(err, code, msgs), out - return checked_fun - -## NaN error rule table - def add_nan_check(prim): error_checks[prim] = partial(nan_error_check, prim) @@ -485,3 +452,23 @@ add_nan_check(lax.abs_p) add_nan_check(lax.select_p) add_nan_check(lax.max_p) add_nan_check(lax.min_p) + +def assert_discharge_rule(error, pred, code, *, msgs): + out_err = error.err | jnp.logical_not(pred) + out_code = lax.select(error.err, error.code, code) + return [], Error(out_err, out_code, {**error.msgs, **msgs}) +error_checks[assert_p] = assert_discharge_rule + + +## checkify api + +Out = TypeVar('Out') +def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]: + @traceback_util.api_boundary + def checked_fun(*args, **kwargs): + args_flat, in_tree = tree_flatten((args, kwargs)) + f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree) + (err, code, out_flat), msgs = checkify_flat(f, *args_flat) + out = tree_unflatten(out_tree(), out_flat) + return Error(err, code, msgs), out + return checked_fun