checkify: fix and test post_process_call/map

This commit is contained in:
Matthew Johnson 2022-01-18 22:22:57 -08:00
parent eddea68a9a
commit 9488c5ae72
2 changed files with 56 additions and 17 deletions

View File

@ -29,11 +29,15 @@ from jax.interpreters import partial_eval as pe
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax._src import source_info_util, traceback_util
from jax import lax
from jax._src.util import as_hashable_function, unzip2, split_list
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
safe_zip)
source_info_util.register_exclusion(__file__)
traceback_util.register_exclusion(__file__)
map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip
## Utils
@ -112,7 +116,8 @@ class CheckifyTrace(core.Trace):
in_vals = [t.val for t in tracers]
rule = error_checks.get(primitive)
if rule:
out, self.main.error = rule(self.main.error, self.main.enabled_errors, *in_vals, **params) # type: ignore
out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore
*in_vals, **params)
else:
out = primitive.bind(*in_vals, **params)
if primitive.multiple_results:
@ -149,22 +154,25 @@ class CheckifyTrace(core.Trace):
def post_process_call(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(self.main, 'error')
e = popattr(main, 'error')
err, code, main.msgs = e.err, e.code, e.msgs
def todo(vals):
trace = main.with_cur_sublevel()
err, code, *vals = vals
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs')))
trace = main.with_cur_sublevel()
return [CheckifyTracer(trace, x) for x in vals]
return (err, code, *vals), todo
def post_process_map(self, primitive, tracers, params):
vals = [t.val for t in tracers]
main = self.main
e = popattr(self.main, 'error')
e = popattr(main, 'error')
err, code, main.msgs = e.err, e.code, e.msgs
def todo(vals):
errs, codes, *vals = vals
err, code = _reduce_any_error(errs, codes)
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs')))
trace = main.with_cur_sublevel()
err, code, *vals = vals
return [CheckifyTracer(trace, x) for x in vals]
def out_axes_transform(out_axes):
return (0, 0, *out_axes)
@ -174,10 +182,11 @@ def _reduce_any_error(errs, codes):
errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0)
return errs_[-1], codes_[-1]
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], *args):
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'],
*args):
fun, msgs = checkify_subtrace(fun)
fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors)
err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
@ -341,7 +350,8 @@ error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) for jxpr in branches)
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors)
for jxpr in branches)
new_linear = (False, False, *linear)
err, code, *outs = lax.cond_p.bind(
index, error.err, error.code, *ops,
@ -350,7 +360,8 @@ def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
return outs, Error(err, code, new_msgs)
error_checks[lax.cond_p] = cond_error_check
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
num_consts, num_carry, linear, unroll):
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
new_linear = (False, False, *linear)
@ -371,7 +382,8 @@ def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors):
out = body_f(*vals)
_ = cond_f(*out) # this checks if the next cond application will error
return out
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals)
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors,
body_jaxpr.in_avals)
def ignore_errors_jaxpr(jaxpr, error):
"""Constructs a jaxpr which takes two extra args but ignores them."""
@ -385,13 +397,15 @@ def ignore_errors_jaxpr(jaxpr, error):
jaxpr.outvars, jaxpr.eqns)
return core.ClosedJaxpr(new_jaxpr, consts)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
cond_jaxpr, body_nconsts, body_jaxpr):
cond_jaxpr_, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
checked_cond_fun = core.jaxpr_as_fun(cond_jaxpr_)
# Check if the first cond application will error.
cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors)
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(
cond_jaxpr, body_jaxpr, error, enabled_errors)
compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
@ -489,7 +503,8 @@ automatic_errors = float_errors | index_errors
user_asserts = {ErrorCategory.ASSERT}
Out = TypeVar('Out')
def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts) -> Callable[..., Tuple[Error, Out]]:
def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts
) -> Callable[..., Tuple[Error, Out]]:
if not errors:
raise ValueError('Checkify needs to be called with at least one enabled'
' ErrorCategory, was called with an empty errors set.')

View File

@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import unittest
from absl.testing import absltest
@ -192,7 +193,6 @@ class CheckifyTransformTests(jtu.JaxTestCase):
err, y = checked_f(-jnp.inf)
self.assertIs(err.get(), None)
@jtu.skip_on_devices('tpu')
def test_scan_map(self):
def scan_body(_, x):
@ -400,6 +400,30 @@ class CheckifyTransformTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), expected_error)
@jtu.skip_on_devices('tpu')
def test_post_process_call(self):
@partial(checkify.checkify, errors=checkify.float_errors)
def g(x):
@jax.jit
def f(y):
return jnp.sin(x * y)
return f(jnp.inf)
err, _ = g(2.)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
@jtu.skip_on_devices('tpu')
def test_post_process_map(self):
@partial(checkify.checkify, errors=checkify.float_errors)
def g(x):
@jax.pmap
def f(y):
return jnp.sin(x * y)
return f(jnp.array([jnp.inf]))[0]
err, _ = g(2.)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
class AssertPrimitiveTests(jtu.JaxTestCase):
def test_assert_primitive_impl(self):