Rewrite Checkify to support tracking different error types.

In general, behavior should remain the same and this is not a breaking
change.

There are some minor changes to the API:
  - checkify.ErrorCategory has changed type: it's no longer an Enum, but
    the JaxException type. These have not been exposed as part of the
    public API.
  - some attributes on Error have changed and made private
  - The raised error has changed type (JaxRuntimeError), and will have a
    different traceback (pointing to the origin of the error + where the
    error value was raised).
  - `checkify.check` now supports formating error message with variable
    size runtime info!

Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
This commit is contained in:
lenamartens 2022-10-19 00:53:24 +01:00
parent ebee4f4bfd
commit e4757e8410
6 changed files with 796 additions and 382 deletions

View File

@ -14,7 +14,7 @@ API
check
check_error
Error
ErrorCategory
JaxRuntimeError
user_checks
nan_checks
index_checks

File diff suppressed because it is too large Load Diff

View File

@ -246,6 +246,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
# Raise index in case of effects to allow data-dependence-based discharging
# of those effects (even if they don't have an explicit data dependence).
index = core.raise_as_much_as_possible(index)
false_jaxpr = false_jaxpr.replace(
jaxpr=false_jaxpr.jaxpr.replace(effects=joined_effects))
true_jaxpr = true_jaxpr.replace(
jaxpr=true_jaxpr.jaxpr.replace(effects=joined_effects))
linear = [False] * len(consts) + linear_ops
out = cond_p.bind(

View File

@ -15,6 +15,7 @@
from jax._src.checkify import (
Error as Error,
ErrorCategory as ErrorCategory,
JaxRuntimeError as JaxRuntimeError,
all_checks as all_checks,
automatic_checks as automatic_checks,
check as check,

View File

@ -1219,7 +1219,7 @@ tf_not_yet_impl = [
# Not high priority?
"after_all",
"all_to_all",
"assert",
"check",
"create_token",
"custom_transpose_call",
"custom_vmap_call",

View File

@ -29,12 +29,13 @@ from jax.experimental import pjit
from jax.experimental import maps
from jax._src.sharding import NamedSharding
from jax._src import array
from jax._src.checkify import CheckEffect
from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError
import jax.numpy as jnp
config.parse_flags_with_absl()
@jtu.with_config(jax_check_tracer_leaks=True)
class CheckifyTransformTests(jtu.JaxTestCase):
@jtu.sample_product(jit=[False, True])
@ -49,11 +50,11 @@ class CheckifyTransformTests(jtu.JaxTestCase):
checked_f = checkify.checkify(f, errors=checkify.float_checks)
err, _ = checked_f(3., 4.)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
err, _ = checked_f(3., jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
@jtu.sample_product(jit=[False, True])
def test_jit_oob(self, jit):
@ -67,7 +68,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
checked_f = checkify.checkify(f, errors=checkify.index_checks)
err, _ = checked_f(jnp.arange(3), 2)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
err, _ = checked_f(jnp.arange(3), 5)
self.assertIsNotNone(err.get())
@ -83,7 +84,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
checked_f = checkify.checkify(f, errors=checkify.index_checks)
err, _ = checked_f(jnp.arange(3), 2)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
err, _ = checked_f(jnp.arange(3), 3)
self.assertIsNotNone(err.get())
@ -98,15 +99,14 @@ class CheckifyTransformTests(jtu.JaxTestCase):
checked_f = checkify.checkify(f, errors=checkify.float_checks)
err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,)))
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
err, _ = checked_f(jnp.ones((3,)), jnp.array([1., 0., 1.]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "division by zero")
err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive div")
self.assertStartsWith(err.get(), "nan generated by primitive: div")
@jtu.sample_product(jit=[False, True])
@jtu.skip_on_devices("tpu")
@ -121,7 +121,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
# no error
err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
# oob error
err, _ = checked_f(jnp.array([0., 1., 2.]), 5)
@ -131,11 +131,20 @@ class CheckifyTransformTests(jtu.JaxTestCase):
# nan error
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 2)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive cos")
self.assertStartsWith(err.get(), "nan generated by primitive: cos")
def test_numpy_indexing_oobs(self):
@parameterized.named_parameters(
("gather", lambda x: x.get()),
("scatter_add", lambda x: x.add(1.)),
("scatter_mul", lambda x: x.multiply(1.)),
("scatter_div", lambda x: x.divide(1.)),
("scatter_pow", lambda x: x.power(1.)),
("scatter_min", lambda x: x.min(1.)),
("scatter_max", lambda x: x.max(1.)),
)
def test_numpy_indexing_oobs(self, update_op):
def raises_oob(fn, idx, *expected_strs):
err, _ = checkify.checkify(fn, errors=checkify.index_checks)(x, idx)
err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, idx)
error_txt = err.get()
self.assertIsNotNone(error_txt)
self.assertStartsWith(error_txt, "out-of-bounds indexing")
@ -147,7 +156,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
axis1_msg = "axis 1 with size 3"
axis2_msg = "axis 2 with size 7"
single_idx = lambda x, i: x[i]
single_idx = lambda x, i: update_op(x.at[i])
raises_oob(single_idx, 5, "index 5", axis0_msg)
raises_oob(single_idx, -5, "index -3", axis0_msg)
raises_oob(single_idx, (0, 100), "index 100", axis1_msg)
@ -158,7 +167,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
raises_oob(single_idx, (((1, 1), (1, 20)), 3), "index 3", axis1_msg)
raises_oob(single_idx, (((1, 1), (1, 20)), 0), "index 20", axis0_msg)
multi_idx = lambda x, i: x[i[0], :, i[1]]
multi_idx = lambda x, i: update_op(x.at[i[0], :, i[1]])
raises_oob(multi_idx, (0, 9), "index 9", axis2_msg)
# TODO(lenamartens): numpy reports index -5 here, need to normalize?
raises_oob(multi_idx, (-5, 9), "index -3", axis0_msg)
@ -194,32 +203,77 @@ class CheckifyTransformTests(jtu.JaxTestCase):
xs = jnp.array([0., 2.])
err, _ = checked_f(xs, xs)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
ys = jnp.array([3., jnp.inf])
err, _ = checked_f(xs, ys)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
@jtu.skip_on_devices("tpu")
def test_cond_basic(self):
@jax.jit
def f(x):
return lax.cond(x > 0,
lambda: jnp.sin(x),
lambda: x)
def true_fun(x):
return jnp.sin(x)
def false_fun(x):
checkify.check(x > -1, "oh no")
return x / 0.
return lax.cond(x > 0, true_fun, false_fun, x)
checked_f = checkify.checkify(f, errors=checkify.float_checks)
checked_f = checkify.checkify(f, errors=checkify.all_checks)
err, _ = checked_f(3.)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
err, _ = checked_f(jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
err, _ = checked_f(-jnp.inf)
self.assertIs(err.get(), None)
self.assertStartsWith(err.get(), "oh no")
err, _ = checked_f(0.)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "division by zero")
def test_cond_different_payloads(self):
@jax.jit
def f(x):
def true_fun(x):
checkify.check(~x, "{one}", one=x)
def false_fun(x):
checkify.check(x, "{one} and {two}", one=x, two=x)
return lax.cond(x, true_fun, false_fun, x)
checked_f = checkify.checkify(f)
err, _ = checked_f(True)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "True")
err, _ = checked_f(False)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "False and False")
def test_cond_nd_payloads(self):
@jax.jit
def f(x):
def true_fun(x):
checkify.check(jnp.all(x > 0), "{one}", one=x)
def false_fun(x):
checkify.check(jnp.all(x < 0), "{one} and {two}", one=x, two=x)
return lax.cond(jnp.all(x < 0), true_fun, false_fun, x)
checked_f = checkify.checkify(f)
err, _ = checked_f(jnp.arange(0, 4))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "[0 1 2 3] and [0 1 2 3]")
err, _ = checked_f(jnp.arange(-4, -1))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "[-4 -3 -2]")
@jtu.skip_on_devices("tpu")
def test_scan_map(self):
@ -235,14 +289,14 @@ class CheckifyTransformTests(jtu.JaxTestCase):
xs = jnp.array([0., 2.])
err, (_, ch_outs) = checked_f(xs)
_, outs = f(xs)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
self.assertArraysEqual(ch_outs, outs)
xs = jnp.array([3., jnp.inf])
err, (_, ch_outs) = checked_f(xs)
_, outs = f(xs)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
self.assertArraysEqual(ch_outs, outs)
@jtu.skip_on_devices("tpu")
@ -261,7 +315,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
carry, xs = 3., jnp.ones((2,))
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
out_carry, outs = f(carry, xs)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
self.assertArraysEqual(ch_outs, outs)
self.assertArraysEqual(ch_out_carry, out_carry)
@ -303,7 +357,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
init_val = 1.
err, ch_out = checked_f(init_val)
out = f(init_val)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
self.assertArraysEqual(ch_out, out)
init_val = 0.
@ -331,7 +385,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
init_val = 1.
err, ch_out = checked_f(init_val)
out = f(init_val)
self.assertIs(err.get(), None)
self.assertIsNone(err.get())
self.assertArraysEqual(ch_out, out)
init_val = 0.
@ -388,20 +442,20 @@ class CheckifyTransformTests(jtu.JaxTestCase):
body_val = 1.
err, _ = checked_f(cond_val, body_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
cond_val = 1.
body_val = jnp.inf
err, _ = checked_f(cond_val, body_val)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive cos")
self.assertStartsWith(err.get(), "nan generated by primitive: cos")
cond_val = jnp.inf
body_val = jnp.inf
err, _ = checked_f(cond_val, body_val)
self.assertIsNotNone(err.get())
# first error which occurs is in cond
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
def test_pjit(self):
def f(x):
@ -448,8 +502,8 @@ class CheckifyTransformTests(jtu.JaxTestCase):
@parameterized.named_parameters(
("assert", checkify.user_checks, "must be negative!"),
("div", {checkify.ErrorCategory.DIV}, "division by zero"),
("nan", {checkify.ErrorCategory.NAN}, "nan generated"),
("div", checkify.div_checks, "division by zero"),
("nan", checkify.nan_checks, "nan generated"),
("oob", checkify.index_checks, "out-of-bounds indexing"),
("automatic_checks", checkify.automatic_checks, "division by zero"),
)
@ -477,7 +531,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
return f(jnp.inf)
err, _ = g(2.)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
@jtu.skip_on_devices("tpu")
def test_post_process_map(self):
@ -485,11 +539,11 @@ class CheckifyTransformTests(jtu.JaxTestCase):
def g(x):
@jax.pmap
def f(y):
return jnp.sin(x * y)
return jnp.sin(x * y), jnp.cos(x * y)
return f(jnp.array([jnp.inf]))[0]
err, _ = g(2.)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
@jtu.skip_on_devices("tpu")
def test_custom_jvp(self):
@ -508,13 +562,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
self.assertIsNone(err.get())
err, y = f(jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
self.assertStartsWith(err.get(), 'nan generated by primitive: sin')
# When we hit the custom jvp rule with jvp-of-checkify, no checks are added.
(err, y), (errdot, ydot) = jax.jvp(f, (3.,), (1.,)) # doesn't crash
self.assertIsNone(err.get()) # no error
self.assertEmpty(err.msgs) # and no checks were added!
self.assertEmpty(errdot.msgs)
self.assertEmpty(err._metadata) # and no checks were added!
self.assertEmpty(errdot._metadata)
y_expected, ydot_expected = jax.jvp(jnp.sin, (3.,), (1.,))
self.assertAllClose(y, y_expected)
self.assertAllClose(ydot, ydot_expected)
@ -528,12 +582,12 @@ class CheckifyTransformTests(jtu.JaxTestCase):
errors=checkify.float_checks)
err, (y, ydot) = g(3., 1.) # doesn't crash
self.assertIsNone(err.get()) # no error
self.assertNotEmpty(err.msgs) # but checks were added!
self.assertNotEmpty(err._metadata) # but checks were added!
self.assertAllClose(y, jnp.sin(3.))
self.assertAllClose(ydot, jnp.cos(3.))
err, _ = g(jnp.inf, 1.)
self.assertIsNotNone(err.get()) # yes error
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
self.assertStartsWith(err.get(), 'nan generated by primitive: sin')
@jtu.skip_on_devices("tpu")
def test_custom_vjp(self):
@ -556,31 +610,31 @@ class CheckifyTransformTests(jtu.JaxTestCase):
# no differentiation, yes error
err, y = f(jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
self.assertStartsWith(err.get(), 'nan generated by primitive: sin')
# When we hit the custom vjp rule with vjp-of-checkify, no checks are added.
(err, y), f_vjp = jax.vjp(f, 3.)
self.assertIsNone(err.get()) # no error
self.assertEmpty(err.msgs) # and no checks were added!
self.assertEmpty(err._metadata) # and no checks were added!
# Checkify-of-vjp adds checks (unlike vjp-of-checkify above).
err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.)
self.assertIsNone(err.get()) # no error
self.assertNotEmpty(err.msgs) # but checks were added!
self.assertNotEmpty(err._metadata) # but checks were added!
err, y = checkify.checkify(jax.grad(sin),
errors=checkify.float_checks)(jnp.inf)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive sin")
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
def test_scan_consts(self):
def f(xs):
def scan_body(carry, _):
# closes oves xs
return carry+1, xs[carry]
return lax.scan(scan_body, 1, xs)[1]
return lax.scan(scan_body, 1, xs)
checked_f = checkify.checkify(f, errors=checkify.index_checks)
err, _ = checked_f(jnp.ones((7, 3)))
err, _ = checked_f(jnp.ones((7,)))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "out-of-bounds indexing")
@ -649,14 +703,59 @@ class CheckifyTransformTests(jtu.JaxTestCase):
errors=checkify.float_checks)
cf(jax.random.PRNGKey(123)) # does not crash.
def test_pmap_one_device(self):
@jax.pmap
def f(x, y):
return x/y
cf = checkify.checkify(f, errors=checkify.automatic_checks)
errs, _ = cf(jnp.ones((1,)), jnp.zeros((1,)))
self.assertIsNotNone(errs.get())
self.assertIn("division by zero", errs.get())
def test_psum_nan_check(self):
@partial(jax.vmap, axis_name="i")
def f(x, y):
return lax.psum((x, y), axis_name="i")
cf = checkify.checkify(f, errors=checkify.nan_checks)
err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2)))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive: psum")
def test_different_payload_effects(self):
def f(x, y):
x = x[y]
checkify.check(jnp.all(x > 0), "{x}", x=x)
return x
f = checkify.checkify(f, errors=checkify.all_checks)
err, _ = jax.vmap(f)(jnp.ones((2, 3))*-1, jnp.array([0, 5]))
self.assertIsNotNone(err.get())
def test_effects_total_ordering(self):
sds0 = jax.ShapeDtypeStruct((2,), jnp.float32)
sds1 = jax.ShapeDtypeStruct((2,), jnp.int32)
sds2 = jax.ShapeDtypeStruct((3,), jnp.int32)
self.assertTotallyOrdered(
[ErrorEffect(FailedCheckError, (sds0,))],
[ErrorEffect(FailedCheckError, (sds0, sds0))],
[ErrorEffect(FailedCheckError, (sds1,))],
[ErrorEffect(FailedCheckError, (sds1, sds0))],
[ErrorEffect(FailedCheckError, (sds2,))],
[ErrorEffect(OOBError, (sds0,))],
[ErrorEffect(OOBError, (sds0, sds0))],
)
@jtu.with_config(jax_check_tracer_leaks=True)
class AssertPrimitiveTests(jtu.JaxTestCase):
def test_assert_primitive_impl(self):
def f():
checkify.check(False, "hi")
with self.assertRaisesRegex(ValueError, "hi"):
with self.assertRaisesRegex(JaxRuntimeError, "hi"):
f()
def test_assert_primitive_lowering(self):
@ -668,10 +767,24 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
f()
def test_assert_primitive_jaxpr_effects(self):
def f():
checkify.check(False, "hi")
def f(x):
checkify.check(False, "hi: {}", x)
self.assertSetEqual(jax.make_jaxpr(f)().effects, {CheckEffect})
jaxpr = jax.make_jaxpr(f)(jnp.ones(4, jnp.int32))
self.assertSetEqual(jaxpr.effects,
{ErrorEffect(FailedCheckError, (
jax.ShapeDtypeStruct((0,), jnp.int32),
jax.ShapeDtypeStruct((4,), jnp.int32),))})
def g(x, y):
checkify.check(False, "hi: {} {}", x, y)
self.assertSetEqual(
jax.make_jaxpr(g)(
jnp.ones(4, jnp.int32), jnp.ones(2, jnp.float32)).effects,
{ErrorEffect(FailedCheckError, (
jax.ShapeDtypeStruct((0,), jnp.int32),
jax.ShapeDtypeStruct((4,), jnp.int32),
jax.ShapeDtypeStruct((2,), jnp.float32)))})
def test_assert_primitive_eval_shape(self):
# The check is abstractly evaluated but not lowered.
@ -798,10 +911,10 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
mult_batch_fail = jnp.array([[0.5, 0.5], [1, 1], [2, 2]])
f(no_failures)
with self.assertRaisesRegex(ValueError, "x must sum to one."):
with self.assertRaisesRegex(JaxRuntimeError, "x must sum to one."):
f(one_batch_fails)
with self.assertRaisesRegex(ValueError, "x must sum to one."):
with self.assertRaisesRegex(JaxRuntimeError, "x must sum to one."):
f(mult_batch_fail)
checked_f = checkify.checkify(f)
@ -817,10 +930,13 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
self.assertStartsWith(err.get(), "x must sum to one")
def test_check_error(self):
def g():
checkify.check(False, "hi")
def f():
checkify.check_error(checkify.Error(True, 0, {0: "hi"}))
err, _ = checkify.checkify(g)()
checkify.check_error(err)
with self.assertRaisesRegex(ValueError, "hi"):
with self.assertRaisesRegex(JaxRuntimeError, "hi"):
f()
f = checkify.checkify(f)
@ -872,7 +988,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
python_should_be_running = False
f(True)
with self.assertRaisesRegex(ValueError, "foo"):
with self.assertRaisesRegex(JaxRuntimeError, "foo"):
f(False)
def test_cond_of_named_call(self):
@ -950,10 +1066,12 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
# self.assertIsNone(err.get())
def test_assert_cond_no_data_dependence(self):
def true_fun():
return checkify.check(False, "hi!")
def false_fun():
return checkify.check(False, "bye!")
def f():
return jax.lax.cond(True,
lambda: checkify.check(False, "hi!"),
lambda: checkify.check(False, "bye!"))
return jax.lax.cond(True, true_fun, false_fun)
f = checkify.checkify(f)
err, _ = f()
@ -973,16 +1091,6 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "hi!")
def test_psum_nan_check(self):
@partial(jax.vmap, axis_name="i")
def f(x, y):
return lax.psum((x, y), axis_name="i")
cf = checkify.checkify(f, errors=checkify.nan_checks)
err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2)))
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "nan generated by primitive psum")
class LowerableChecksTest(jtu.JaxTestCase):
def setUp(self):