mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
checkify: tweak some organization and names
This commit is contained in:
parent
7bc51879d4
commit
6850833c3a
@ -15,7 +15,7 @@
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from typing import Union, Optional, Callable, Dict
|
||||
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -86,7 +86,7 @@ def assert_func(error: Error, pred: Bool, msg: str) -> Error:
|
||||
|
||||
## Checkify transformation for plumbing functional error values.
|
||||
|
||||
class ErrorTracer(core.Tracer):
|
||||
class CheckifyTracer(core.Tracer):
|
||||
def __init__(self, trace, val):
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
@ -94,11 +94,11 @@ class ErrorTracer(core.Tracer):
|
||||
aval = property(lambda self: core.get_aval(self.val))
|
||||
full_lower = lambda self: self
|
||||
|
||||
class ErrorTrace(core.Trace):
|
||||
pure = lift = lambda self, val: ErrorTracer(self, val)
|
||||
class CheckifyTrace(core.Trace):
|
||||
pure = lift = lambda self, val: CheckifyTracer(self, val)
|
||||
|
||||
def sublift(self, tracer):
|
||||
return ErrorTracer(self, tracer.val)
|
||||
return CheckifyTracer(self, tracer.val)
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
in_vals = [t.val for t in tracers]
|
||||
@ -108,23 +108,23 @@ class ErrorTrace(core.Trace):
|
||||
else:
|
||||
out = primitive.bind(*in_vals, **params)
|
||||
if primitive.multiple_results:
|
||||
return [ErrorTracer(self, x) for x in out]
|
||||
return [CheckifyTracer(self, x) for x in out]
|
||||
else:
|
||||
return ErrorTracer(self, out)
|
||||
return CheckifyTracer(self, out)
|
||||
|
||||
def process_call(self, primitive, f, tracers, params):
|
||||
in_vals = [t.val for t in tracers]
|
||||
e = popattr(self.main, 'error')
|
||||
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items()))
|
||||
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
|
||||
params_ = dict(params, donated_invars=(False, False, *params['donated_invars']))
|
||||
err, code, *out_vals = primitive.bind(f, e.err, e.code, *in_vals, **params_)
|
||||
setnewattr(self.main, 'error', Error(err, code, msgs()))
|
||||
return [ErrorTracer(self, x) for x in out_vals]
|
||||
return [CheckifyTracer(self, x) for x in out_vals]
|
||||
|
||||
def process_map(self, primitive, f, tracers, params):
|
||||
in_vals = [t.val for t in tracers]
|
||||
e = popattr(self.main, 'error')
|
||||
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items()))
|
||||
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
|
||||
|
||||
@as_hashable_function(closure=params['out_axes_thunk'])
|
||||
def new_out_axes_thunk():
|
||||
@ -136,7 +136,7 @@ class ErrorTrace(core.Trace):
|
||||
errs, codes, *outs = primitive.bind(f, e.err, e.code, *in_vals, **params_)
|
||||
err, code = _reduce_any_error(errs, codes)
|
||||
setnewattr(self.main, 'error', Error(err, code, msgs()))
|
||||
return [ErrorTracer(self, x) for x in outs]
|
||||
return [CheckifyTracer(self, x) for x in outs]
|
||||
|
||||
def post_process_call(self, primitive, tracers, params):
|
||||
vals = [t.val for t in tracers]
|
||||
@ -146,7 +146,7 @@ class ErrorTrace(core.Trace):
|
||||
def todo(vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
err, code, *vals = vals
|
||||
return [ErrorTracer(trace, x) for x in vals]
|
||||
return [CheckifyTracer(trace, x) for x in vals]
|
||||
return (err, code, *vals), todo
|
||||
|
||||
def post_process_map(self, primitive, tracers, params):
|
||||
@ -157,7 +157,7 @@ class ErrorTrace(core.Trace):
|
||||
def todo(vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
err, code, *vals = vals
|
||||
return [ErrorTracer(trace, x) for x in vals]
|
||||
return [CheckifyTracer(trace, x) for x in vals]
|
||||
def out_axes_transform(out_axes):
|
||||
return (0, 0, *out_axes)
|
||||
return (err, code, *vals), (todo, out_axes_transform)
|
||||
@ -169,26 +169,24 @@ def _reduce_any_error(errs, codes):
|
||||
ErrorCheckRule = Callable
|
||||
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
||||
|
||||
def check_errors_flat(fun: lu.WrappedFun, *args):
|
||||
fun, msgs = check_errors_subtrace(fun)
|
||||
fun = check_errors_toplevel(fun)
|
||||
err, code, *out_vals = fun.call_wrapped(*args)
|
||||
return (err, code, out_vals), msgs()
|
||||
def checkify_flat(fun: lu.WrappedFun, *args):
|
||||
fun, msgs = checkify_subtrace(fun)
|
||||
fun = checkify_traceable(fun, tuple(init_error.msgs.items()))
|
||||
err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
|
||||
return (err, code, outvals), msgs()
|
||||
|
||||
@lu.transformation
|
||||
def check_errors_toplevel(*args):
|
||||
error = init_error
|
||||
with core.new_main(ErrorTrace) as main:
|
||||
msgs = tuple(error.msgs.items())
|
||||
outs = yield (main, msgs, error.err, error.code, *args), {}
|
||||
def checkify_traceable(msgs, err, code, *args):
|
||||
with core.new_main(CheckifyTrace) as main:
|
||||
outs = yield (main, msgs, err, code, *args), {}
|
||||
del main
|
||||
yield outs
|
||||
|
||||
@lu.transformation_with_aux
|
||||
def check_errors_subtrace(main, msgs, err, code, *args):
|
||||
def checkify_subtrace(main, msgs, err, code, *args):
|
||||
setnewattr(main, 'error', Error(err, code, dict(msgs)))
|
||||
trace = main.with_cur_sublevel()
|
||||
in_tracers = [ErrorTracer(trace, x) for x in args]
|
||||
in_tracers = [CheckifyTracer(trace, x) for x in args]
|
||||
out = yield in_tracers, {}
|
||||
out_tracers = map(trace.full_raise, out)
|
||||
out_vals = [t.val for t in out_tracers]
|
||||
@ -196,27 +194,20 @@ def check_errors_subtrace(main, msgs, err, code, *args):
|
||||
del main.error
|
||||
yield (err, code, *out_vals), msgs
|
||||
|
||||
def checkify_fun_to_jaxpr(f, error, in_avals):
|
||||
f, msgs = check_errors_subtrace(f)
|
||||
f = check_errors_traceable(f, tuple(error.msgs.items()))
|
||||
err_aval = core.raise_to_shaped(core.get_aval(error.err))
|
||||
code_aval = core.raise_to_shaped(core.get_aval(error.code))
|
||||
avals_in = [err_aval, code_aval, *in_avals]
|
||||
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
|
||||
|
||||
# TODO take (error_aval, code_aval) instead of error here?
|
||||
def checkify_jaxpr(jaxpr, error):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
|
||||
|
||||
# TODO dedup with check_errors_toplevel
|
||||
@lu.transformation
|
||||
def check_errors_traceable(msgs, err, code, *args):
|
||||
with core.new_main(ErrorTrace) as main:
|
||||
outs = yield (main, msgs, err, code, *args), {}
|
||||
del main
|
||||
yield outs
|
||||
def checkify_fun_to_jaxpr(f, error, in_avals):
|
||||
f, msgs = checkify_subtrace(f)
|
||||
f = checkify_traceable(f, tuple(error.msgs.items()))
|
||||
err_aval = core.raise_to_shaped(core.get_aval(error.err))
|
||||
code_aval = core.raise_to_shaped(core.get_aval(error.code))
|
||||
avals_in = [err_aval, code_aval, *in_avals]
|
||||
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
|
||||
|
||||
|
||||
## assert primitive
|
||||
@ -398,30 +389,6 @@ def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_ncons
|
||||
return out, Error(err, code, new_msgs)
|
||||
error_checks[lax.while_p] = while_loop_error_check
|
||||
|
||||
# TODO(mattjj,lenamartens): currently we bundle effectful-assert-discharging
|
||||
# with the error-check-adding transformation (checkify), but they could be
|
||||
# separated into two orthogonal transformations.
|
||||
def assert_discharge_rule(error, pred, code, *, msgs):
|
||||
out_err = error.err | jnp.logical_not(pred)
|
||||
out_code = lax.select(error.err, error.code, code)
|
||||
return [], Error(out_err, out_code, {**error.msgs, **msgs})
|
||||
error_checks[assert_p] = assert_discharge_rule
|
||||
|
||||
|
||||
## checkify api
|
||||
|
||||
def checkify(fun: Callable) -> Callable:
|
||||
@traceback_util.api_boundary
|
||||
def checked_fun(*args, **kwargs):
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
(err, code, out_flat), msgs = check_errors_flat(f, *args_flat)
|
||||
out = tree_unflatten(out_tree(), out_flat)
|
||||
return Error(err, code, msgs), out
|
||||
return checked_fun
|
||||
|
||||
## NaN error rule table
|
||||
|
||||
def add_nan_check(prim):
|
||||
error_checks[prim] = partial(nan_error_check, prim)
|
||||
|
||||
@ -485,3 +452,23 @@ add_nan_check(lax.abs_p)
|
||||
add_nan_check(lax.select_p)
|
||||
add_nan_check(lax.max_p)
|
||||
add_nan_check(lax.min_p)
|
||||
|
||||
def assert_discharge_rule(error, pred, code, *, msgs):
|
||||
out_err = error.err | jnp.logical_not(pred)
|
||||
out_code = lax.select(error.err, error.code, code)
|
||||
return [], Error(out_err, out_code, {**error.msgs, **msgs})
|
||||
error_checks[assert_p] = assert_discharge_rule
|
||||
|
||||
|
||||
## checkify api
|
||||
|
||||
Out = TypeVar('Out')
|
||||
def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]:
|
||||
@traceback_util.api_boundary
|
||||
def checked_fun(*args, **kwargs):
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
(err, code, out_flat), msgs = checkify_flat(f, *args_flat)
|
||||
out = tree_unflatten(out_tree(), out_flat)
|
||||
return Error(err, code, msgs), out
|
||||
return checked_fun
|
||||
|
Loading…
x
Reference in New Issue
Block a user