mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Checkify: add way to disable categories of errors.
By default only user_asserts are lifted into the checked function.
This commit is contained in:
parent
6411f8a033
commit
8ea85769ea
@ -12,10 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass
|
||||
from functools import partial
|
||||
import itertools as it
|
||||
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar
|
||||
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, Set, FrozenSet
|
||||
|
||||
import numpy as np
|
||||
|
||||
@ -97,6 +98,13 @@ class CheckifyTracer(core.Tracer):
|
||||
class CheckifyTrace(core.Trace):
|
||||
pure = lift = lambda self, val: CheckifyTracer(self, val)
|
||||
|
||||
def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
|
||||
enabled_errors: FrozenSet['ErrorCategory']) -> None:
|
||||
self.main = main
|
||||
self.level = main.level
|
||||
self.sublevel = sublevel
|
||||
self.main.enabled_errors = enabled_errors
|
||||
|
||||
def sublift(self, tracer):
|
||||
return CheckifyTracer(self, tracer.val)
|
||||
|
||||
@ -104,7 +112,7 @@ class CheckifyTrace(core.Trace):
|
||||
in_vals = [t.val for t in tracers]
|
||||
rule = error_checks.get(primitive)
|
||||
if rule:
|
||||
out, self.main.error = rule(self.main.error, *in_vals, **params) # type: ignore
|
||||
out, self.main.error = rule(self.main.error, self.main.enabled_errors, *in_vals, **params) # type: ignore
|
||||
else:
|
||||
out = primitive.bind(*in_vals, **params)
|
||||
if primitive.multiple_results:
|
||||
@ -166,18 +174,18 @@ def _reduce_any_error(errs, codes):
|
||||
errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0)
|
||||
return errs_[-1], codes_[-1]
|
||||
|
||||
ErrorCheckRule = Callable
|
||||
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
|
||||
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
|
||||
|
||||
def checkify_flat(fun: lu.WrappedFun, *args):
|
||||
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], *args):
|
||||
fun, msgs = checkify_subtrace(fun)
|
||||
fun = checkify_traceable(fun, tuple(init_error.msgs.items()))
|
||||
fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors)
|
||||
err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
|
||||
return (err, code, outvals), msgs()
|
||||
|
||||
@lu.transformation
|
||||
def checkify_traceable(msgs, err, code, *args):
|
||||
with core.new_main(CheckifyTrace) as main:
|
||||
def checkify_traceable(msgs, enabled_errors, err, code, *args):
|
||||
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
|
||||
outs = yield (main, msgs, err, code, *args), {}
|
||||
del main
|
||||
yield outs
|
||||
@ -196,13 +204,13 @@ def checkify_subtrace(main, msgs, err, code, *args):
|
||||
|
||||
|
||||
# TODO take (error_aval, code_aval) instead of error here?
|
||||
def checkify_jaxpr(jaxpr, error):
|
||||
def checkify_jaxpr(jaxpr, error, enabled_errors):
|
||||
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
|
||||
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
|
||||
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)
|
||||
|
||||
def checkify_fun_to_jaxpr(f, error, in_avals):
|
||||
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
|
||||
f, msgs = checkify_subtrace(f)
|
||||
f = checkify_traceable(f, tuple(error.msgs.items()))
|
||||
f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors)
|
||||
err_aval = core.raise_to_shaped(core.get_aval(error.err))
|
||||
code_aval = core.raise_to_shaped(core.get_aval(error.code))
|
||||
avals_in = [err_aval, code_aval, *in_avals]
|
||||
@ -244,13 +252,15 @@ def assert_abstract_eval(pred, code, *, msgs):
|
||||
def summary() -> str:
|
||||
return str(source_info_util.summarize(source_info_util.current()))
|
||||
|
||||
def nan_error_check(prim, error, *in_vals, **params):
|
||||
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
|
||||
out = prim.bind(*in_vals, **params)
|
||||
if ErrorCategory.NAN not in enabled_errors:
|
||||
return out, error
|
||||
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
|
||||
msg = f"nan generated by primitive {prim.name} at {summary()}"
|
||||
return out, assert_func(error, no_nans, msg)
|
||||
|
||||
def gather_error_check(error, operand, start_indices, *,
|
||||
def gather_error_check(error, enabled_errors, operand, start_indices, *,
|
||||
dimension_numbers, slice_sizes, unique_indices,
|
||||
indices_are_sorted, mode, fill_value):
|
||||
out = lax.gather_p.bind(
|
||||
@ -258,6 +268,9 @@ def gather_error_check(error, operand, start_indices, *,
|
||||
slice_sizes=slice_sizes, unique_indices=unique_indices,
|
||||
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
|
||||
|
||||
if ErrorCategory.OOB not in enabled_errors:
|
||||
return out, error
|
||||
|
||||
# compare to OOB masking logic in lax._gather_translation_rule
|
||||
dnums = dimension_numbers
|
||||
operand_dims = np.array(operand.shape)
|
||||
@ -270,12 +283,13 @@ def gather_error_check(error, operand, start_indices, *,
|
||||
return out, assert_func(error, all_inbounds, msg)
|
||||
error_checks[lax.gather_p] = gather_error_check
|
||||
|
||||
def div_error_check(error, x, y):
|
||||
def div_error_check(error, enabled_errors, x, y):
|
||||
"""Checks for division by zero and NaN."""
|
||||
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
|
||||
msg = f'divided by zero at {summary()}'
|
||||
div_by_zero_err = assert_func(error, all_nonzero, msg)
|
||||
return nan_error_check(lax.div_p, div_by_zero_err, x, y)
|
||||
if ErrorCategory.DIV in enabled_errors:
|
||||
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
|
||||
msg = f'divided by zero at {summary()}'
|
||||
error = assert_func(error, all_nonzero, msg)
|
||||
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
|
||||
error_checks[lax.div_p] = div_error_check
|
||||
|
||||
def scatter_in_bounds(operand, indices, updates, dnums):
|
||||
@ -300,10 +314,9 @@ def scatter_in_bounds(operand, indices, updates, dnums):
|
||||
upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
|
||||
return jnp.logical_and(lower_in_bounds, upper_in_bounds)
|
||||
|
||||
def scatter_error_check(prim, error, operand, indices, updates, *,
|
||||
update_jaxpr, update_consts,
|
||||
dimension_numbers, indices_are_sorted,
|
||||
unique_indices, mode):
|
||||
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
|
||||
*, update_jaxpr, update_consts, dimension_numbers,
|
||||
indices_are_sorted, unique_indices, mode):
|
||||
"""Checks if indices are within bounds and update does not generate NaN."""
|
||||
out = prim.bind(
|
||||
operand, indices, updates, update_jaxpr=update_jaxpr,
|
||||
@ -311,6 +324,9 @@ def scatter_error_check(prim, error, operand, indices, updates, *,
|
||||
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
||||
mode=mode)
|
||||
|
||||
if ErrorCategory.OOB not in enabled_errors:
|
||||
return out, error
|
||||
|
||||
in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers)
|
||||
oob_msg = f'out-of-bounds indexing while updating at {summary()}'
|
||||
oob_error = assert_func(error, in_bounds, oob_msg)
|
||||
@ -324,8 +340,8 @@ error_checks[lax.scatter_mul_p] = partial(scatter_error_check, lax.scatter_mul_p
|
||||
error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p)
|
||||
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
|
||||
|
||||
def cond_error_check(error, index, *ops, branches, linear):
|
||||
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
|
||||
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
|
||||
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) for jxpr in branches)
|
||||
new_linear = (False, False, *linear)
|
||||
err, code, *outs = lax.cond_p.bind(
|
||||
index, error.err, error.code, *ops,
|
||||
@ -334,9 +350,9 @@ def cond_error_check(error, index, *ops, branches, linear):
|
||||
return outs, Error(err, code, new_msgs)
|
||||
error_checks[lax.cond_p] = cond_error_check
|
||||
|
||||
def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
|
||||
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
|
||||
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
|
||||
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error)
|
||||
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
|
||||
new_linear = (False, False, *linear)
|
||||
new_in_flat = [*consts, error.err, error.code, *carry, *xs]
|
||||
err, code, *outs = lax.scan_p.bind(
|
||||
@ -348,14 +364,14 @@ def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_ca
|
||||
return outs, Error(err, code, new_msgs)
|
||||
error_checks[lax.scan_p] = scan_error_check
|
||||
|
||||
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error):
|
||||
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors):
|
||||
cond_f = core.jaxpr_as_fun(cond_jaxpr)
|
||||
body_f = core.jaxpr_as_fun(body_jaxpr)
|
||||
def new_body_f(*vals):
|
||||
out = body_f(*vals)
|
||||
_ = cond_f(*out) # this checks if the next cond application will error
|
||||
return out
|
||||
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
|
||||
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals)
|
||||
|
||||
def ignore_errors_jaxpr(jaxpr, error):
|
||||
"""Constructs a jaxpr which takes two extra args but ignores them."""
|
||||
@ -369,13 +385,13 @@ def ignore_errors_jaxpr(jaxpr, error):
|
||||
jaxpr.outvars, jaxpr.eqns)
|
||||
return core.ClosedJaxpr(new_jaxpr, consts)
|
||||
|
||||
def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
|
||||
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error)
|
||||
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
|
||||
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
|
||||
checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
|
||||
# Check if the first cond application will error.
|
||||
cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)
|
||||
|
||||
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
|
||||
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors)
|
||||
compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
|
||||
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
|
||||
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
|
||||
@ -453,7 +469,10 @@ add_nan_check(lax.select_p)
|
||||
add_nan_check(lax.max_p)
|
||||
add_nan_check(lax.min_p)
|
||||
|
||||
def assert_discharge_rule(error, pred, code, *, msgs):
|
||||
def assert_discharge_rule(error, enabled_errors, pred, code, *, msgs):
|
||||
if ErrorCategory.ASSERT not in enabled_errors:
|
||||
return [], error
|
||||
|
||||
out_err = error.err | jnp.logical_not(pred)
|
||||
out_code = lax.select(error.err, error.code, code)
|
||||
return [], Error(out_err, out_code, {**error.msgs, **msgs})
|
||||
@ -462,13 +481,24 @@ error_checks[assert_p] = assert_discharge_rule
|
||||
|
||||
## checkify api
|
||||
|
||||
ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'ASSERT'])
|
||||
|
||||
float_errors = {ErrorCategory.NAN, ErrorCategory.DIV}
|
||||
index_errors = {ErrorCategory.OOB}
|
||||
automatic_errors = float_errors | index_errors
|
||||
user_asserts = {ErrorCategory.ASSERT}
|
||||
|
||||
Out = TypeVar('Out')
|
||||
def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]:
|
||||
def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts) -> Callable[..., Tuple[Error, Out]]:
|
||||
if not errors:
|
||||
raise ValueError('Checkify needs to be called with at least one enabled'
|
||||
' ErrorCategory, was called with an empty errors set.')
|
||||
|
||||
@traceback_util.api_boundary
|
||||
def checked_fun(*args, **kwargs):
|
||||
args_flat, in_tree = tree_flatten((args, kwargs))
|
||||
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
|
||||
(err, code, out_flat), msgs = checkify_flat(f, *args_flat)
|
||||
(err, code, out_flat), msgs = checkify_flat(f, frozenset(errors), *args_flat)
|
||||
out = tree_unflatten(out_tree(), out_flat)
|
||||
return Error(err, code, msgs), out
|
||||
return checked_fun
|
||||
|
@ -39,11 +39,12 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return y1 + y2
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
err, _ = checkify.checkify(f)(3., 4.)
|
||||
err, _ = checked_f(3., 4.)
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
err, _ = checkify.checkify(f)(3., jnp.inf)
|
||||
err, _ = checked_f(3., jnp.inf)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
|
||||
@ -58,16 +59,17 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return w
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_errors)
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.arange(3), 2)
|
||||
err, _ = checked_f(jnp.arange(3), 2)
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.arange(3), 5)
|
||||
err, _ = checked_f(jnp.arange(3), 5)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_update={update_fn}", "update_fn": update_fn}
|
||||
{"testcase_name": f"_updatefn={update_fn}", "update_fn": update_fn}
|
||||
for update_fn in ["set", "add", "multiply", "divide", "power", "min",
|
||||
"max", "get"])
|
||||
def test_jit_oob_update(self, update_fn):
|
||||
@ -75,11 +77,12 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return getattr(x.at[i], update_fn)(1.)
|
||||
|
||||
f = jax.jit(f)
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_errors)
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.arange(3), 2)
|
||||
err, _ = checked_f(jnp.arange(3), 2)
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.arange(3), 3)
|
||||
err, _ = checked_f(jnp.arange(3), 3)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
|
||||
|
||||
@ -91,15 +94,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return x/y
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.ones((3,)))
|
||||
err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,)))
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.array([1, 0, 1]))
|
||||
err, _ = checked_f(jnp.ones((3,)), jnp.array([1, 0, 1]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
|
||||
err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive div')
|
||||
|
||||
@ -114,18 +118,19 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return z
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.automatic_errors)
|
||||
|
||||
# no error
|
||||
err, _ = checkify.checkify(f)(jnp.array([0., jnp.inf, 2.]), 2)
|
||||
err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2)
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
# oob error
|
||||
err, _ = checkify.checkify(f)(jnp.array([0., 1., 2.]), 5)
|
||||
err, _ = checked_f(jnp.array([0., 1., 2.]), 5)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
|
||||
|
||||
# nan error
|
||||
err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 2)
|
||||
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 2)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive cos')
|
||||
|
||||
@ -139,9 +144,10 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return y * z
|
||||
|
||||
f = jax.jit(f) if jit else f
|
||||
checked_f = checkify.checkify(f, errors=checkify.automatic_errors)
|
||||
|
||||
# both oob and nan error, but oob happens first
|
||||
err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 5)
|
||||
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 5)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
|
||||
|
||||
@ -155,13 +161,14 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
y1 = jnp.sin(x1)
|
||||
y2 = jnp.sin(x2)
|
||||
return y1 + y2
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
xs = jnp.array([0., 2.])
|
||||
err, _ = checkify.checkify(f)(xs, xs)
|
||||
err, _ = checked_f(xs, xs)
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
ys = jnp.array([3., jnp.inf])
|
||||
err, _ = checkify.checkify(f)(xs, ys)
|
||||
err, _ = checked_f(xs, ys)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
|
||||
@ -173,14 +180,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
lambda: jnp.sin(x),
|
||||
lambda: x)
|
||||
|
||||
err, y = checkify.checkify(f)(3.)
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
err, y = checked_f(3.)
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
err, y = checkify.checkify(f)(jnp.inf)
|
||||
err, y = checked_f(jnp.inf)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
|
||||
err, y = checkify.checkify(f)(-jnp.inf)
|
||||
err, y = checked_f(-jnp.inf)
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
|
||||
@ -193,14 +202,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(xs):
|
||||
return lax.scan(scan_body, None, xs)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
xs = jnp.array([0., 2.])
|
||||
err, (_, ch_outs) = checkify.checkify(f)(xs)
|
||||
err, (_, ch_outs) = checked_f(xs)
|
||||
_, outs = f(xs)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertArraysEqual(ch_outs, outs)
|
||||
|
||||
xs = jnp.array([3., jnp.inf])
|
||||
err, (_, ch_outs) = checkify.checkify(f)(xs)
|
||||
err, (_, ch_outs) = checked_f(xs)
|
||||
_, outs = f(xs)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
@ -217,8 +228,10 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(carry, xs):
|
||||
return lax.scan(scan_body, carry, xs)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
carry, xs = 3., jnp.ones((2,))
|
||||
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
|
||||
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
||||
out_carry, outs = f(carry, xs)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertArraysEqual(ch_outs, outs)
|
||||
@ -226,7 +239,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
# error happens on first iteration
|
||||
carry, xs = 1., jnp.ones((2,))
|
||||
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
|
||||
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
||||
out_carry, outs = f(carry, xs)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
@ -235,7 +248,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
# error happens on second iteration
|
||||
carry, xs = 2., jnp.ones((4,))
|
||||
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
|
||||
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
||||
out_carry, outs = f(carry, xs)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
@ -257,14 +270,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(init_val):
|
||||
return lax.while_loop(while_cond, while_body, (init_val, 0.))
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
init_val = 1.
|
||||
err, ch_out = checkify.checkify(f)(init_val)
|
||||
err, ch_out = checked_f(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertArraysEqual(ch_out, out)
|
||||
|
||||
init_val = 0.
|
||||
err, ch_out = checkify.checkify(f)(init_val)
|
||||
err, ch_out = checked_f(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
@ -283,14 +298,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(init_val):
|
||||
return lax.while_loop(while_cond, while_body, init_val)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
init_val = 1.
|
||||
err, ch_out = checkify.checkify(f)(init_val)
|
||||
err, ch_out = checked_f(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertArraysEqual(ch_out, out)
|
||||
|
||||
init_val = 0.
|
||||
err, ch_out = checkify.checkify(f)(init_val)
|
||||
err, ch_out = checked_f(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
@ -307,15 +324,17 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(init_val):
|
||||
return lax.while_loop(while_cond, lambda val: val-1, init_val)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
# error on first cond
|
||||
init_val = 0.
|
||||
err, _ = checkify.checkify(f)(init_val)
|
||||
err, _ = checked_f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
|
||||
# error on second cond
|
||||
init_val = 1.
|
||||
err, _ = checkify.checkify(f)(init_val)
|
||||
err, _ = checked_f(init_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "divided by zero")
|
||||
|
||||
@ -335,25 +354,53 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def f(cond_val, body_val):
|
||||
return lax.while_loop(while_cond, while_body, (0., cond_val, body_val))
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_errors)
|
||||
|
||||
cond_val = jnp.inf
|
||||
body_val = 1.
|
||||
err, _ = checkify.checkify(f)(cond_val, body_val)
|
||||
err, _ = checked_f(cond_val, body_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
|
||||
cond_val = 1.
|
||||
body_val = jnp.inf
|
||||
err, _ = checkify.checkify(f)(cond_val, body_val)
|
||||
err, _ = checked_f(cond_val, body_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive cos")
|
||||
|
||||
cond_val = jnp.inf
|
||||
body_val = jnp.inf
|
||||
err, _ = checkify.checkify(f)(cond_val, body_val)
|
||||
err, _ = checked_f(cond_val, body_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
# first error which occurs is in cond
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
|
||||
def test_empty_enabled_errors(self):
|
||||
with self.assertRaisesRegex(ValueError, 'called with an empty errors set'):
|
||||
checkify.checkify(lambda x: x, errors={})
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("assert", checkify.user_asserts, "must be negative!"),
|
||||
("div", {checkify.ErrorCategory.DIV}, "divided by zero"),
|
||||
("nan", {checkify.ErrorCategory.NAN}, "nan generated"),
|
||||
("oob", checkify.index_errors, "out-of-bounds indexing"),
|
||||
("automatic_errors", checkify.automatic_errors, "divided by zero"),
|
||||
)
|
||||
@jtu.skip_on_devices('tpu')
|
||||
def test_enabled_errors(self, error_set, expected_error):
|
||||
def multi_errors(x):
|
||||
x = x/0 # DIV
|
||||
x = jnp.sin(x) # NAN
|
||||
x = x[500] # OOB
|
||||
checkify.assert_(x < 0, "must be negative!") # ASSERT
|
||||
return x
|
||||
|
||||
x = jnp.ones((2,))
|
||||
err, _ = checkify.checkify(multi_errors, errors=error_set)(x)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), expected_error)
|
||||
|
||||
|
||||
class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
def test_assert_primitive_impl(self):
|
||||
def f():
|
||||
|
Loading…
x
Reference in New Issue
Block a user