generalize assert primitive, allow recharging

This commit is contained in:
Matthew Johnson 2021-12-02 14:26:58 -08:00
parent 40912d5d96
commit c1f71d17c0
2 changed files with 68 additions and 21 deletions

View File

@ -88,8 +88,19 @@ def assert_func(error: Error, pred: Bool, msg: str) -> Error:
## Checkify transformation for plumbing functional error values.
class ErrorTracer(core.Tracer):
def __init__(self, trace, val):
self._trace = trace
self.val = val
core.get_aval(val), val
aval = property(lambda self: core.get_aval(self.val))
full_lower = lambda self: self
class ErrorTrace(core.Trace):
pure = lift = sublift = lambda self, val: ErrorTracer(self, val)
pure = lift = lambda self, val: ErrorTracer(self, val)
def sublift(self, tracer):
return ErrorTracer(self, tracer.val)
def process_primitive(self, primitive, tracers, params):
in_vals = [t.val for t in tracers]
@ -160,14 +171,6 @@ def _reduce_any_error(errs, codes):
ErrorCheckRule = Callable
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
class ErrorTracer(core.Tracer):
def __init__(self, trace, val):
self._trace = trace
self.val = val
core.get_aval(val), val
aval = property(lambda self: core.get_aval(self.val))
full_lower = lambda self: self
def check_errors_flat(fun: lu.WrappedFun, *args):
fun, msgs = check_errors_subtrace(fun)
fun = check_errors_toplevel(fun)
@ -217,28 +220,34 @@ def check_errors_traceable(msgs, err, code, *args):
## assert primitive
def assert_(pred: Bool, msg: str) -> None:
return assert_p.bind(pred, msg=msg)
code = next_code()
return assert2_(pred, code, {code: msg})
def assert2_(pred: Bool, code: Int, msgs: Dict[int, str]) -> None:
return assert_p.bind(pred, code, msgs=msgs)
assert_p = core.Primitive('assert')
assert_p.multiple_results = True # zero results
@assert_p.def_impl
def assert_impl(pred, *, msg):
assert pred, msg
def assert_impl(pred, code, *, msgs):
assert pred, msgs[int(code)]
return []
@assert_p.def_abstract_eval
def assert_abstract_eval(pred, *, msg):
def assert_abstract_eval(pred, code, *, msgs):
raise Exception("can't be staged!")
## checkify rules
def summary() -> str:
return str(source_info_util.summarize(source_info_util.current()))
def nan_error_check(prim, error, *in_vals, **params):
out = prim.bind(*in_vals, **params)
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
summary = source_info_util.summarize(source_info_util.current())
msg = f"nan generated by primitive {prim.name} at {summary}"
msg = f"nan generated by primitive {prim.name} at {summary()}"
return out, assert_func(error, no_nans, msg)
error_checks[lax.sin_p] = partial(nan_error_check, lax.sin_p)
@ -260,8 +269,7 @@ def gather_error_check(error, operand, start_indices, *,
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
all_inbounds = jnp.all((start_indices >= 0) & (start_indices <= upper_bound))
summary = source_info_util.summarize(source_info_util.current())
msg = f"out-of-bounds indexing at {summary}"
msg = f"out-of-bounds indexing at {summary()}"
return out, assert_func(error, all_inbounds, msg)
error_checks[slicing.gather_p] = gather_error_check
@ -276,10 +284,12 @@ def cond_error_check(error, index, *ops, branches, linear):
error_checks[control_flow.cond_p] = cond_error_check
# TODO(mattjj,lenamartens): currently we bundle effectful-assert-discharging
# with the error-check-adding transformation (checkify), but the two could be
# made orthogonal.
def assert_discharge_rule(err: Error, pred: Bool, *, msg: str):
return [], assert_func(err, pred, msg)
# with the error-check-adding transformation (checkify), but they could be
# separated into two orthogonal transformations.
def assert_discharge_rule(error, pred, code, *, msgs):
out_err = error.err | jnp.logical_not(pred)
out_code = lax.select(error.err, error.code, code)
return [], Error(out_err, out_code, {**error.msgs, **msgs})
error_checks[assert_p] = assert_discharge_rule

View File

@ -185,6 +185,43 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "must be positive")
def test_assert2(self):
def f(pred): # note: data dependence needed!
checkify.assert2_(pred, 0, {0: "hi"})
with self.assertRaisesRegex(AssertionError, "hi"):
f(False)
f = checkify.checkify(f)
err, none = f(False)
self.assertIsNone(none)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "hi")
def test_discharge_recharge(self):
def ejit(f):
f = checkify.checkify(f)
f = jax.jit(f)
def jitted_f(*args):
err, out = f(*args)
checkify.assert2_(~err.err, err.code, err.msgs)
return out
return jitted_f
@ejit
def f(pred):
assert python_should_be_running
checkify.assert_(pred, "foo")
python_should_be_running = True
f(True)
python_should_be_running = False
f(True)
with self.assertRaisesRegex(AssertionError, "foo"):
f(False)
if __name__ == "__main__":
absltest.main(testLoader=jtu.JaxTestLoader())