mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Rewrite Checkify to support tracking different error types.
In general, behavior should remain the same and this is not a breaking change. There are some minor changes to the API: - checkify.ErrorCategory has changed type: it's no longer an Enum, but the JaxException type. These have not been exposed as part of the public API. - some attributes on Error have changed and made private - The raised error has changed type (JaxRuntimeError), and will have a different traceback (pointing to the origin of the error + where the error value was raised). - `checkify.check` now supports formating error message with variable size runtime info! Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
This commit is contained in:
parent
ebee4f4bfd
commit
e4757e8410
@ -14,7 +14,7 @@ API
|
||||
check
|
||||
check_error
|
||||
Error
|
||||
ErrorCategory
|
||||
JaxRuntimeError
|
||||
user_checks
|
||||
nan_checks
|
||||
index_checks
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -246,6 +246,10 @@ def _cond(pred, true_fun: Callable, false_fun: Callable, *operands,
|
||||
# Raise index in case of effects to allow data-dependence-based discharging
|
||||
# of those effects (even if they don't have an explicit data dependence).
|
||||
index = core.raise_as_much_as_possible(index)
|
||||
false_jaxpr = false_jaxpr.replace(
|
||||
jaxpr=false_jaxpr.jaxpr.replace(effects=joined_effects))
|
||||
true_jaxpr = true_jaxpr.replace(
|
||||
jaxpr=true_jaxpr.jaxpr.replace(effects=joined_effects))
|
||||
|
||||
linear = [False] * len(consts) + linear_ops
|
||||
out = cond_p.bind(
|
||||
|
@ -15,6 +15,7 @@
|
||||
from jax._src.checkify import (
|
||||
Error as Error,
|
||||
ErrorCategory as ErrorCategory,
|
||||
JaxRuntimeError as JaxRuntimeError,
|
||||
all_checks as all_checks,
|
||||
automatic_checks as automatic_checks,
|
||||
check as check,
|
||||
|
@ -1219,7 +1219,7 @@ tf_not_yet_impl = [
|
||||
# Not high priority?
|
||||
"after_all",
|
||||
"all_to_all",
|
||||
"assert",
|
||||
"check",
|
||||
"create_token",
|
||||
"custom_transpose_call",
|
||||
"custom_vmap_call",
|
||||
|
@ -29,12 +29,13 @@ from jax.experimental import pjit
|
||||
from jax.experimental import maps
|
||||
from jax._src.sharding import NamedSharding
|
||||
from jax._src import array
|
||||
from jax._src.checkify import CheckEffect
|
||||
from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError
|
||||
import jax.numpy as jnp
|
||||
|
||||
config.parse_flags_with_absl()
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
@jtu.sample_product(jit=[False, True])
|
||||
@ -49,11 +50,11 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
err, _ = checked_f(3., 4.)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
err, _ = checked_f(3., jnp.inf)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
|
||||
@jtu.sample_product(jit=[False, True])
|
||||
def test_jit_oob(self, jit):
|
||||
@ -67,7 +68,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
||||
|
||||
err, _ = checked_f(jnp.arange(3), 2)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
err, _ = checked_f(jnp.arange(3), 5)
|
||||
self.assertIsNotNone(err.get())
|
||||
@ -83,7 +84,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
||||
|
||||
err, _ = checked_f(jnp.arange(3), 2)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
err, _ = checked_f(jnp.arange(3), 3)
|
||||
self.assertIsNotNone(err.get())
|
||||
@ -98,15 +99,14 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
|
||||
err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,)))
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
err, _ = checked_f(jnp.ones((3,)), jnp.array([1., 0., 1.]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
|
||||
err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive div")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: div")
|
||||
|
||||
@jtu.sample_product(jit=[False, True])
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -121,7 +121,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
# no error
|
||||
err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
# oob error
|
||||
err, _ = checked_f(jnp.array([0., 1., 2.]), 5)
|
||||
@ -131,11 +131,20 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
# nan error
|
||||
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 2)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive cos")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: cos")
|
||||
|
||||
def test_numpy_indexing_oobs(self):
|
||||
@parameterized.named_parameters(
|
||||
("gather", lambda x: x.get()),
|
||||
("scatter_add", lambda x: x.add(1.)),
|
||||
("scatter_mul", lambda x: x.multiply(1.)),
|
||||
("scatter_div", lambda x: x.divide(1.)),
|
||||
("scatter_pow", lambda x: x.power(1.)),
|
||||
("scatter_min", lambda x: x.min(1.)),
|
||||
("scatter_max", lambda x: x.max(1.)),
|
||||
)
|
||||
def test_numpy_indexing_oobs(self, update_op):
|
||||
def raises_oob(fn, idx, *expected_strs):
|
||||
err, _ = checkify.checkify(fn, errors=checkify.index_checks)(x, idx)
|
||||
err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, idx)
|
||||
error_txt = err.get()
|
||||
self.assertIsNotNone(error_txt)
|
||||
self.assertStartsWith(error_txt, "out-of-bounds indexing")
|
||||
@ -147,7 +156,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
axis1_msg = "axis 1 with size 3"
|
||||
axis2_msg = "axis 2 with size 7"
|
||||
|
||||
single_idx = lambda x, i: x[i]
|
||||
single_idx = lambda x, i: update_op(x.at[i])
|
||||
raises_oob(single_idx, 5, "index 5", axis0_msg)
|
||||
raises_oob(single_idx, -5, "index -3", axis0_msg)
|
||||
raises_oob(single_idx, (0, 100), "index 100", axis1_msg)
|
||||
@ -158,7 +167,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
raises_oob(single_idx, (((1, 1), (1, 20)), 3), "index 3", axis1_msg)
|
||||
raises_oob(single_idx, (((1, 1), (1, 20)), 0), "index 20", axis0_msg)
|
||||
|
||||
multi_idx = lambda x, i: x[i[0], :, i[1]]
|
||||
multi_idx = lambda x, i: update_op(x.at[i[0], :, i[1]])
|
||||
raises_oob(multi_idx, (0, 9), "index 9", axis2_msg)
|
||||
# TODO(lenamartens): numpy reports index -5 here, need to normalize?
|
||||
raises_oob(multi_idx, (-5, 9), "index -3", axis0_msg)
|
||||
@ -194,32 +203,77 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
xs = jnp.array([0., 2.])
|
||||
err, _ = checked_f(xs, xs)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
ys = jnp.array([3., jnp.inf])
|
||||
err, _ = checked_f(xs, ys)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_cond_basic(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
return lax.cond(x > 0,
|
||||
lambda: jnp.sin(x),
|
||||
lambda: x)
|
||||
def true_fun(x):
|
||||
return jnp.sin(x)
|
||||
def false_fun(x):
|
||||
checkify.check(x > -1, "oh no")
|
||||
return x / 0.
|
||||
return lax.cond(x > 0, true_fun, false_fun, x)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
||||
checked_f = checkify.checkify(f, errors=checkify.all_checks)
|
||||
|
||||
err, _ = checked_f(3.)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
|
||||
err, _ = checked_f(jnp.inf)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
|
||||
err, _ = checked_f(-jnp.inf)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertStartsWith(err.get(), "oh no")
|
||||
|
||||
err, _ = checked_f(0.)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "division by zero")
|
||||
|
||||
def test_cond_different_payloads(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
def true_fun(x):
|
||||
checkify.check(~x, "{one}", one=x)
|
||||
def false_fun(x):
|
||||
checkify.check(x, "{one} and {two}", one=x, two=x)
|
||||
return lax.cond(x, true_fun, false_fun, x)
|
||||
|
||||
checked_f = checkify.checkify(f)
|
||||
|
||||
err, _ = checked_f(True)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "True")
|
||||
|
||||
err, _ = checked_f(False)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "False and False")
|
||||
|
||||
def test_cond_nd_payloads(self):
|
||||
@jax.jit
|
||||
def f(x):
|
||||
def true_fun(x):
|
||||
checkify.check(jnp.all(x > 0), "{one}", one=x)
|
||||
def false_fun(x):
|
||||
checkify.check(jnp.all(x < 0), "{one} and {two}", one=x, two=x)
|
||||
return lax.cond(jnp.all(x < 0), true_fun, false_fun, x)
|
||||
|
||||
checked_f = checkify.checkify(f)
|
||||
|
||||
err, _ = checked_f(jnp.arange(0, 4))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "[0 1 2 3] and [0 1 2 3]")
|
||||
|
||||
err, _ = checked_f(jnp.arange(-4, -1))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "[-4 -3 -2]")
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_scan_map(self):
|
||||
@ -235,14 +289,14 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
xs = jnp.array([0., 2.])
|
||||
err, (_, ch_outs) = checked_f(xs)
|
||||
_, outs = f(xs)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
self.assertArraysEqual(ch_outs, outs)
|
||||
|
||||
xs = jnp.array([3., jnp.inf])
|
||||
err, (_, ch_outs) = checked_f(xs)
|
||||
_, outs = f(xs)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
self.assertArraysEqual(ch_outs, outs)
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
@ -261,7 +315,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
carry, xs = 3., jnp.ones((2,))
|
||||
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
||||
out_carry, outs = f(carry, xs)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
self.assertArraysEqual(ch_outs, outs)
|
||||
self.assertArraysEqual(ch_out_carry, out_carry)
|
||||
|
||||
@ -303,7 +357,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
init_val = 1.
|
||||
err, ch_out = checked_f(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
self.assertArraysEqual(ch_out, out)
|
||||
|
||||
init_val = 0.
|
||||
@ -331,7 +385,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
init_val = 1.
|
||||
err, ch_out = checked_f(init_val)
|
||||
out = f(init_val)
|
||||
self.assertIs(err.get(), None)
|
||||
self.assertIsNone(err.get())
|
||||
self.assertArraysEqual(ch_out, out)
|
||||
|
||||
init_val = 0.
|
||||
@ -388,20 +442,20 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
body_val = 1.
|
||||
err, _ = checked_f(cond_val, body_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
|
||||
cond_val = 1.
|
||||
body_val = jnp.inf
|
||||
err, _ = checked_f(cond_val, body_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive cos")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: cos")
|
||||
|
||||
cond_val = jnp.inf
|
||||
body_val = jnp.inf
|
||||
err, _ = checked_f(cond_val, body_val)
|
||||
self.assertIsNotNone(err.get())
|
||||
# first error which occurs is in cond
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
|
||||
def test_pjit(self):
|
||||
def f(x):
|
||||
@ -448,8 +502,8 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("assert", checkify.user_checks, "must be negative!"),
|
||||
("div", {checkify.ErrorCategory.DIV}, "division by zero"),
|
||||
("nan", {checkify.ErrorCategory.NAN}, "nan generated"),
|
||||
("div", checkify.div_checks, "division by zero"),
|
||||
("nan", checkify.nan_checks, "nan generated"),
|
||||
("oob", checkify.index_checks, "out-of-bounds indexing"),
|
||||
("automatic_checks", checkify.automatic_checks, "division by zero"),
|
||||
)
|
||||
@ -477,7 +531,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
return f(jnp.inf)
|
||||
err, _ = g(2.)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_post_process_map(self):
|
||||
@ -485,11 +539,11 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
def g(x):
|
||||
@jax.pmap
|
||||
def f(y):
|
||||
return jnp.sin(x * y)
|
||||
return jnp.sin(x * y), jnp.cos(x * y)
|
||||
return f(jnp.array([jnp.inf]))[0]
|
||||
err, _ = g(2.)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_custom_jvp(self):
|
||||
@ -508,13 +562,13 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
self.assertIsNone(err.get())
|
||||
err, y = f(jnp.inf)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive: sin')
|
||||
|
||||
# When we hit the custom jvp rule with jvp-of-checkify, no checks are added.
|
||||
(err, y), (errdot, ydot) = jax.jvp(f, (3.,), (1.,)) # doesn't crash
|
||||
self.assertIsNone(err.get()) # no error
|
||||
self.assertEmpty(err.msgs) # and no checks were added!
|
||||
self.assertEmpty(errdot.msgs)
|
||||
self.assertEmpty(err._metadata) # and no checks were added!
|
||||
self.assertEmpty(errdot._metadata)
|
||||
y_expected, ydot_expected = jax.jvp(jnp.sin, (3.,), (1.,))
|
||||
self.assertAllClose(y, y_expected)
|
||||
self.assertAllClose(ydot, ydot_expected)
|
||||
@ -528,12 +582,12 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
errors=checkify.float_checks)
|
||||
err, (y, ydot) = g(3., 1.) # doesn't crash
|
||||
self.assertIsNone(err.get()) # no error
|
||||
self.assertNotEmpty(err.msgs) # but checks were added!
|
||||
self.assertNotEmpty(err._metadata) # but checks were added!
|
||||
self.assertAllClose(y, jnp.sin(3.))
|
||||
self.assertAllClose(ydot, jnp.cos(3.))
|
||||
err, _ = g(jnp.inf, 1.)
|
||||
self.assertIsNotNone(err.get()) # yes error
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive: sin')
|
||||
|
||||
@jtu.skip_on_devices("tpu")
|
||||
def test_custom_vjp(self):
|
||||
@ -556,31 +610,31 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
# no differentiation, yes error
|
||||
err, y = f(jnp.inf)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive sin')
|
||||
self.assertStartsWith(err.get(), 'nan generated by primitive: sin')
|
||||
|
||||
# When we hit the custom vjp rule with vjp-of-checkify, no checks are added.
|
||||
(err, y), f_vjp = jax.vjp(f, 3.)
|
||||
self.assertIsNone(err.get()) # no error
|
||||
self.assertEmpty(err.msgs) # and no checks were added!
|
||||
self.assertEmpty(err._metadata) # and no checks were added!
|
||||
|
||||
# Checkify-of-vjp adds checks (unlike vjp-of-checkify above).
|
||||
err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.)
|
||||
self.assertIsNone(err.get()) # no error
|
||||
self.assertNotEmpty(err.msgs) # but checks were added!
|
||||
self.assertNotEmpty(err._metadata) # but checks were added!
|
||||
err, y = checkify.checkify(jax.grad(sin),
|
||||
errors=checkify.float_checks)(jnp.inf)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive sin")
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
||||
|
||||
def test_scan_consts(self):
|
||||
def f(xs):
|
||||
def scan_body(carry, _):
|
||||
# closes oves xs
|
||||
return carry+1, xs[carry]
|
||||
return lax.scan(scan_body, 1, xs)[1]
|
||||
return lax.scan(scan_body, 1, xs)
|
||||
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
||||
err, _ = checked_f(jnp.ones((7, 3)))
|
||||
err, _ = checked_f(jnp.ones((7,)))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "out-of-bounds indexing")
|
||||
|
||||
@ -649,14 +703,59 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
errors=checkify.float_checks)
|
||||
cf(jax.random.PRNGKey(123)) # does not crash.
|
||||
|
||||
def test_pmap_one_device(self):
|
||||
@jax.pmap
|
||||
def f(x, y):
|
||||
return x/y
|
||||
|
||||
cf = checkify.checkify(f, errors=checkify.automatic_checks)
|
||||
errs, _ = cf(jnp.ones((1,)), jnp.zeros((1,)))
|
||||
self.assertIsNotNone(errs.get())
|
||||
self.assertIn("division by zero", errs.get())
|
||||
|
||||
def test_psum_nan_check(self):
|
||||
@partial(jax.vmap, axis_name="i")
|
||||
def f(x, y):
|
||||
return lax.psum((x, y), axis_name="i")
|
||||
|
||||
cf = checkify.checkify(f, errors=checkify.nan_checks)
|
||||
err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2)))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive: psum")
|
||||
|
||||
def test_different_payload_effects(self):
|
||||
def f(x, y):
|
||||
x = x[y]
|
||||
checkify.check(jnp.all(x > 0), "{x}", x=x)
|
||||
return x
|
||||
|
||||
f = checkify.checkify(f, errors=checkify.all_checks)
|
||||
err, _ = jax.vmap(f)(jnp.ones((2, 3))*-1, jnp.array([0, 5]))
|
||||
self.assertIsNotNone(err.get())
|
||||
|
||||
def test_effects_total_ordering(self):
|
||||
sds0 = jax.ShapeDtypeStruct((2,), jnp.float32)
|
||||
sds1 = jax.ShapeDtypeStruct((2,), jnp.int32)
|
||||
sds2 = jax.ShapeDtypeStruct((3,), jnp.int32)
|
||||
self.assertTotallyOrdered(
|
||||
[ErrorEffect(FailedCheckError, (sds0,))],
|
||||
[ErrorEffect(FailedCheckError, (sds0, sds0))],
|
||||
[ErrorEffect(FailedCheckError, (sds1,))],
|
||||
[ErrorEffect(FailedCheckError, (sds1, sds0))],
|
||||
[ErrorEffect(FailedCheckError, (sds2,))],
|
||||
[ErrorEffect(OOBError, (sds0,))],
|
||||
[ErrorEffect(OOBError, (sds0, sds0))],
|
||||
)
|
||||
|
||||
|
||||
@jtu.with_config(jax_check_tracer_leaks=True)
|
||||
class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
|
||||
def test_assert_primitive_impl(self):
|
||||
def f():
|
||||
checkify.check(False, "hi")
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "hi"):
|
||||
with self.assertRaisesRegex(JaxRuntimeError, "hi"):
|
||||
f()
|
||||
|
||||
def test_assert_primitive_lowering(self):
|
||||
@ -668,10 +767,24 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
f()
|
||||
|
||||
def test_assert_primitive_jaxpr_effects(self):
|
||||
def f():
|
||||
checkify.check(False, "hi")
|
||||
def f(x):
|
||||
checkify.check(False, "hi: {}", x)
|
||||
|
||||
self.assertSetEqual(jax.make_jaxpr(f)().effects, {CheckEffect})
|
||||
jaxpr = jax.make_jaxpr(f)(jnp.ones(4, jnp.int32))
|
||||
self.assertSetEqual(jaxpr.effects,
|
||||
{ErrorEffect(FailedCheckError, (
|
||||
jax.ShapeDtypeStruct((0,), jnp.int32),
|
||||
jax.ShapeDtypeStruct((4,), jnp.int32),))})
|
||||
def g(x, y):
|
||||
checkify.check(False, "hi: {} {}", x, y)
|
||||
|
||||
self.assertSetEqual(
|
||||
jax.make_jaxpr(g)(
|
||||
jnp.ones(4, jnp.int32), jnp.ones(2, jnp.float32)).effects,
|
||||
{ErrorEffect(FailedCheckError, (
|
||||
jax.ShapeDtypeStruct((0,), jnp.int32),
|
||||
jax.ShapeDtypeStruct((4,), jnp.int32),
|
||||
jax.ShapeDtypeStruct((2,), jnp.float32)))})
|
||||
|
||||
def test_assert_primitive_eval_shape(self):
|
||||
# The check is abstractly evaluated but not lowered.
|
||||
@ -798,10 +911,10 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
mult_batch_fail = jnp.array([[0.5, 0.5], [1, 1], [2, 2]])
|
||||
|
||||
f(no_failures)
|
||||
with self.assertRaisesRegex(ValueError, "x must sum to one."):
|
||||
with self.assertRaisesRegex(JaxRuntimeError, "x must sum to one."):
|
||||
f(one_batch_fails)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "x must sum to one."):
|
||||
with self.assertRaisesRegex(JaxRuntimeError, "x must sum to one."):
|
||||
f(mult_batch_fail)
|
||||
|
||||
checked_f = checkify.checkify(f)
|
||||
@ -817,10 +930,13 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
self.assertStartsWith(err.get(), "x must sum to one")
|
||||
|
||||
def test_check_error(self):
|
||||
def g():
|
||||
checkify.check(False, "hi")
|
||||
def f():
|
||||
checkify.check_error(checkify.Error(True, 0, {0: "hi"}))
|
||||
err, _ = checkify.checkify(g)()
|
||||
checkify.check_error(err)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "hi"):
|
||||
with self.assertRaisesRegex(JaxRuntimeError, "hi"):
|
||||
f()
|
||||
|
||||
f = checkify.checkify(f)
|
||||
@ -872,7 +988,7 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
|
||||
python_should_be_running = False
|
||||
f(True)
|
||||
with self.assertRaisesRegex(ValueError, "foo"):
|
||||
with self.assertRaisesRegex(JaxRuntimeError, "foo"):
|
||||
f(False)
|
||||
|
||||
def test_cond_of_named_call(self):
|
||||
@ -950,10 +1066,12 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
# self.assertIsNone(err.get())
|
||||
|
||||
def test_assert_cond_no_data_dependence(self):
|
||||
def true_fun():
|
||||
return checkify.check(False, "hi!")
|
||||
def false_fun():
|
||||
return checkify.check(False, "bye!")
|
||||
def f():
|
||||
return jax.lax.cond(True,
|
||||
lambda: checkify.check(False, "hi!"),
|
||||
lambda: checkify.check(False, "bye!"))
|
||||
return jax.lax.cond(True, true_fun, false_fun)
|
||||
|
||||
f = checkify.checkify(f)
|
||||
err, _ = f()
|
||||
@ -973,16 +1091,6 @@ class AssertPrimitiveTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "hi!")
|
||||
|
||||
def test_psum_nan_check(self):
|
||||
@partial(jax.vmap, axis_name="i")
|
||||
def f(x, y):
|
||||
return lax.psum((x, y), axis_name="i")
|
||||
|
||||
cf = checkify.checkify(f, errors=checkify.nan_checks)
|
||||
err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2)))
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "nan generated by primitive psum")
|
||||
|
||||
|
||||
class LowerableChecksTest(jtu.JaxTestCase):
|
||||
def setUp(self):
|
||||
|
Loading…
x
Reference in New Issue
Block a user