mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
checkify: fix and test post_process_call/map
This commit is contained in:
parent
eddea68a9a
commit
9488c5ae72
@ -29,11 +29,15 @@ from jax.interpreters import partial_eval as pe
|
||||
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
|
||||
from jax._src import source_info_util, traceback_util
|
||||
from jax import lax
|
||||
from jax._src.util import as_hashable_function, unzip2, split_list
|
||||
from jax._src.util import (as_hashable_function, unzip2, split_list, safe_map,
|
||||
safe_zip)
|
||||
|
||||
source_info_util.register_exclusion(__file__)
|
||||
traceback_util.register_exclusion(__file__)
|
||||
|
||||
map, unsafe_map = safe_map, map
|
||||
zip, unsafe_zip = safe_zip, zip
|
||||
|
||||
|
||||
## Utils
|
||||
|
||||
@ -112,7 +116,8 @@ class CheckifyTrace(core.Trace):
|
||||
in_vals = [t.val for t in tracers]
|
||||
rule = error_checks.get(primitive)
|
||||
if rule:
|
||||
out, self.main.error = rule(self.main.error, self.main.enabled_errors, *in_vals, **params) # type: ignore
|
||||
out, self.main.error = rule(self.main.error, self.main.enabled_errors, # type: ignore
|
||||
*in_vals, **params)
|
||||
else:
|
||||
out = primitive.bind(*in_vals, **params)
|
||||
if primitive.multiple_results:
|
||||
@ -149,22 +154,25 @@ class CheckifyTrace(core.Trace):
|
||||
def post_process_call(self, primitive, tracers, params):
|
||||
vals = [t.val for t in tracers]
|
||||
main = self.main
|
||||
e = popattr(self.main, 'error')
|
||||
e = popattr(main, 'error')
|
||||
err, code, main.msgs = e.err, e.code, e.msgs
|
||||
def todo(vals):
|
||||
trace = main.with_cur_sublevel()
|
||||
err, code, *vals = vals
|
||||
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs')))
|
||||
trace = main.with_cur_sublevel()
|
||||
return [CheckifyTracer(trace, x) for x in vals]
|
||||
return (err, code, *vals), todo
|
||||
|
||||
def post_process_map(self, primitive, tracers, params):
|
||||
vals = [t.val for t in tracers]
|
||||
main = self.main
|
||||
e = popattr(self.main, 'error')
|
||||
e = popattr(main, 'error')
|
||||
err, code, main.msgs = e.err, e.code, e.msgs
|
||||
def todo(vals):
|
||||
errs, codes, *vals = vals
|
||||
err, code = _reduce_any_error(errs, codes)
|
||||
setnewattr(main, 'error', Error(err, code, popattr(main, 'msgs')))
|
||||
trace = main.with_cur_sublevel()
|
||||
err, code, *vals = vals
|
||||
return [CheckifyTracer(trace, x) for x in vals]
|
||||
def out_axes_transform(out_axes):
|
||||
return (0, 0, *out_axes)
|
||||
@ -174,10 +182,11 @@ def _reduce_any_error(errs, codes):
|
||||
errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0)
|
||||
return errs_[-1], codes_[-1]
|
||||
|
||||
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
|
||||
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
|
||||
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
||||
|
||||
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], *args):
|
||||
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'],
|
||||
*args):
|
||||
fun, msgs = checkify_subtrace(fun)
|
||||
fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors)
|
||||
err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
|
||||
@ -341,7 +350,8 @@ error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p
|
||||
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
|
||||
|
||||
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
|
||||
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) for jxpr in branches)
|
||||
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors)
|
||||
for jxpr in branches)
|
||||
new_linear = (False, False, *linear)
|
||||
err, code, *outs = lax.cond_p.bind(
|
||||
index, error.err, error.code, *ops,
|
||||
@ -350,7 +360,8 @@ def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
|
||||
return outs, Error(err, code, new_msgs)
|
||||
error_checks[lax.cond_p] = cond_error_check
|
||||
|
||||
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
|
||||
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr,
|
||||
num_consts, num_carry, linear, unroll):
|
||||
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
|
||||
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
|
||||
new_linear = (False, False, *linear)
|
||||
@ -371,7 +382,8 @@ def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors):
|
||||
out = body_f(*vals)
|
||||
_ = cond_f(*out) # this checks if the next cond application will error
|
||||
return out
|
||||
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals)
|
||||
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors,
|
||||
body_jaxpr.in_avals)
|
||||
|
||||
def ignore_errors_jaxpr(jaxpr, error):
|
||||
"""Constructs a jaxpr which takes two extra args but ignores them."""
|
||||
@ -385,13 +397,15 @@ def ignore_errors_jaxpr(jaxpr, error):
|
||||
jaxpr.outvars, jaxpr.eqns)
|
||||
return core.ClosedJaxpr(new_jaxpr, consts)
|
||||
|
||||
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
|
||||
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
|
||||
checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
|
||||
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts,
|
||||
cond_jaxpr, body_nconsts, body_jaxpr):
|
||||
cond_jaxpr_, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
|
||||
checked_cond_fun = core.jaxpr_as_fun(cond_jaxpr_)
|
||||
# Check if the first cond application will error.
|
||||
cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)
|
||||
|
||||
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors)
|
||||
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(
|
||||
cond_jaxpr, body_jaxpr, error, enabled_errors)
|
||||
compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
|
||||
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
|
||||
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
|
||||
@ -489,7 +503,8 @@ automatic_errors = float_errors | index_errors
|
||||
user_asserts = {ErrorCategory.ASSERT}
|
||||
|
||||
Out = TypeVar('Out')
|
||||
def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts) -> Callable[..., Tuple[Error, Out]]:
|
||||
def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts
|
||||
) -> Callable[..., Tuple[Error, Out]]:
|
||||
if not errors:
|
||||
raise ValueError('Checkify needs to be called with at least one enabled'
|
||||
' ErrorCategory, was called with an empty errors set.')
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
import unittest
|
||||
|
||||
from absl.testing import absltest
|
||||
@ -192,7 +193,6 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
err, y = checked_f(-jnp.inf)
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
|
||||
@jtu.skip_on_devices('tpu')
|
||||
def test_scan_map(self):
|
||||
def scan_body(_, x):
|
||||
@ -400,6 +400,30 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), expected_error)
|
||||
|
||||
@jtu.skip_on_devices('tpu')
|
||||
def test_post_process_call(self):
|
||||
@partial(checkify.checkify, errors=checkify.float_errors)
|
||||
def g(x):
|
||||
@jax.jit
|
||||
def f(y):
|
||||
return jnp.sin(x * y)
|
||||
return f(jnp.inf)
|
||||
err, _ = g(2.)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
|
||||
@jtu.skip_on_devices('tpu')
|
||||
def test_post_process_map(self):
|
||||
@partial(checkify.checkify, errors=checkify.float_errors)
|
||||
def g(x):
|
||||
@jax.pmap
|
||||
def f(y):
|
||||
return jnp.sin(x * y)
|
||||
return f(jnp.array([jnp.inf]))[0]
|
||||
err, _ = g(2.)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
|
||||
|
||||
class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
def test_assert_primitive_impl(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user