checkify: tweak some organization and names

This commit is contained in:
Matthew Johnson 2022-01-10 21:29:12 -08:00
parent 7bc51879d4
commit 6850833c3a

View File

@ -15,7 +15,7 @@
from dataclasses import dataclass
from functools import partial
import itertools as it
from typing import Union, Optional, Callable, Dict
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar
import numpy as np
@ -86,7 +86,7 @@ def assert_func(error: Error, pred: Bool, msg: str) -> Error:
## Checkify transformation for plumbing functional error values.
class ErrorTracer(core.Tracer):
class CheckifyTracer(core.Tracer):
def __init__(self, trace, val):
self._trace = trace
self.val = val
@ -94,11 +94,11 @@ class ErrorTracer(core.Tracer):
aval = property(lambda self: core.get_aval(self.val))
full_lower = lambda self: self
class ErrorTrace(core.Trace):
pure = lift = lambda self, val: ErrorTracer(self, val)
class CheckifyTrace(core.Trace):
pure = lift = lambda self, val: CheckifyTracer(self, val)
def sublift(self, tracer):
return ErrorTracer(self, tracer.val)
return CheckifyTracer(self, tracer.val)
def process_primitive(self, primitive, tracers, params):
in_vals = [t.val for t in tracers]
@ -108,23 +108,23 @@ class ErrorTrace(core.Trace):
else:
out = primitive.bind(*in_vals, **params)
if primitive.multiple_results:
return [ErrorTracer(self, x) for x in out]
return [CheckifyTracer(self, x) for x in out]
else:
return ErrorTracer(self, out)
return CheckifyTracer(self, out)
def process_call(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items()))
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
params_ = dict(params, donated_invars=(False, False, *params['donated_invars']))
err, code, *out_vals = primitive.bind(f, e.err, e.code, *in_vals, **params_)
setnewattr(self.main, 'error', Error(err, code, msgs()))
return [ErrorTracer(self, x) for x in out_vals]
return [CheckifyTracer(self, x) for x in out_vals]
def process_map(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error')
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items()))
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
@as_hashable_function(closure=params['out_axes_thunk'])
def new_out_axes_thunk():
@ -136,7 +136,7 @@ class ErrorTrace(core.Trace):
errs, codes, *outs = primitive.bind(f, e.err, e.code, *in_vals, **params_)
err, code = _reduce_any_error(errs, codes)
setnewattr(self.main, 'error', Error(err, code, msgs()))
return [ErrorTracer(self, x) for x in outs]
return [CheckifyTracer(self, x) for x in outs]
def post_process_call(self, primitive, tracers, params):
vals = [t.val for t in tracers]
@ -146,7 +146,7 @@ class ErrorTrace(core.Trace):
def todo(vals):
trace = main.with_cur_sublevel()
err, code, *vals = vals
return [ErrorTracer(trace, x) for x in vals]
return [CheckifyTracer(trace, x) for x in vals]
return (err, code, *vals), todo
def post_process_map(self, primitive, tracers, params):
@ -157,7 +157,7 @@ class ErrorTrace(core.Trace):
def todo(vals):
trace = main.with_cur_sublevel()
err, code, *vals = vals
return [ErrorTracer(trace, x) for x in vals]
return [CheckifyTracer(trace, x) for x in vals]
def out_axes_transform(out_axes):
return (0, 0, *out_axes)
return (err, code, *vals), (todo, out_axes_transform)
@ -169,26 +169,24 @@ def _reduce_any_error(errs, codes):
ErrorCheckRule = Callable
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
def check_errors_flat(fun: lu.WrappedFun, *args):
fun, msgs = check_errors_subtrace(fun)
fun = check_errors_toplevel(fun)
err, code, *out_vals = fun.call_wrapped(*args)
return (err, code, out_vals), msgs()
def checkify_flat(fun: lu.WrappedFun, *args):
fun, msgs = checkify_subtrace(fun)
fun = checkify_traceable(fun, tuple(init_error.msgs.items()))
err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
return (err, code, outvals), msgs()
@lu.transformation
def check_errors_toplevel(*args):
error = init_error
with core.new_main(ErrorTrace) as main:
msgs = tuple(error.msgs.items())
outs = yield (main, msgs, error.err, error.code, *args), {}
def checkify_traceable(msgs, err, code, *args):
with core.new_main(CheckifyTrace) as main:
outs = yield (main, msgs, err, code, *args), {}
del main
yield outs
@lu.transformation_with_aux
def check_errors_subtrace(main, msgs, err, code, *args):
def checkify_subtrace(main, msgs, err, code, *args):
setnewattr(main, 'error', Error(err, code, dict(msgs)))
trace = main.with_cur_sublevel()
in_tracers = [ErrorTracer(trace, x) for x in args]
in_tracers = [CheckifyTracer(trace, x) for x in args]
out = yield in_tracers, {}
out_tracers = map(trace.full_raise, out)
out_vals = [t.val for t in out_tracers]
@ -196,27 +194,20 @@ def check_errors_subtrace(main, msgs, err, code, *args):
del main.error
yield (err, code, *out_vals), msgs
def checkify_fun_to_jaxpr(f, error, in_avals):
f, msgs = check_errors_subtrace(f)
f = check_errors_traceable(f, tuple(error.msgs.items()))
err_aval = core.raise_to_shaped(core.get_aval(error.err))
code_aval = core.raise_to_shaped(core.get_aval(error.code))
avals_in = [err_aval, code_aval, *in_avals]
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
# TODO take (error_aval, code_aval) instead of error here?
def checkify_jaxpr(jaxpr, error):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
# TODO dedup with check_errors_toplevel
@lu.transformation
def check_errors_traceable(msgs, err, code, *args):
with core.new_main(ErrorTrace) as main:
outs = yield (main, msgs, err, code, *args), {}
del main
yield outs
def checkify_fun_to_jaxpr(f, error, in_avals):
f, msgs = checkify_subtrace(f)
f = checkify_traceable(f, tuple(error.msgs.items()))
err_aval = core.raise_to_shaped(core.get_aval(error.err))
code_aval = core.raise_to_shaped(core.get_aval(error.code))
avals_in = [err_aval, code_aval, *in_avals]
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
## assert primitive
@ -398,30 +389,6 @@ def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_ncons
return out, Error(err, code, new_msgs)
error_checks[lax.while_p] = while_loop_error_check
# TODO(mattjj,lenamartens): currently we bundle effectful-assert-discharging
# with the error-check-adding transformation (checkify), but they could be
# separated into two orthogonal transformations.
def assert_discharge_rule(error, pred, code, *, msgs):
out_err = error.err | jnp.logical_not(pred)
out_code = lax.select(error.err, error.code, code)
return [], Error(out_err, out_code, {**error.msgs, **msgs})
error_checks[assert_p] = assert_discharge_rule
## checkify api
def checkify(fun: Callable) -> Callable:
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
(err, code, out_flat), msgs = check_errors_flat(f, *args_flat)
out = tree_unflatten(out_tree(), out_flat)
return Error(err, code, msgs), out
return checked_fun
## NaN error rule table
def add_nan_check(prim):
error_checks[prim] = partial(nan_error_check, prim)
@ -485,3 +452,23 @@ add_nan_check(lax.abs_p)
add_nan_check(lax.select_p)
add_nan_check(lax.max_p)
add_nan_check(lax.min_p)
def assert_discharge_rule(error, pred, code, *, msgs):
out_err = error.err | jnp.logical_not(pred)
out_code = lax.select(error.err, error.code, code)
return [], Error(out_err, out_code, {**error.msgs, **msgs})
error_checks[assert_p] = assert_discharge_rule
## checkify api
Out = TypeVar('Out')
def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]:
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
(err, code, out_flat), msgs = checkify_flat(f, *args_flat)
out = tree_unflatten(out_tree(), out_flat)
return Error(err, code, msgs), out
return checked_fun