mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
generalize assert primitive, allow recharging
This commit is contained in:
parent
40912d5d96
commit
c1f71d17c0
@ -88,8 +88,19 @@ def assert_func(error: Error, pred: Bool, msg: str) -> Error:
|
||||
|
||||
## Checkify transformation for plumbing functional error values.
|
||||
|
||||
class ErrorTracer(core.Tracer):
|
||||
def __init__(self, trace, val):
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
core.get_aval(val), val
|
||||
aval = property(lambda self: core.get_aval(self.val))
|
||||
full_lower = lambda self: self
|
||||
|
||||
class ErrorTrace(core.Trace):
|
||||
pure = lift = sublift = lambda self, val: ErrorTracer(self, val)
|
||||
pure = lift = lambda self, val: ErrorTracer(self, val)
|
||||
|
||||
def sublift(self, tracer):
|
||||
return ErrorTracer(self, tracer.val)
|
||||
|
||||
def process_primitive(self, primitive, tracers, params):
|
||||
in_vals = [t.val for t in tracers]
|
||||
@ -160,14 +171,6 @@ def _reduce_any_error(errs, codes):
|
||||
ErrorCheckRule = Callable
|
||||
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
||||
|
||||
class ErrorTracer(core.Tracer):
|
||||
def __init__(self, trace, val):
|
||||
self._trace = trace
|
||||
self.val = val
|
||||
core.get_aval(val), val
|
||||
aval = property(lambda self: core.get_aval(self.val))
|
||||
full_lower = lambda self: self
|
||||
|
||||
def check_errors_flat(fun: lu.WrappedFun, *args):
|
||||
fun, msgs = check_errors_subtrace(fun)
|
||||
fun = check_errors_toplevel(fun)
|
||||
@ -217,28 +220,34 @@ def check_errors_traceable(msgs, err, code, *args):
|
||||
## assert primitive
|
||||
|
||||
def assert_(pred: Bool, msg: str) -> None:
|
||||
return assert_p.bind(pred, msg=msg)
|
||||
code = next_code()
|
||||
return assert2_(pred, code, {code: msg})
|
||||
|
||||
def assert2_(pred: Bool, code: Int, msgs: Dict[int, str]) -> None:
|
||||
return assert_p.bind(pred, code, msgs=msgs)
|
||||
|
||||
assert_p = core.Primitive('assert')
|
||||
assert_p.multiple_results = True # zero results
|
||||
|
||||
@assert_p.def_impl
|
||||
def assert_impl(pred, *, msg):
|
||||
assert pred, msg
|
||||
def assert_impl(pred, code, *, msgs):
|
||||
assert pred, msgs[int(code)]
|
||||
return []
|
||||
|
||||
@assert_p.def_abstract_eval
|
||||
def assert_abstract_eval(pred, *, msg):
|
||||
def assert_abstract_eval(pred, code, *, msgs):
|
||||
raise Exception("can't be staged!")
|
||||
|
||||
|
||||
## checkify rules
|
||||
|
||||
def summary() -> str:
|
||||
return str(source_info_util.summarize(source_info_util.current()))
|
||||
|
||||
def nan_error_check(prim, error, *in_vals, **params):
|
||||
out = prim.bind(*in_vals, **params)
|
||||
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
|
||||
summary = source_info_util.summarize(source_info_util.current())
|
||||
msg = f"nan generated by primitive {prim.name} at {summary}"
|
||||
msg = f"nan generated by primitive {prim.name} at {summary()}"
|
||||
return out, assert_func(error, no_nans, msg)
|
||||
|
||||
error_checks[lax.sin_p] = partial(nan_error_check, lax.sin_p)
|
||||
@ -260,8 +269,7 @@ def gather_error_check(error, operand, start_indices, *,
|
||||
upper_bound -= np.array(slice_sizes)[np.array(dnums.start_index_map)]
|
||||
all_inbounds = jnp.all((start_indices >= 0) & (start_indices <= upper_bound))
|
||||
|
||||
summary = source_info_util.summarize(source_info_util.current())
|
||||
msg = f"out-of-bounds indexing at {summary}"
|
||||
msg = f"out-of-bounds indexing at {summary()}"
|
||||
return out, assert_func(error, all_inbounds, msg)
|
||||
error_checks[slicing.gather_p] = gather_error_check
|
||||
|
||||
@ -276,10 +284,12 @@ def cond_error_check(error, index, *ops, branches, linear):
|
||||
error_checks[control_flow.cond_p] = cond_error_check
|
||||
|
||||
# TODO(mattjj,lenamartens): currently we bundle effectful-assert-discharging
|
||||
# with the error-check-adding transformation (checkify), but the two could be
|
||||
# made orthogonal.
|
||||
def assert_discharge_rule(err: Error, pred: Bool, *, msg: str):
|
||||
return [], assert_func(err, pred, msg)
|
||||
# with the error-check-adding transformation (checkify), but they could be
|
||||
# separated into two orthogonal transformations.
|
||||
def assert_discharge_rule(error, pred, code, *, msgs):
|
||||
out_err = error.err | jnp.logical_not(pred)
|
||||
out_code = lax.select(error.err, error.code, code)
|
||||
return [], Error(out_err, out_code, {**error.msgs, **msgs})
|
||||
error_checks[assert_p] = assert_discharge_rule
|
||||
|
||||
|
||||
|
@ -185,6 +185,43 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "must be positive")
|
||||
|
||||
def test_assert2(self):
|
||||
def f(pred): # note: data dependence needed!
|
||||
checkify.assert2_(pred, 0, {0: "hi"})
|
||||
|
||||
with self.assertRaisesRegex(AssertionError, "hi"):
|
||||
f(False)
|
||||
|
||||
f = checkify.checkify(f)
|
||||
err, none = f(False)
|
||||
|
||||
self.assertIsNone(none)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "hi")
|
||||
|
||||
def test_discharge_recharge(self):
|
||||
def ejit(f):
|
||||
f = checkify.checkify(f)
|
||||
f = jax.jit(f)
|
||||
def jitted_f(*args):
|
||||
err, out = f(*args)
|
||||
checkify.assert2_(~err.err, err.code, err.msgs)
|
||||
return out
|
||||
return jitted_f
|
||||
|
||||
@ejit
|
||||
def f(pred):
|
||||
assert python_should_be_running
|
||||
checkify.assert_(pred, "foo")
|
||||
|
||||
python_should_be_running = True
|
||||
f(True)
|
||||
|
||||
python_should_be_running = False
|
||||
f(True)
|
||||
with self.assertRaisesRegex(AssertionError, "foo"):
|
||||
f(False)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
absltest.main(testLoader=jtu.JaxTestLoader())
|
||||
|
Loading…
x
Reference in New Issue
Block a user