mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 11:56:07 +00:00
checkify: tweak some organization and names
This commit is contained in:
parent
7bc51879d4
commit
6850833c3a
@ -15,7 +15,7 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import itertools as it
|
import itertools as it
|
||||||
from typing import Union, Optional, Callable, Dict
|
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@ -86,7 +86,7 @@ def assert_func(error: Error, pred: Bool, msg: str) -> Error:
|
|||||||
|
|
||||||
## Checkify transformation for plumbing functional error values.
|
## Checkify transformation for plumbing functional error values.
|
||||||
|
|
||||||
class ErrorTracer(core.Tracer):
|
class CheckifyTracer(core.Tracer):
|
||||||
def __init__(self, trace, val):
|
def __init__(self, trace, val):
|
||||||
self._trace = trace
|
self._trace = trace
|
||||||
self.val = val
|
self.val = val
|
||||||
@ -94,11 +94,11 @@ class ErrorTracer(core.Tracer):
|
|||||||
aval = property(lambda self: core.get_aval(self.val))
|
aval = property(lambda self: core.get_aval(self.val))
|
||||||
full_lower = lambda self: self
|
full_lower = lambda self: self
|
||||||
|
|
||||||
class ErrorTrace(core.Trace):
|
class CheckifyTrace(core.Trace):
|
||||||
pure = lift = lambda self, val: ErrorTracer(self, val)
|
pure = lift = lambda self, val: CheckifyTracer(self, val)
|
||||||
|
|
||||||
def sublift(self, tracer):
|
def sublift(self, tracer):
|
||||||
return ErrorTracer(self, tracer.val)
|
return CheckifyTracer(self, tracer.val)
|
||||||
|
|
||||||
def process_primitive(self, primitive, tracers, params):
|
def process_primitive(self, primitive, tracers, params):
|
||||||
in_vals = [t.val for t in tracers]
|
in_vals = [t.val for t in tracers]
|
||||||
@ -108,23 +108,23 @@ class ErrorTrace(core.Trace):
|
|||||||
else:
|
else:
|
||||||
out = primitive.bind(*in_vals, **params)
|
out = primitive.bind(*in_vals, **params)
|
||||||
if primitive.multiple_results:
|
if primitive.multiple_results:
|
||||||
return [ErrorTracer(self, x) for x in out]
|
return [CheckifyTracer(self, x) for x in out]
|
||||||
else:
|
else:
|
||||||
return ErrorTracer(self, out)
|
return CheckifyTracer(self, out)
|
||||||
|
|
||||||
def process_call(self, primitive, f, tracers, params):
|
def process_call(self, primitive, f, tracers, params):
|
||||||
in_vals = [t.val for t in tracers]
|
in_vals = [t.val for t in tracers]
|
||||||
e = popattr(self.main, 'error')
|
e = popattr(self.main, 'error')
|
||||||
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items()))
|
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
|
||||||
params_ = dict(params, donated_invars=(False, False, *params['donated_invars']))
|
params_ = dict(params, donated_invars=(False, False, *params['donated_invars']))
|
||||||
err, code, *out_vals = primitive.bind(f, e.err, e.code, *in_vals, **params_)
|
err, code, *out_vals = primitive.bind(f, e.err, e.code, *in_vals, **params_)
|
||||||
setnewattr(self.main, 'error', Error(err, code, msgs()))
|
setnewattr(self.main, 'error', Error(err, code, msgs()))
|
||||||
return [ErrorTracer(self, x) for x in out_vals]
|
return [CheckifyTracer(self, x) for x in out_vals]
|
||||||
|
|
||||||
def process_map(self, primitive, f, tracers, params):
|
def process_map(self, primitive, f, tracers, params):
|
||||||
in_vals = [t.val for t in tracers]
|
in_vals = [t.val for t in tracers]
|
||||||
e = popattr(self.main, 'error')
|
e = popattr(self.main, 'error')
|
||||||
f, msgs = check_errors_subtrace(f, self.main, tuple(e.msgs.items()))
|
f, msgs = checkify_subtrace(f, self.main, tuple(e.msgs.items()))
|
||||||
|
|
||||||
@as_hashable_function(closure=params['out_axes_thunk'])
|
@as_hashable_function(closure=params['out_axes_thunk'])
|
||||||
def new_out_axes_thunk():
|
def new_out_axes_thunk():
|
||||||
@ -136,7 +136,7 @@ class ErrorTrace(core.Trace):
|
|||||||
errs, codes, *outs = primitive.bind(f, e.err, e.code, *in_vals, **params_)
|
errs, codes, *outs = primitive.bind(f, e.err, e.code, *in_vals, **params_)
|
||||||
err, code = _reduce_any_error(errs, codes)
|
err, code = _reduce_any_error(errs, codes)
|
||||||
setnewattr(self.main, 'error', Error(err, code, msgs()))
|
setnewattr(self.main, 'error', Error(err, code, msgs()))
|
||||||
return [ErrorTracer(self, x) for x in outs]
|
return [CheckifyTracer(self, x) for x in outs]
|
||||||
|
|
||||||
def post_process_call(self, primitive, tracers, params):
|
def post_process_call(self, primitive, tracers, params):
|
||||||
vals = [t.val for t in tracers]
|
vals = [t.val for t in tracers]
|
||||||
@ -146,7 +146,7 @@ class ErrorTrace(core.Trace):
|
|||||||
def todo(vals):
|
def todo(vals):
|
||||||
trace = main.with_cur_sublevel()
|
trace = main.with_cur_sublevel()
|
||||||
err, code, *vals = vals
|
err, code, *vals = vals
|
||||||
return [ErrorTracer(trace, x) for x in vals]
|
return [CheckifyTracer(trace, x) for x in vals]
|
||||||
return (err, code, *vals), todo
|
return (err, code, *vals), todo
|
||||||
|
|
||||||
def post_process_map(self, primitive, tracers, params):
|
def post_process_map(self, primitive, tracers, params):
|
||||||
@ -157,7 +157,7 @@ class ErrorTrace(core.Trace):
|
|||||||
def todo(vals):
|
def todo(vals):
|
||||||
trace = main.with_cur_sublevel()
|
trace = main.with_cur_sublevel()
|
||||||
err, code, *vals = vals
|
err, code, *vals = vals
|
||||||
return [ErrorTracer(trace, x) for x in vals]
|
return [CheckifyTracer(trace, x) for x in vals]
|
||||||
def out_axes_transform(out_axes):
|
def out_axes_transform(out_axes):
|
||||||
return (0, 0, *out_axes)
|
return (0, 0, *out_axes)
|
||||||
return (err, code, *vals), (todo, out_axes_transform)
|
return (err, code, *vals), (todo, out_axes_transform)
|
||||||
@ -169,26 +169,24 @@ def _reduce_any_error(errs, codes):
|
|||||||
ErrorCheckRule = Callable
|
ErrorCheckRule = Callable
|
||||||
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
||||||
|
|
||||||
def check_errors_flat(fun: lu.WrappedFun, *args):
|
def checkify_flat(fun: lu.WrappedFun, *args):
|
||||||
fun, msgs = check_errors_subtrace(fun)
|
fun, msgs = checkify_subtrace(fun)
|
||||||
fun = check_errors_toplevel(fun)
|
fun = checkify_traceable(fun, tuple(init_error.msgs.items()))
|
||||||
err, code, *out_vals = fun.call_wrapped(*args)
|
err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
|
||||||
return (err, code, out_vals), msgs()
|
return (err, code, outvals), msgs()
|
||||||
|
|
||||||
@lu.transformation
|
@lu.transformation
|
||||||
def check_errors_toplevel(*args):
|
def checkify_traceable(msgs, err, code, *args):
|
||||||
error = init_error
|
with core.new_main(CheckifyTrace) as main:
|
||||||
with core.new_main(ErrorTrace) as main:
|
outs = yield (main, msgs, err, code, *args), {}
|
||||||
msgs = tuple(error.msgs.items())
|
|
||||||
outs = yield (main, msgs, error.err, error.code, *args), {}
|
|
||||||
del main
|
del main
|
||||||
yield outs
|
yield outs
|
||||||
|
|
||||||
@lu.transformation_with_aux
|
@lu.transformation_with_aux
|
||||||
def check_errors_subtrace(main, msgs, err, code, *args):
|
def checkify_subtrace(main, msgs, err, code, *args):
|
||||||
setnewattr(main, 'error', Error(err, code, dict(msgs)))
|
setnewattr(main, 'error', Error(err, code, dict(msgs)))
|
||||||
trace = main.with_cur_sublevel()
|
trace = main.with_cur_sublevel()
|
||||||
in_tracers = [ErrorTracer(trace, x) for x in args]
|
in_tracers = [CheckifyTracer(trace, x) for x in args]
|
||||||
out = yield in_tracers, {}
|
out = yield in_tracers, {}
|
||||||
out_tracers = map(trace.full_raise, out)
|
out_tracers = map(trace.full_raise, out)
|
||||||
out_vals = [t.val for t in out_tracers]
|
out_vals = [t.val for t in out_tracers]
|
||||||
@ -196,27 +194,20 @@ def check_errors_subtrace(main, msgs, err, code, *args):
|
|||||||
del main.error
|
del main.error
|
||||||
yield (err, code, *out_vals), msgs
|
yield (err, code, *out_vals), msgs
|
||||||
|
|
||||||
def checkify_fun_to_jaxpr(f, error, in_avals):
|
|
||||||
f, msgs = check_errors_subtrace(f)
|
|
||||||
f = check_errors_traceable(f, tuple(error.msgs.items()))
|
|
||||||
err_aval = core.raise_to_shaped(core.get_aval(error.err))
|
|
||||||
code_aval = core.raise_to_shaped(core.get_aval(error.code))
|
|
||||||
avals_in = [err_aval, code_aval, *in_avals]
|
|
||||||
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
|
||||||
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
|
|
||||||
|
|
||||||
# TODO take (error_aval, code_aval) instead of error here?
|
# TODO take (error_aval, code_aval) instead of error here?
|
||||||
def checkify_jaxpr(jaxpr, error):
|
def checkify_jaxpr(jaxpr, error):
|
||||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||||
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
|
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
|
||||||
|
|
||||||
# TODO dedup with check_errors_toplevel
|
def checkify_fun_to_jaxpr(f, error, in_avals):
|
||||||
@lu.transformation
|
f, msgs = checkify_subtrace(f)
|
||||||
def check_errors_traceable(msgs, err, code, *args):
|
f = checkify_traceable(f, tuple(error.msgs.items()))
|
||||||
with core.new_main(ErrorTrace) as main:
|
err_aval = core.raise_to_shaped(core.get_aval(error.err))
|
||||||
outs = yield (main, msgs, err, code, *args), {}
|
code_aval = core.raise_to_shaped(core.get_aval(error.code))
|
||||||
del main
|
avals_in = [err_aval, code_aval, *in_avals]
|
||||||
yield outs
|
jaxpr_out, _, literals_out = pe.trace_to_jaxpr_dynamic(f, avals_in)
|
||||||
|
return core.ClosedJaxpr(jaxpr_out, literals_out), msgs()
|
||||||
|
|
||||||
|
|
||||||
## assert primitive
|
## assert primitive
|
||||||
@ -398,30 +389,6 @@ def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_ncons
|
|||||||
return out, Error(err, code, new_msgs)
|
return out, Error(err, code, new_msgs)
|
||||||
error_checks[lax.while_p] = while_loop_error_check
|
error_checks[lax.while_p] = while_loop_error_check
|
||||||
|
|
||||||
# TODO(mattjj,lenamartens): currently we bundle effectful-assert-discharging
|
|
||||||
# with the error-check-adding transformation (checkify), but they could be
|
|
||||||
# separated into two orthogonal transformations.
|
|
||||||
def assert_discharge_rule(error, pred, code, *, msgs):
|
|
||||||
out_err = error.err | jnp.logical_not(pred)
|
|
||||||
out_code = lax.select(error.err, error.code, code)
|
|
||||||
return [], Error(out_err, out_code, {**error.msgs, **msgs})
|
|
||||||
error_checks[assert_p] = assert_discharge_rule
|
|
||||||
|
|
||||||
|
|
||||||
## checkify api
|
|
||||||
|
|
||||||
def checkify(fun: Callable) -> Callable:
|
|
||||||
@traceback_util.api_boundary
|
|
||||||
def checked_fun(*args, **kwargs):
|
|
||||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
|
||||||
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
|
||||||
(err, code, out_flat), msgs = check_errors_flat(f, *args_flat)
|
|
||||||
out = tree_unflatten(out_tree(), out_flat)
|
|
||||||
return Error(err, code, msgs), out
|
|
||||||
return checked_fun
|
|
||||||
|
|
||||||
## NaN error rule table
|
|
||||||
|
|
||||||
def add_nan_check(prim):
|
def add_nan_check(prim):
|
||||||
error_checks[prim] = partial(nan_error_check, prim)
|
error_checks[prim] = partial(nan_error_check, prim)
|
||||||
|
|
||||||
@ -485,3 +452,23 @@ add_nan_check(lax.abs_p)
|
|||||||
add_nan_check(lax.select_p)
|
add_nan_check(lax.select_p)
|
||||||
add_nan_check(lax.max_p)
|
add_nan_check(lax.max_p)
|
||||||
add_nan_check(lax.min_p)
|
add_nan_check(lax.min_p)
|
||||||
|
|
||||||
|
def assert_discharge_rule(error, pred, code, *, msgs):
|
||||||
|
out_err = error.err | jnp.logical_not(pred)
|
||||||
|
out_code = lax.select(error.err, error.code, code)
|
||||||
|
return [], Error(out_err, out_code, {**error.msgs, **msgs})
|
||||||
|
error_checks[assert_p] = assert_discharge_rule
|
||||||
|
|
||||||
|
|
||||||
|
## checkify api
|
||||||
|
|
||||||
|
Out = TypeVar('Out')
|
||||||
|
def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]:
|
||||||
|
@traceback_util.api_boundary
|
||||||
|
def checked_fun(*args, **kwargs):
|
||||||
|
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||||
|
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||||
|
(err, code, out_flat), msgs = checkify_flat(f, *args_flat)
|
||||||
|
out = tree_unflatten(out_tree(), out_flat)
|
||||||
|
return Error(err, code, msgs), out
|
||||||
|
return checked_fun
|
||||||
|
Loading…
x
Reference in New Issue
Block a user