checkify: tweak some organization and names

This commit is contained in:
Matthew Johnson 2022-01-10 21:29:12 -08:00
parent 7bc51879d4
commit 6850833c3a

View File

@ -15,7 +15,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
import itertools as it import itertools as it
from typing import Union, Optional, Callable, Dict from typing import Union, Optional, Callable, Dict, Tuple, TypeVar
import numpy as np import numpy as np
@ -86,7 +86,7 @@ def assert_func(error: Error, pred: Bool, msg: str) -> Error:
## Checkify transformation for plumbing functional error values. ## Checkify transformation for plumbing functional error values.
class ErrorTracer(core.Tracer): class CheckifyTracer(core.Tracer):
def __init__(self, trace, val): def __init__(self, trace, val):
self._trace = trace self._trace = trace
self.val = val self.val = val
@ -94,11 +94,11 @@ class ErrorTracer(core.Tracer):
aval = property(lambda self: core.get_aval(self.val)) aval = property(lambda self: core.get_aval(self.val))
full_lower = lambda self: self full_lower = lambda self: self
class ErrorTrace(core.Trace): class CheckifyTrace(core.Trace):
pure = lift = lambda self, val: ErrorTracer(self, val) pure = lift = lambda self, val: CheckifyTracer(self, val)
def sublift(self, tracer): def sublift(self, tracer):
return ErrorTracer(self, tracer.val) return CheckifyTracer(self, tracer.val)
def process_primitive(self, primitive, tracers, params): def process_primitive(self, primitive, tracers, params):
in_vals = [t.val for t in tracers] in_vals = [t.val for t in tracers]
@ -108,23 +108,23 @@ class ErrorTrace(core.Trace):
else: else:
out = primitive.bind(*in_vals, **params) out = primitive.bind(*in_vals, **params)
if primitive.multiple_results: if primitive.multiple_results:
return [ErrorTracer(self, x) for x in out] return [CheckifyTracer(self, x) for x in out]
else: else:
return ErrorTracer(self, out) return CheckifyTracer(self, out)
def process_call(self, primitive, f, tracers, params): def process_call(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers] in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error') e = popattr(self.main, 'error')
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items())) f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
params_ = dict(params, donated_invars=(False, False, *params['donated_invars'])) params_ = dict(params, donated_invars=(False, False, *params['donated_invars']))
err, code, *out_vals = primitive.bind(f, e.err, e.code, *in_vals, **params_) err, code, *out_vals = primitive.bind(f, e.err, e.code, *in_vals, **params_)
setnewattr(self.main, 'error', Error(err, code, msgs())) setnewattr(self.main, 'error', Error(err, code, msgs()))
return [ErrorTracer(self, x) for x in out_vals] return [CheckifyTracer(self, x) for x in out_vals]
def process_map(self, primitive, f, tracers, params): def process_map(self, primitive, f, tracers, params):
in_vals = [t.val for t in tracers] in_vals = [t.val for t in tracers]
e = popattr(self.main, 'error') e = popattr(self.main, 'error')
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items())) f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
@as_hashable_function(closure=params['out_axes_thunk']) @as_hashable_function(closure=params['out_axes_thunk'])
def new_out_axes_thunk(): def new_out_axes_thunk():
@ -136,7 +136,7 @@ class ErrorTrace(core.Trace):
errs, codes, *outs = primitive.bind(f, e.err, e.code, *in_vals, **params_) errs, codes, *outs = primitive.bind(f, e.err, e.code, *in_vals, **params_)
err, code = _reduce_any_error(errs, codes) err, code = _reduce_any_error(errs, codes)
setnewattr(self.main, 'error', Error(err, code, msgs())) setnewattr(self.main, 'error', Error(err, code, msgs()))
return [ErrorTracer(self, x) for x in outs] return [CheckifyTracer(self, x) for x in outs]
def post_process_call(self, primitive, tracers, params): def post_process_call(self, primitive, tracers, params):
vals = [t.val for t in tracers] vals = [t.val for t in tracers]
@ -146,7 +146,7 @@ class ErrorTrace(core.Trace):
def todo(vals): def todo(vals):
trace = main.with_cur_sublevel() trace = main.with_cur_sublevel()
err, code, *vals = vals err, code, *vals = vals
return [ErrorTracer(trace, x) for x in vals] return [CheckifyTracer(trace, x) for x in vals]
return (err, code, *vals), todo return (err, code, *vals), todo
def post_process_map(self, primitive, tracers, params): def post_process_map(self, primitive, tracers, params):
@ -157,7 +157,7 @@ class ErrorTrace(core.Trace):
def todo(vals): def todo(vals):
trace = main.with_cur_sublevel() trace = main.with_cur_sublevel()
err, code, *vals = vals err, code, *vals = vals
return [ErrorTracer(trace, x) for x in vals] return [CheckifyTracer(trace, x) for x in vals]
def out_axes_transform(out_axes): def out_axes_transform(out_axes):
return (0, 0, *out_axes) return (0, 0, *out_axes)
return (err, code, *vals), (todo, out_axes_transform) return (err, code, *vals), (todo, out_axes_transform)
@ -169,26 +169,24 @@ def _reduce_any_error(errs, codes):
ErrorCheckRule = Callable ErrorCheckRule = Callable
error_checks: Dict[core.Primitive, ErrorCheckRule] = {} error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
def check_errors_flat(fun: lu.WrappedFun, *args): def checkify_flat(fun: lu.WrappedFun, *args):
fun, msgs = check_errors_subtrace(fun) fun, msgs = checkify_subtrace(fun)
fun = check_errors_toplevel(fun) fun = checkify_traceable(fun, tuple(init_error.msgs.items()))
err, code, *out_vals = fun.call_wrapped(*args) err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
return (err, code, out_vals), msgs() return (err, code, outvals), msgs()
@lu.transformation @lu.transformation
def check_errors_toplevel(*args): def checkify_traceable(msgs, err, code, *args):
error = init_error with core.new_main(CheckifyTrace) as main:
with core.new_main(ErrorTrace) as main: outs = yield (main, msgs, err, code, *args), {}
msgs = tuple(error.msgs.items())
outs = yield (main, msgs, error.err, error.code, *args), {}
del main del main
yield outs yield outs
@lu.transformation_with_aux @lu.transformation_with_aux
def check_errors_subtrace(main, msgs, err, code, *args): def checkify_subtrace(main, msgs, err, code, *args):
setnewattr(main, 'error', Error(err, code, dict(msgs))) setnewattr(main, 'error', Error(err, code, dict(msgs)))
trace = main.with_cur_sublevel() trace = main.with_cur_sublevel()
in_tracers = [ErrorTracer(trace, x) for x in args] in_tracers = [CheckifyTracer(trace, x) for x in args]
out = yield in_tracers, {} out = yield in_tracers, {}
out_tracers = map(trace.full_raise, out) out_tracers = map(trace.full_raise, out)
out_vals = [t.val for t in out_tracers] out_vals = [t.val for t in out_tracers]
@ -196,27 +194,20 @@ def check_errors_subtrace(main, msgs, err, code, *args):
del main.error del main.error
yield (err, code, *out_vals), msgs yield (err, code, *out_vals), msgs
def checkify_fun_to_jaxpr(f, error, in_avals):
f, msgs = check_errors_subtrace(f)
f = check_errors_traceable(f, tuple(error.msgs.items()))
err_aval = core.raise_to_shaped(core.get_aval(error.err))
code_aval = core.raise_to_shaped(core.get_aval(error.code))
avals_in = [err_aval, code_aval, *in_avals]
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
# TODO take (error_aval, code_aval) instead of error here? # TODO take (error_aval, code_aval) instead of error here?
def checkify_jaxpr(jaxpr, error): def checkify_jaxpr(jaxpr, error):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals) return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
# TODO dedup with check_errors_toplevel def checkify_fun_to_jaxpr(f, error, in_avals):
@lu.transformation f, msgs = checkify_subtrace(f)
def check_errors_traceable(msgs, err, code, *args): f = checkify_traceable(f, tuple(error.msgs.items()))
with core.new_main(ErrorTrace) as main: err_aval = core.raise_to_shaped(core.get_aval(error.err))
outs = yield (main, msgs, err, code, *args), {} code_aval = core.raise_to_shaped(core.get_aval(error.code))
del main avals_in = [err_aval, code_aval, *in_avals]
yield outs jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
## assert primitive ## assert primitive
@ -398,30 +389,6 @@ def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_ncons
return out, Error(err, code, new_msgs) return out, Error(err, code, new_msgs)
error_checks[lax.while_p] = while_loop_error_check error_checks[lax.while_p] = while_loop_error_check
# TODO(mattjj,lenamartens): currently we bundle effectful-assert-discharging
# with the error-check-adding transformation (checkify), but they could be
# separated into two orthogonal transformations.
def assert_discharge_rule(error, pred, code, *, msgs):
out_err = error.err | jnp.logical_not(pred)
out_code = lax.select(error.err, error.code, code)
return [], Error(out_err, out_code, {**error.msgs, **msgs})
error_checks[assert_p] = assert_discharge_rule
## checkify api
def checkify(fun: Callable) -> Callable:
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
(err, code, out_flat), msgs = check_errors_flat(f, *args_flat)
out = tree_unflatten(out_tree(), out_flat)
return Error(err, code, msgs), out
return checked_fun
## NaN error rule table
def add_nan_check(prim): def add_nan_check(prim):
error_checks[prim] = partial(nan_error_check, prim) error_checks[prim] = partial(nan_error_check, prim)
@ -485,3 +452,23 @@ add_nan_check(lax.abs_p)
add_nan_check(lax.select_p) add_nan_check(lax.select_p)
add_nan_check(lax.max_p) add_nan_check(lax.max_p)
add_nan_check(lax.min_p) add_nan_check(lax.min_p)
def assert_discharge_rule(error, pred, code, *, msgs):
out_err = error.err | jnp.logical_not(pred)
out_code = lax.select(error.err, error.code, code)
return [], Error(out_err, out_code, {**error.msgs, **msgs})
error_checks[assert_p] = assert_discharge_rule
## checkify api
Out = TypeVar('Out')
def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]:
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
(err, code, out_flat), msgs = checkify_flat(f, *args_flat)
out = tree_unflatten(out_tree(), out_flat)
return Error(err, code, msgs), out
return checked_fun