Checkify: add way to disable categories of errors.

By default only user_asserts are lifted into the checked function.
This commit is contained in:
Lena Martens 2022-01-10 18:21:41 +00:00 committed by lenamartens
parent 6411f8a033
commit 8ea85769ea
2 changed files with 144 additions and 67 deletions

View File

@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import enum
from dataclasses import dataclass
from functools import partial
import itertools as it
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar
from typing import Union, Optional, Callable, Dict, Tuple, TypeVar, Set, FrozenSet
import numpy as np
@ -97,6 +98,13 @@ class CheckifyTracer(core.Tracer):
class CheckifyTrace(core.Trace):
pure = lift = lambda self, val: CheckifyTracer(self, val)
def __init__(self, main: core.MainTrace, sublevel: core.Sublevel,
enabled_errors: FrozenSet['ErrorCategory']) -> None:
self.main = main
self.level = main.level
self.sublevel = sublevel
self.main.enabled_errors = enabled_errors
def sublift(self, tracer):
return CheckifyTracer(self, tracer.val)
@ -104,7 +112,7 @@ class CheckifyTrace(core.Trace):
in_vals = [t.val for t in tracers]
rule = error_checks.get(primitive)
if rule:
out, self.main.error = rule(self.main.error, *in_vals, **params) # type: ignore
out, self.main.error = rule(self.main.error, self.main.enabled_errors, *in_vals, **params) # type: ignore
else:
out = primitive.bind(*in_vals, **params)
if primitive.multiple_results:
@ -166,18 +174,18 @@ def _reduce_any_error(errs, codes):
errs_, codes_ = lax.sort_key_val(errs, codes, dimension=0)
return errs_[-1], codes_[-1]
ErrorCheckRule = Callable
ErrorCheckRule = Callable # (Error, FrozenSet[ErrorCategory], *in_vals, **params) -> (Any, Error)
error_checks: Dict[core.Primitive, ErrorCheckRule] = {}
def checkify_flat(fun: lu.WrappedFun, *args):
def checkify_flat(fun: lu.WrappedFun, enabled_errors: FrozenSet['ErrorCategory'], *args):
fun, msgs = checkify_subtrace(fun)
fun = checkify_traceable(fun, tuple(init_error.msgs.items()))
fun = checkify_traceable(fun, tuple(init_error.msgs.items()), enabled_errors)
err, code, *outvals = fun.call_wrapped(init_error.err, init_error.code, *args)
return (err, code, outvals), msgs()
@lu.transformation
def checkify_traceable(msgs, err, code, *args):
with core.new_main(CheckifyTrace) as main:
def checkify_traceable(msgs, enabled_errors, err, code, *args):
with core.new_main(CheckifyTrace, enabled_errors=enabled_errors) as main:
outs = yield (main, msgs, err, code, *args), {}
del main
yield outs
@ -196,13 +204,13 @@ def checkify_subtrace(main, msgs, err, code, *args):
# TODO take (error_aval, code_aval) instead of error here?
def checkify_jaxpr(jaxpr, error):
def checkify_jaxpr(jaxpr, error, enabled_errors):
f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
return checkify_fun_to_jaxpr(f, error, jaxpr.in_avals)
return checkify_fun_to_jaxpr(f, error, enabled_errors, jaxpr.in_avals)
def checkify_fun_to_jaxpr(f, error, in_avals):
def checkify_fun_to_jaxpr(f, error, enabled_errors, in_avals):
f, msgs = checkify_subtrace(f)
f = checkify_traceable(f, tuple(error.msgs.items()))
f = checkify_traceable(f, tuple(error.msgs.items()), enabled_errors)
err_aval = core.raise_to_shaped(core.get_aval(error.err))
code_aval = core.raise_to_shaped(core.get_aval(error.code))
avals_in = [err_aval, code_aval, *in_avals]
@ -244,13 +252,15 @@ def assert_abstract_eval(pred, code, *, msgs):
def summary() -> str:
return str(source_info_util.summarize(source_info_util.current()))
def nan_error_check(prim, error, *in_vals, **params):
def nan_error_check(prim, error, enabled_errors, *in_vals, **params):
out = prim.bind(*in_vals, **params)
if ErrorCategory.NAN not in enabled_errors:
return out, error
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
msg = f"nan generated by primitive {prim.name} at {summary()}"
return out, assert_func(error, no_nans, msg)
def gather_error_check(error, operand, start_indices, *,
def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):
out = lax.gather_p.bind(
@ -258,6 +268,9 @@ def gather_error_check(error, operand, start_indices, *,
slice_sizes=slice_sizes, unique_indices=unique_indices,
indices_are_sorted=indices_are_sorted, mode=mode, fill_value=fill_value)
if ErrorCategory.OOB not in enabled_errors:
return out, error
# compare to OOB masking logic in lax._gather_translation_rule
dnums = dimension_numbers
operand_dims = np.array(operand.shape)
@ -270,12 +283,13 @@ def gather_error_check(error, operand, start_indices, *,
return out, assert_func(error, all_inbounds, msg)
error_checks[lax.gather_p] = gather_error_check
def div_error_check(error, x, y):
def div_error_check(error, enabled_errors, x, y):
"""Checks for division by zero and NaN."""
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
msg = f'divided by zero at {summary()}'
div_by_zero_err = assert_func(error, all_nonzero, msg)
return nan_error_check(lax.div_p, div_by_zero_err, x, y)
if ErrorCategory.DIV in enabled_errors:
all_nonzero = jnp.logical_not(jnp.any(jnp.equal(y, 0)))
msg = f'divided by zero at {summary()}'
error = assert_func(error, all_nonzero, msg)
return nan_error_check(lax.div_p, error, enabled_errors, x, y)
error_checks[lax.div_p] = div_error_check
def scatter_in_bounds(operand, indices, updates, dnums):
@ -300,10 +314,9 @@ def scatter_in_bounds(operand, indices, updates, dnums):
upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
return jnp.logical_and(lower_in_bounds, upper_in_bounds)
def scatter_error_check(prim, error, operand, indices, updates, *,
update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted,
unique_indices, mode):
def scatter_error_check(prim, error, enabled_errors, operand, indices, updates,
*, update_jaxpr, update_consts, dimension_numbers,
indices_are_sorted, unique_indices, mode):
"""Checks if indices are within bounds and update does not generate NaN."""
out = prim.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
@ -311,6 +324,9 @@ def scatter_error_check(prim, error, operand, indices, updates, *,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
if ErrorCategory.OOB not in enabled_errors:
return out, error
in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers)
oob_msg = f'out-of-bounds indexing while updating at {summary()}'
oob_error = assert_func(error, in_bounds, oob_msg)
@ -324,8 +340,8 @@ error_checks[lax.scatter_mul_p] = partial(scatter_error_check, lax.scatter_mul_p
error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p)
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
def cond_error_check(error, index, *ops, branches, linear):
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
def cond_error_check(error, enabled_errors, index, *ops, branches, linear):
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error, enabled_errors) for jxpr in branches)
new_linear = (False, False, *linear)
err, code, *outs = lax.cond_p.bind(
index, error.err, error.code, *ops,
@ -334,9 +350,9 @@ def cond_error_check(error, index, *ops, branches, linear):
return outs, Error(err, code, new_msgs)
error_checks[lax.cond_p] = cond_error_check
def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
def scan_error_check(error, enabled_errors, *in_flat, reverse, length, jaxpr, num_consts, num_carry, linear, unroll):
consts, carry, xs = split_list(in_flat, [num_consts, num_carry])
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error)
checked_jaxpr, msgs_ = checkify_jaxpr(jaxpr, error, enabled_errors)
new_linear = (False, False, *linear)
new_in_flat = [*consts, error.err, error.code, *carry, *xs]
err, code, *outs = lax.scan_p.bind(
@ -348,14 +364,14 @@ def scan_error_check(error, *in_flat, reverse, length, jaxpr, num_consts, num_ca
return outs, Error(err, code, new_msgs)
error_checks[lax.scan_p] = scan_error_check
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error):
def checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors):
cond_f = core.jaxpr_as_fun(cond_jaxpr)
body_f = core.jaxpr_as_fun(body_jaxpr)
def new_body_f(*vals):
out = body_f(*vals)
_ = cond_f(*out) # this checks if the next cond application will error
return out
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, body_jaxpr.in_avals)
return checkify_fun_to_jaxpr(lu.wrap_init(new_body_f), error, enabled_errors, body_jaxpr.in_avals)
def ignore_errors_jaxpr(jaxpr, error):
"""Constructs a jaxpr which takes two extra args but ignores them."""
@ -369,13 +385,13 @@ def ignore_errors_jaxpr(jaxpr, error):
jaxpr.outvars, jaxpr.eqns)
return core.ClosedJaxpr(new_jaxpr, consts)
def while_loop_error_check(error, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error)
def while_loop_error_check(error, enabled_errors, *in_flat, cond_nconsts, cond_jaxpr, body_nconsts, body_jaxpr):
checked_cond_jaxpr, msgs_cond = checkify_jaxpr(cond_jaxpr, error, enabled_errors)
checked_cond_fun = core.jaxpr_as_fun(checked_cond_jaxpr)
# Check if the first cond application will error.
cond_err, cond_code, _ = checked_cond_fun(error.err, error.code, *in_flat)
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error)
checked_body_jaxpr, msgs_body = checkify_while_body_jaxpr(cond_jaxpr, body_jaxpr, error, enabled_errors)
compat_cond_jaxpr = ignore_errors_jaxpr(cond_jaxpr, error)
c_consts, b_consts, carry = split_list(in_flat, [cond_nconsts, body_nconsts])
new_in_flat = [*c_consts, *b_consts, cond_err, cond_code, *carry]
@ -453,7 +469,10 @@ add_nan_check(lax.select_p)
add_nan_check(lax.max_p)
add_nan_check(lax.min_p)
def assert_discharge_rule(error, pred, code, *, msgs):
def assert_discharge_rule(error, enabled_errors, pred, code, *, msgs):
if ErrorCategory.ASSERT not in enabled_errors:
return [], error
out_err = error.err | jnp.logical_not(pred)
out_code = lax.select(error.err, error.code, code)
return [], Error(out_err, out_code, {**error.msgs, **msgs})
@ -462,13 +481,24 @@ error_checks[assert_p] = assert_discharge_rule
## checkify api
ErrorCategory = enum.Enum('ErrorCategory', ['NAN', 'OOB', 'DIV', 'ASSERT'])
float_errors = {ErrorCategory.NAN, ErrorCategory.DIV}
index_errors = {ErrorCategory.OOB}
automatic_errors = float_errors | index_errors
user_asserts = {ErrorCategory.ASSERT}
Out = TypeVar('Out')
def checkify(fun: Callable[..., Out]) -> Callable[..., Tuple[Error, Out]]:
def checkify(fun: Callable[..., Out], errors: Set[ErrorCategory] = user_asserts) -> Callable[..., Tuple[Error, Out]]:
if not errors:
raise ValueError('Checkify needs to be called with at least one enabled'
' ErrorCategory, was called with an empty errors set.')
@traceback_util.api_boundary
def checked_fun(*args, **kwargs):
args_flat, in_tree = tree_flatten((args, kwargs))
f, out_tree = flatten_fun(lu.wrap_init(fun), in_tree)
(err, code, out_flat), msgs = checkify_flat(f, *args_flat)
(err, code, out_flat), msgs = checkify_flat(f, frozenset(errors), *args_flat)
out = tree_unflatten(out_tree(), out_flat)
return Error(err, code, msgs), out
return checked_fun

View File

@ -39,11 +39,12 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return y1 + y2
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.float_errors)
err, _ = checkify.checkify(f)(3., 4.)
err, _ = checked_f(3., 4.)
self.assertIs(err.get(), None)
err, _ = checkify.checkify(f)(3., jnp.inf)
err, _ = checked_f(3., jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
@ -58,16 +59,17 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return w
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.index_errors)
err, _ = checkify.checkify(f)(jnp.arange(3), 2)
err, _ = checked_f(jnp.arange(3), 2)
self.assertIs(err.get(), None)
err, _ = checkify.checkify(f)(jnp.arange(3), 5)
err, _ = checked_f(jnp.arange(3), 5)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
@parameterized.named_parameters(
{"testcase_name": f"_update={update_fn}", "update_fn": update_fn}
{"testcase_name": f"_updatefn={update_fn}", "update_fn": update_fn}
for update_fn in ["set", "add", "multiply", "divide", "power", "min",
"max", "get"])
def test_jit_oob_update(self, update_fn):
@ -75,11 +77,12 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return getattr(x.at[i], update_fn)(1.)
f = jax.jit(f)
checked_f = checkify.checkify(f, errors=checkify.index_errors)
err, _ = checkify.checkify(f)(jnp.arange(3), 2)
err, _ = checked_f(jnp.arange(3), 2)
self.assertIs(err.get(), None)
err, _ = checkify.checkify(f)(jnp.arange(3), 3)
err, _ = checked_f(jnp.arange(3), 3)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
@ -91,15 +94,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return x/y
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.float_errors)
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.ones((3,)))
err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,)))
self.assertIs(err.get(), None)
err, _ = checkify.checkify(f)(jnp.ones((3,)), jnp.array([1, 0, 1]))
err, _ = checked_f(jnp.ones((3,)), jnp.array([1, 0, 1]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
err, _ = checkify.checkify(f)(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive div')
@ -114,18 +118,19 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return z
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.automatic_errors)
# no error
err, _ = checkify.checkify(f)(jnp.array([0., jnp.inf, 2.]), 2)
err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2)
self.assertIs(err.get(), None)
# oob error
err, _ = checkify.checkify(f)(jnp.array([0., 1., 2.]), 5)
err, _ = checked_f(jnp.array([0., 1., 2.]), 5)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
# nan error
err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 2)
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 2)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive cos')
@ -139,9 +144,10 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return y * z
f = jax.jit(f) if jit else f
checked_f = checkify.checkify(f, errors=checkify.automatic_errors)
# both oob and nan error, but oob happens first
err, _ = checkify.checkify(f)(jnp.array([0., 1., jnp.inf]), 5)
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 5)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
@ -155,13 +161,14 @@ class CheckifyTransformTests(jtu.JaxTestCase):
y1 = jnp.sin(x1)
y2 = jnp.sin(x2)
return y1 + y2
checked_f = checkify.checkify(f, errors=checkify.float_errors)
xs = jnp.array([0., 2.])
err, _ = checkify.checkify(f)(xs, xs)
err, _ = checked_f(xs, xs)
self.assertIs(err.get(), None)
ys = jnp.array([3., jnp.inf])
err, _ = checkify.checkify(f)(xs, ys)
err, _ = checked_f(xs, ys)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
@ -173,14 +180,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
lambda: jnp.sin(x),
lambda: x)
err, y = checkify.checkify(f)(3.)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
err, y = checked_f(3.)
self.assertIs(err.get(), None)
err, y = checkify.checkify(f)(jnp.inf)
err, y = checked_f(jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
err, y = checkify.checkify(f)(-jnp.inf)
err, y = checked_f(-jnp.inf)
self.assertIs(err.get(), None)
@ -193,14 +202,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(xs):
return lax.scan(scan_body, None, xs)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
xs = jnp.array([0., 2.])
err, (_, ch_outs) = checkify.checkify(f)(xs)
err, (_, ch_outs) = checked_f(xs)
_, outs = f(xs)
self.assertIs(err.get(), None)
self.assertArraysEqual(ch_outs, outs)
xs = jnp.array([3., jnp.inf])
err, (_, ch_outs) = checkify.checkify(f)(xs)
err, (_, ch_outs) = checked_f(xs)
_, outs = f(xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
@ -217,8 +228,10 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(carry, xs):
return lax.scan(scan_body, carry, xs)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
carry, xs = 3., jnp.ones((2,))
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIs(err.get(), None)
self.assertArraysEqual(ch_outs, outs)
@ -226,7 +239,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
# error happens on first iteration
carry, xs = 1., jnp.ones((2,))
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
@ -235,7 +248,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
# error happens on second iteration
carry, xs = 2., jnp.ones((4,))
err, (ch_out_carry, ch_outs) = checkify.checkify(f)(carry, xs)
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
@ -257,14 +270,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(init_val):
return lax.while_loop(while_cond, while_body, (init_val, 0.))
checked_f = checkify.checkify(f, errors=checkify.float_errors)
init_val = 1.
err, ch_out = checkify.checkify(f)(init_val)
err, ch_out = checked_f(init_val)
out = f(init_val)
self.assertIs(err.get(), None)
self.assertArraysEqual(ch_out, out)
init_val = 0.
err, ch_out = checkify.checkify(f)(init_val)
err, ch_out = checked_f(init_val)
out = f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
@ -283,14 +298,16 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(init_val):
return lax.while_loop(while_cond, while_body, init_val)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
init_val = 1.
err, ch_out = checkify.checkify(f)(init_val)
err, ch_out = checked_f(init_val)
out = f(init_val)
self.assertIs(err.get(), None)
self.assertArraysEqual(ch_out, out)
init_val = 0.
err, ch_out = checkify.checkify(f)(init_val)
err, ch_out = checked_f(init_val)
out = f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
@ -307,15 +324,17 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(init_val):
return lax.while_loop(while_cond, lambda val: val-1, init_val)
checked_f = checkify.checkify(f, errors=checkify.float_errors)
# error on first cond
init_val = 0.
err, _ = checkify.checkify(f)(init_val)
err, _ = checked_f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
# error on second cond
init_val = 1.
err, _ = checkify.checkify(f)(init_val)
err, _ = checked_f(init_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "divided by zero")
@ -335,25 +354,53 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def f(cond_val, body_val):
return lax.while_loop(while_cond, while_body, (0., cond_val, body_val))
checked_f = checkify.checkify(f, errors=checkify.float_errors)
cond_val = jnp.inf
body_val = 1.
err, _ = checkify.checkify(f)(cond_val, body_val)
err, _ = checked_f(cond_val, body_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
cond_val = 1.
body_val = jnp.inf
err, _ = checkify.checkify(f)(cond_val, body_val)
err, _ = checked_f(cond_val, body_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive cos")
cond_val = jnp.inf
body_val = jnp.inf
err, _ = checkify.checkify(f)(cond_val, body_val)
err, _ = checked_f(cond_val, body_val)
self.assertIsNotNone(err.get())
# first error which occurs is in cond
self.assertStartsWith(err.get(), "nan generated by primitive sin")
def test_empty_enabled_errors(self):
with self.assertRaisesRegex(ValueError, 'called with an empty errors set'):
checkify.checkify(lambda x: x, errors={})
@parameterized.named_parameters(
("assert", checkify.user_asserts, "must be negative!"),
("div", {checkify.ErrorCategory.DIV}, "divided by zero"),
("nan", {checkify.ErrorCategory.NAN}, "nan generated"),
("oob", checkify.index_errors, "out-of-bounds indexing"),
("automatic_errors", checkify.automatic_errors, "divided by zero"),
)
@jtu.skip_on_devices('tpu')
def test_enabled_errors(self, error_set, expected_error):
def multi_errors(x):
x = x/0 # DIV
x = jnp.sin(x) # NAN
x = x[500] # OOB
checkify.assert_(x < 0, "must be negative!") # ASSERT
return x
x = jnp.ones((2,))
err, _ = checkify.checkify(multi_errors, errors=error_set)(x)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), expected_error)
class AssertPrimitiveTests(jtu.JaxTestCase):
def test_assert_primitive_impl(self):
def f():