mirror of
https://github.com/ROCm/jax.git
synced 2025-04-14 10:56:06 +00:00

This has two problems: * it's not thread-safe, which will become problematic if we run tests with thread-parallelism. * it's not very maintainable. Instead, add a new util.test_event(...) function that can be called at points of interest in the program. test_utils registers a callback that is invoked when an event is received. This avoids the need to make thread-unsafe global monkey patches.
1383 lines
43 KiB
Python
1383 lines
43 KiB
Python
# Copyright 2021 The JAX Authors.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# https://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
|
|
from functools import partial
|
|
import unittest
|
|
|
|
from absl.testing import absltest
|
|
from absl.testing import parameterized
|
|
import numpy as np
|
|
|
|
import jax
|
|
from jax import lax
|
|
from jax.experimental import checkify
|
|
from jax.experimental import pjit
|
|
from jax.experimental import shard_map
|
|
from jax.sharding import NamedSharding
|
|
from jax._src import array
|
|
from jax._src import config
|
|
from jax._src import core
|
|
from jax._src import test_util as jtu
|
|
from jax._src.checkify import JaxRuntimeError, FailedCheckError, ErrorEffect, OOBError
|
|
from jax._src.lib import xla_extension
|
|
import jax.numpy as jnp
|
|
|
|
config.parse_flags_with_absl()
|
|
|
|
|
|
@jtu.with_config(jax_check_tracer_leaks=True)
|
|
class CheckifyTransformTests(jtu.JaxTestCase):
|
|
|
|
@jtu.sample_product(jit=[False, True])
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_jit_nan(self, jit):
|
|
def f(x1, x2):
|
|
y1 = jnp.sin(x1)
|
|
y2 = jnp.sin(x2)
|
|
return y1 + y2
|
|
|
|
f = jax.jit(f) if jit else f
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
|
|
err, _ = checked_f(3., 4.)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = checked_f(3., jnp.inf)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
|
|
|
@jtu.sample_product(jit=[False, True])
|
|
def test_jit_oob(self, jit):
|
|
def f(x, i):
|
|
y = jnp.sin(x)
|
|
z = y[i]
|
|
w = jnp.cos(z)
|
|
return w
|
|
|
|
f = jax.jit(f) if jit else f
|
|
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
|
|
|
err, _ = checked_f(jnp.arange(3), 2)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = checked_f(jnp.arange(3), 5)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "out-of-bounds indexing")
|
|
|
|
@parameterized.named_parameters(
|
|
("get", lambda x: x.get()),
|
|
("set", lambda x: x.set(1)),
|
|
("add", lambda x: x.add(1)),
|
|
("mul", lambda x: x.multiply(1)),
|
|
("div", lambda x: x.divide(1)),
|
|
("pow", lambda x: x.power(1)),
|
|
("min", lambda x: x.min(1)),
|
|
("max", lambda x: x.max(1)),
|
|
)
|
|
def test_jit_oob_update(self, update_fn):
|
|
f = jax.jit(lambda x, i: update_fn(x.at[i]))
|
|
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
|
|
|
err, _ = checked_f(jnp.arange(3), 2)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = checked_f(jnp.arange(3), 3)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "out-of-bounds indexing")
|
|
|
|
@jtu.sample_product(jit=[False, True])
|
|
@jax.numpy_dtype_promotion('standard')
|
|
def test_jit_div_errors(self, jit):
|
|
def f(x, y):
|
|
return x / y
|
|
|
|
f = jax.jit(f) if jit else f
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
|
|
err, _ = checked_f(jnp.ones((3,)), jnp.ones((3,)))
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = checked_f(jnp.ones((3,)), jnp.array([1., 0., 1.]))
|
|
self.assertIsNotNone(err.get())
|
|
|
|
err, _ = checked_f(jnp.array([1, jnp.inf, 1]), jnp.array([1, jnp.inf, 1]))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: div")
|
|
|
|
@jtu.sample_product(jit=[False, True])
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_jit_multi(self, jit):
|
|
def f(x, i):
|
|
y = x[i]
|
|
z = jnp.cos(y)
|
|
return z
|
|
|
|
f = jax.jit(f) if jit else f
|
|
checked_f = checkify.checkify(f, errors=checkify.automatic_checks)
|
|
|
|
# no error
|
|
err, _ = checked_f(jnp.array([0., jnp.inf, 2.]), 2)
|
|
self.assertIsNone(err.get())
|
|
|
|
# oob error
|
|
err, _ = checked_f(jnp.array([0., 1., 2.]), 5)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "out-of-bounds indexing")
|
|
|
|
# nan error
|
|
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 2)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: cos")
|
|
|
|
@parameterized.named_parameters(
|
|
("gather", lambda x: x.get()),
|
|
("scatter_update", lambda x: x.set(1.)),
|
|
("scatter_add", lambda x: x.add(1.)),
|
|
("scatter_mul", lambda x: x.multiply(1.)),
|
|
("scatter_div", lambda x: x.divide(1.)),
|
|
("scatter_pow", lambda x: x.power(1.)),
|
|
("scatter_min", lambda x: x.min(1.)),
|
|
("scatter_max", lambda x: x.max(1.)),
|
|
)
|
|
def test_numpy_indexing_oobs(self, update_op):
|
|
def raises_oob(fn, idx, *expected_strs):
|
|
err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, idx)
|
|
error_txt = err.get()
|
|
self.assertIsNotNone(error_txt)
|
|
self.assertStartsWith(error_txt, "out-of-bounds indexing")
|
|
for s in expected_strs:
|
|
self.assertIn(s, error_txt)
|
|
|
|
x = jnp.ones((2, 3, 7))
|
|
axis0_msg = "axis 0 with size 2"
|
|
axis1_msg = "axis 1 with size 3"
|
|
axis2_msg = "axis 2 with size 7"
|
|
|
|
single_idx = lambda x, i: update_op(x.at[i])
|
|
raises_oob(single_idx, 5, "index 5", axis0_msg)
|
|
raises_oob(single_idx, -5, "index -3", axis0_msg)
|
|
raises_oob(single_idx, (0, 100), "index 100", axis1_msg)
|
|
raises_oob(single_idx, (0, 5, 100), "index 5", axis1_msg)
|
|
raises_oob(single_idx, (0, 0, 100), "index 100", axis2_msg)
|
|
raises_oob(single_idx, ((1, 20), (1, 4)), "index 20", axis0_msg)
|
|
raises_oob(single_idx, ((1, 20), (3, 4)), "index 3", axis1_msg)
|
|
raises_oob(single_idx, (((1, 1), (1, 20)), 3), "index 3", axis1_msg)
|
|
raises_oob(single_idx, (((1, 1), (1, 20)), 0), "index 20", axis0_msg)
|
|
|
|
multi_idx = lambda x, i: update_op(x.at[i[0], :, i[1]])
|
|
raises_oob(multi_idx, (0, 9), "index 9", axis2_msg)
|
|
# TODO(lenamartens): numpy reports index -5 here, need to normalize?
|
|
raises_oob(multi_idx, (-5, 9), "index -3", axis0_msg)
|
|
raises_oob(multi_idx, (5, -9), "index 5", axis0_msg)
|
|
raises_oob(multi_idx, ((0, 9), 0), "index 9", axis0_msg)
|
|
|
|
def test_dynamic_slice_oobs(self):
|
|
def raises_oob(fn, x, idx, *expected_strs):
|
|
err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, idx)
|
|
error_txt = err.get()
|
|
self.assertIsNotNone(error_txt)
|
|
self.assertStartsWith(error_txt, "out-of-bounds indexing")
|
|
for s in expected_strs:
|
|
self.assertIn(s, error_txt)
|
|
|
|
x = jnp.ones((2, 3, 7))
|
|
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (2, 0, 0), 'index 2')
|
|
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (-3, 0, 0), 'index -1')
|
|
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 3, 0), 'index 3')
|
|
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, -5, 0), 'index -2')
|
|
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, 8), 'index 8')
|
|
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, -10), 'index -3')
|
|
|
|
def test_dynamic_update_slice_oobs(self):
|
|
def raises_oob(fn, x, y, idx, *expected_strs):
|
|
err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, y, idx)
|
|
error_txt = err.get()
|
|
self.assertIsNotNone(error_txt)
|
|
self.assertStartsWith(error_txt, "out-of-bounds indexing")
|
|
for s in expected_strs:
|
|
self.assertIn(s, error_txt)
|
|
|
|
x = jnp.ones((2, 3, 7))
|
|
y = jnp.zeros((1, 1, 1))
|
|
raises_oob(lax.dynamic_update_slice, x, y, (2, 0, 0), 'index 2')
|
|
raises_oob(lax.dynamic_update_slice, x, y, (-3, 0, 0), 'index -1')
|
|
raises_oob(lax.dynamic_update_slice, x, y, (0, 3, 0), 'index 3')
|
|
raises_oob(lax.dynamic_update_slice, x, y, (0, -5, 0), 'index -2')
|
|
raises_oob(lax.dynamic_update_slice, x, y, (0, 1, 8), 'index 8')
|
|
raises_oob(lax.dynamic_update_slice, x, y, (0, 1, -10), 'index -3')
|
|
|
|
@jtu.sample_product(jit=[False, True])
|
|
def test_jit_ordering(self, jit):
|
|
def f(x, i):
|
|
y = x[i]
|
|
z = jnp.sin(x)
|
|
return y * z
|
|
|
|
f = jax.jit(f) if jit else f
|
|
checked_f = checkify.checkify(f, errors=checkify.automatic_checks)
|
|
|
|
# both oob and nan error, but oob happens first
|
|
err, _ = checked_f(jnp.array([0., 1., jnp.inf]), 5)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "out-of-bounds indexing")
|
|
|
|
def test_pmap_basic(self):
|
|
if len(jax.devices()) < 2:
|
|
raise unittest.SkipTest("requires at least 2 devices")
|
|
|
|
@jax.pmap
|
|
def f(x):
|
|
y1 = jnp.sin(1./x)
|
|
y2 = jnp.sin(x)
|
|
return y1 + y2
|
|
checked_f = checkify.checkify(f, errors=checkify.nan_checks)
|
|
|
|
xs = jnp.array([1., 2.])
|
|
err, _ = checked_f(xs)
|
|
self.assertIsNone(err.get())
|
|
|
|
xs = jnp.array([3., 0.])
|
|
err, _ = checked_f(xs)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
|
|
|
def test_pmap_collectives(self):
|
|
if len(jax.devices()) < 4:
|
|
raise unittest.SkipTest("requires at least 2 devices")
|
|
|
|
@partial(jax.pmap, axis_name="i")
|
|
def f(x1):
|
|
return jax.lax.all_gather(x1, axis_name="i")
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
|
|
xs = jnp.array([0., 2., 3., 6.])
|
|
err, _ = checked_f(xs)
|
|
self.assertIsNone(err.get())
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_cond_basic(self):
|
|
@jax.jit
|
|
def f(x):
|
|
def true_fun(x):
|
|
return jnp.sin(x)
|
|
def false_fun(x):
|
|
checkify.check(x > -1, "oh no")
|
|
return x / 0.
|
|
return lax.cond(x > 0, true_fun, false_fun, x)
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.all_checks)
|
|
|
|
err, _ = checked_f(3.)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = checked_f(jnp.inf)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
|
|
|
err, _ = checked_f(-jnp.inf)
|
|
self.assertStartsWith(err.get(), "oh no")
|
|
|
|
err, _ = checked_f(0.)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "division by zero")
|
|
|
|
def test_cond_different_payloads(self):
|
|
@jax.jit
|
|
def f(x):
|
|
def true_fun(x):
|
|
checkify.check(~x, "{one}", one=x)
|
|
def false_fun(x):
|
|
checkify.check(x, "{one} and {two}", one=x, two=x)
|
|
return lax.cond(x, true_fun, false_fun, x)
|
|
|
|
checked_f = checkify.checkify(f)
|
|
|
|
err, _ = checked_f(True)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "True")
|
|
|
|
err, _ = checked_f(False)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "False and False")
|
|
|
|
def test_cond_nd_payloads(self):
|
|
@jax.jit
|
|
def f(x):
|
|
def true_fun(x):
|
|
checkify.check(jnp.all(x > 0), "{one}", one=x)
|
|
def false_fun(x):
|
|
checkify.check(jnp.all(x < 0), "{one} and {two}", one=x, two=x)
|
|
return lax.cond(jnp.all(x < 0), true_fun, false_fun, x)
|
|
|
|
checked_f = checkify.checkify(f)
|
|
|
|
err, _ = checked_f(jnp.arange(0, 4))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "[0 1 2 3] and [0 1 2 3]")
|
|
|
|
err, _ = checked_f(jnp.arange(-4, -1))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "[-4 -3 -2]")
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_scan_map(self):
|
|
def scan_body(_, x):
|
|
return None, jnp.sin(x)
|
|
|
|
@jax.jit
|
|
def f(xs):
|
|
return lax.scan(scan_body, None, xs)
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
|
|
xs = jnp.array([0., 2.])
|
|
err, (_, ch_outs) = checked_f(xs)
|
|
_, outs = f(xs)
|
|
self.assertIsNone(err.get())
|
|
self.assertArraysEqual(ch_outs, outs)
|
|
|
|
xs = jnp.array([3., jnp.inf])
|
|
err, (_, ch_outs) = checked_f(xs)
|
|
_, outs = f(xs)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
|
self.assertArraysEqual(ch_outs, outs)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_scan_carry(self):
|
|
def scan_body(carry, x):
|
|
carry = carry-1.
|
|
possible_nan = jnp.sin(1./carry)
|
|
return carry, x+possible_nan
|
|
|
|
@jax.jit
|
|
def f(carry, xs):
|
|
return lax.scan(scan_body, carry, xs)
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
|
|
carry, xs = 3., jnp.ones((2,))
|
|
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
|
out_carry, outs = f(carry, xs)
|
|
self.assertIsNone(err.get())
|
|
self.assertArraysEqual(ch_outs, outs)
|
|
self.assertArraysEqual(ch_out_carry, out_carry)
|
|
|
|
# error happens on first iteration
|
|
carry, xs = 1., jnp.ones((2,))
|
|
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
|
out_carry, outs = f(carry, xs)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "division by zero")
|
|
self.assertArraysEqual(ch_outs, outs)
|
|
self.assertArraysEqual(ch_out_carry, out_carry)
|
|
|
|
# error happens on second iteration
|
|
carry, xs = 2., jnp.ones((4,))
|
|
err, (ch_out_carry, ch_outs) = checked_f(carry, xs)
|
|
out_carry, outs = f(carry, xs)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "division by zero")
|
|
self.assertArraysEqual(ch_outs, outs)
|
|
self.assertArraysEqual(ch_out_carry, out_carry)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_while_loop_body_error(self):
|
|
def while_cond(val):
|
|
i, _ = val
|
|
return i < 2
|
|
|
|
def while_body(val):
|
|
i, x = val
|
|
possible_nan = jnp.sin(1./i)
|
|
return i+1., x+possible_nan
|
|
|
|
@jax.jit
|
|
def f(init_val):
|
|
return lax.while_loop(while_cond, while_body, (init_val, 0.))
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
|
|
init_val = 1.
|
|
err, ch_out = checked_f(init_val)
|
|
out = f(init_val)
|
|
self.assertIsNone(err.get())
|
|
self.assertArraysEqual(ch_out, out)
|
|
|
|
init_val = 0.
|
|
err, ch_out = checked_f(init_val)
|
|
out = f(init_val)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "division by zero")
|
|
self.assertArraysEqual(ch_out, out)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_while_loop_cond_error(self):
|
|
def while_cond(val):
|
|
_ = jnp.sin(1./val)
|
|
return val < 2.
|
|
|
|
def while_body(val):
|
|
return val+1.
|
|
|
|
@jax.jit
|
|
def f(init_val):
|
|
return lax.while_loop(while_cond, while_body, init_val)
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
|
|
init_val = 1.
|
|
err, ch_out = checked_f(init_val)
|
|
out = f(init_val)
|
|
self.assertIsNone(err.get())
|
|
self.assertArraysEqual(ch_out, out)
|
|
|
|
init_val = 0.
|
|
err, ch_out = checked_f(init_val)
|
|
out = f(init_val)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "division by zero")
|
|
self.assertArraysEqual(ch_out, out)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_while_loop_cond_error_and_false(self):
|
|
# Tests if an error is generated when cond returns False.
|
|
def while_cond(val):
|
|
possible_nan = jnp.sin(1./val)
|
|
return jnp.logical_not(jnp.isnan(possible_nan))
|
|
|
|
@jax.jit
|
|
def f(init_val):
|
|
return lax.while_loop(while_cond, lambda val: val-1, init_val)
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
|
|
# error on first cond
|
|
init_val = 0.
|
|
err, _ = checked_f(init_val)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "division by zero")
|
|
|
|
# error on second cond
|
|
init_val = 1.
|
|
err, _ = checked_f(init_val)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "division by zero")
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_while_loop_body_and_cond_error(self):
|
|
def while_cond(val):
|
|
i, cond_val, _ = val
|
|
_ = jnp.sin(cond_val)
|
|
return i < 2
|
|
|
|
def while_body(val):
|
|
i, cond_val, body_val = val
|
|
possible_nan = jnp.cos(body_val)
|
|
return i+1., cond_val, possible_nan
|
|
|
|
@jax.jit
|
|
def f(cond_val, body_val):
|
|
return lax.while_loop(while_cond, while_body, (0., cond_val, body_val))
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
|
|
cond_val = jnp.inf
|
|
body_val = 1.
|
|
err, _ = checked_f(cond_val, body_val)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
|
|
|
cond_val = 1.
|
|
body_val = jnp.inf
|
|
err, _ = checked_f(cond_val, body_val)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: cos")
|
|
|
|
cond_val = jnp.inf
|
|
body_val = jnp.inf
|
|
err, _ = checked_f(cond_val, body_val)
|
|
self.assertIsNotNone(err.get())
|
|
# first error which occurs is in cond
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
|
|
|
def test_pjit(self):
|
|
def f(x):
|
|
# unary func
|
|
return x / x
|
|
|
|
def g(x, y):
|
|
# binary func
|
|
return x / y
|
|
|
|
devices = jax.local_devices()[:8] # Taking up to 8 devices
|
|
mesh = jax.sharding.Mesh(np.array(devices), ["dev"])
|
|
ps = NamedSharding(mesh, jax.sharding.PartitionSpec("dev"))
|
|
inp = np.arange(8)
|
|
x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx])
|
|
|
|
f = pjit.pjit(f, in_shardings=ps, out_shardings=ps)
|
|
f = checkify.checkify(f, errors=checkify.float_checks)
|
|
g = pjit.pjit(g, in_shardings=ps, out_shardings=ps)
|
|
g = checkify.checkify(g, errors=checkify.float_checks)
|
|
with mesh:
|
|
u_err, _ = f(x)
|
|
b_err, _ = g(x, x)
|
|
|
|
self.assertIsNotNone(u_err.get())
|
|
self.assertStartsWith(u_err.get(), "division by zero")
|
|
self.assertIsNotNone(b_err.get())
|
|
self.assertStartsWith(b_err.get(), "division by zero")
|
|
|
|
@parameterized.parameters(True, False)
|
|
def test_shard_map(self, check_rep):
|
|
def f(x):
|
|
# unary func
|
|
return jax.lax.axis_index("dev") * x / x
|
|
|
|
def g(x, y):
|
|
# binary func
|
|
return jax.lax.axis_index("dev") * x / y
|
|
|
|
devices = jax.local_devices()[:8] # Taking up to 8 devices
|
|
mesh = jax.sharding.Mesh(np.array(devices), ["dev"])
|
|
pspec = jax.sharding.PartitionSpec("dev")
|
|
ps = NamedSharding(mesh, pspec)
|
|
inp = np.tile(np.arange(4, dtype=np.int32), 2)
|
|
x = array.make_array_from_callback(inp.shape, ps, lambda idx: inp[idx])
|
|
|
|
f = shard_map.shard_map(
|
|
f, mesh, in_specs=pspec, out_specs=pspec, check_rep=check_rep
|
|
)
|
|
f = jax.jit(f, in_shardings=ps, out_shardings=ps)
|
|
f = checkify.checkify(f, errors=checkify.float_checks)
|
|
g = shard_map.shard_map(
|
|
g, mesh, in_specs=(pspec, pspec), out_specs=pspec, check_rep=check_rep
|
|
)
|
|
g = jax.jit(g, in_shardings=(ps, ps), out_shardings=ps)
|
|
g = checkify.checkify(g, errors=checkify.float_checks)
|
|
u_err, _ = f(x)
|
|
b_err, _ = g(x, x)
|
|
|
|
divbyzero = "division by zero"
|
|
expected_err = f"at mapped index 0: {divbyzero}"
|
|
if (next_device_with_zero := len(devices) // 2) != 0:
|
|
expected_err += f"\nat mapped index {next_device_with_zero}: {divbyzero}"
|
|
|
|
self.assertIsNotNone(u_err.get())
|
|
self.assertEqual(u_err.get(), expected_err)
|
|
self.assertIsNotNone(b_err.get())
|
|
self.assertEqual(b_err.get(), expected_err)
|
|
|
|
def test_empty_enabled_errors(self):
|
|
def multi_errors(x):
|
|
x = x/0 # DIV
|
|
x = jnp.sin(x) # NAN
|
|
x = x[500] # OOB
|
|
checkify.check(x < 0, "must be negative!") # ASSERT
|
|
return x
|
|
|
|
x = jnp.ones((2,))
|
|
err, _ = checkify.checkify(multi_errors, errors=set())(x)
|
|
self.assertIsNone(err.get())
|
|
|
|
@parameterized.named_parameters(
|
|
("assert", checkify.user_checks, "must be negative!"),
|
|
("div", checkify.div_checks, "division by zero"),
|
|
("nan", checkify.nan_checks, "nan generated"),
|
|
("oob", checkify.index_checks, "out-of-bounds indexing"),
|
|
("automatic_checks", checkify.automatic_checks, "division by zero"),
|
|
)
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_enabled_errors(self, error_set, expected_error):
|
|
def multi_errors(x):
|
|
checkify.check(jnp.all(x < 0), "must be negative!") # ASSERT
|
|
x = x/0 # DIV
|
|
x = jnp.sin(x) # NAN
|
|
x = x[500] # OOB
|
|
return x
|
|
|
|
x = jnp.ones((2,))
|
|
err, _ = checkify.checkify(multi_errors, errors=error_set)(x)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), expected_error)
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_post_process_call(self):
|
|
@partial(checkify.checkify, errors=checkify.float_checks)
|
|
def g(x):
|
|
@jax.jit
|
|
def f(y):
|
|
return jnp.sin(x * y)
|
|
return f(jnp.inf)
|
|
err, _ = g(2.)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_post_process_map(self):
|
|
@partial(checkify.checkify, errors=checkify.float_checks)
|
|
def g(x):
|
|
@jax.pmap
|
|
def f(y):
|
|
return jnp.sin(x * y), jnp.cos(x * y)
|
|
return f(jnp.array([jnp.inf]))[0]
|
|
err, _ = g(2.)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertIn("nan generated by primitive: sin", err.get())
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_custom_jvp(self):
|
|
@jax.custom_jvp
|
|
def sin(x):
|
|
return jnp.sin(x)
|
|
|
|
@sin.defjvp
|
|
def sin_jvp(primals, tangents):
|
|
(x,), (xdot,) = primals, tangents
|
|
return sin(x), jnp.cos(x) * xdot
|
|
|
|
f = checkify.checkify(sin, errors=checkify.float_checks)
|
|
|
|
err, y = f(3.)
|
|
self.assertIsNone(err.get())
|
|
err, y = f(jnp.inf)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), 'nan generated by primitive: sin')
|
|
|
|
# When we hit the custom jvp rule with jvp-of-checkify, no checks are added.
|
|
(err, y), (errdot, ydot) = jax.jvp(f, (3.,), (1.,)) # doesn't crash
|
|
self.assertIsNone(err.get()) # no error
|
|
self.assertEmpty(err._metadata) # and no checks were added!
|
|
self.assertEmpty(errdot._metadata)
|
|
y_expected, ydot_expected = jax.jvp(jnp.sin, (3.,), (1.,))
|
|
self.assertAllClose(y, y_expected)
|
|
self.assertAllClose(ydot, ydot_expected)
|
|
|
|
# Grad-of-checkify doesn't crash either.
|
|
x_bar = jax.grad(lambda x: f(x)[1])(3.)
|
|
self.assertAllClose(x_bar, jnp.cos(3.))
|
|
|
|
# Checkify-of-jvp adds checks (unlike jvp-of-checkify above).
|
|
g = checkify.checkify(lambda x, xdot: jax.jvp(sin, (x,), (xdot,)),
|
|
errors=checkify.float_checks)
|
|
err, (y, ydot) = g(3., 1.) # doesn't crash
|
|
self.assertIsNone(err.get()) # no error
|
|
self.assertNotEmpty(err._metadata) # but checks were added!
|
|
self.assertAllClose(y, jnp.sin(3.))
|
|
self.assertAllClose(ydot, jnp.cos(3.))
|
|
err, _ = g(jnp.inf, 1.)
|
|
self.assertIsNotNone(err.get()) # yes error
|
|
self.assertStartsWith(err.get(), 'nan generated by primitive: sin')
|
|
|
|
@jtu.skip_on_devices("tpu")
|
|
def test_custom_vjp(self):
|
|
@jax.custom_vjp
|
|
def sin(x):
|
|
return jnp.sin(x)
|
|
|
|
def sin_fwd(x):
|
|
return jnp.sin(x), 2. * x
|
|
def sin_bwd(x2, g):
|
|
return jnp.cos(x2 / 2.) * g,
|
|
sin.defvjp(sin_fwd, sin_bwd)
|
|
|
|
f = checkify.checkify(sin, errors=checkify.float_checks)
|
|
|
|
# no differentiation, no error
|
|
err, y = f(3.)
|
|
self.assertIsNone(err.get())
|
|
|
|
# no differentiation, yes error
|
|
err, y = f(jnp.inf)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), 'nan generated by primitive: sin')
|
|
|
|
# When we hit the custom vjp rule with vjp-of-checkify, no checks are added.
|
|
(err, y), f_vjp = jax.vjp(f, 3.)
|
|
self.assertIsNone(err.get()) # no error
|
|
self.assertEmpty(err._metadata) # and no checks were added!
|
|
|
|
# Checkify-of-vjp adds checks (unlike vjp-of-checkify above).
|
|
err, y = checkify.checkify(jax.grad(sin), errors=checkify.float_checks)(3.)
|
|
self.assertIsNone(err.get()) # no error
|
|
self.assertNotEmpty(err._metadata) # but checks were added!
|
|
err, y = checkify.checkify(jax.grad(sin),
|
|
errors=checkify.float_checks)(jnp.inf)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive: sin")
|
|
|
|
def test_scan_consts(self):
|
|
def f(xs):
|
|
def scan_body(carry, _):
|
|
# closes oves xs
|
|
return carry+1, xs[carry]
|
|
return lax.scan(scan_body, 1, xs)
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
|
err, _ = checked_f(jnp.ones((7,)))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "out-of-bounds indexing")
|
|
|
|
def test_scan_consts2(self):
|
|
def f(xs):
|
|
def scan_body(carry, _):
|
|
# add more consts!
|
|
_ = xs[carry], xs[carry], jnp.sin(np.arange(11.))
|
|
return carry+1, xs[carry]
|
|
return lax.scan(scan_body, 1, xs)[1]
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
|
err, _ = checked_f(jnp.ones((7, 3)))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "out-of-bounds indexing")
|
|
|
|
def test_while_consts(self):
|
|
def f(xs):
|
|
def while_cond(carry):
|
|
i, _ = carry
|
|
_ = xs[i], jnp.sin(np.arange(11.))
|
|
return i > -1
|
|
|
|
def while_body(carry):
|
|
i, _ = carry
|
|
x = xs[i]
|
|
return i - 1, x/i
|
|
|
|
return lax.while_loop(while_cond, while_body, (0, jnp.zeros_like(xs[0])))
|
|
|
|
checked_f = checkify.checkify(f, errors=checkify.float_checks)
|
|
err, _ = checked_f(jnp.ones((7, 3)))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "division by zero")
|
|
|
|
def test_multiple_payloads(self):
|
|
def f(x):
|
|
_ = x[5]
|
|
_ = x[6]
|
|
|
|
err, _ = checkify.checkify(f, errors=checkify.index_checks)(jnp.ones((2,)))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertIn("index 5", err.get())
|
|
|
|
def test_nd_payloads(self):
|
|
cf = checkify.checkify(lambda x, i: x[i], errors=checkify.index_checks)
|
|
errs, _ = jax.vmap(cf)(jnp.ones((3, 2)), jnp.array([5, 0, 100]))
|
|
self.assertIsNotNone(errs.get())
|
|
self.assertIn("index 5", errs.get())
|
|
self.assertIn("index 100", errs.get())
|
|
|
|
def test_mapped_error_one_payload(self):
|
|
def f(x, i):
|
|
x = x[i]
|
|
return x/0
|
|
|
|
cf = checkify.checkify(f, errors=checkify.automatic_checks)
|
|
errs, _ = jax.vmap(cf)(jnp.ones((2, 1)), jnp.array([0, 100]))
|
|
self.assertIsNotNone(errs.get())
|
|
self.assertIn("division by zero", errs.get())
|
|
self.assertIn("index 100", errs.get())
|
|
|
|
@jax.legacy_prng_key('allow')
|
|
def test_checking_key_split_with_nan_check(self):
|
|
cf = checkify.checkify(
|
|
lambda k: jax.random.permutation(k, jnp.array([0, 1, 2])),
|
|
errors=checkify.float_checks)
|
|
cf(jax.random.PRNGKey(123)) # does not crash.
|
|
|
|
def test_pmap_one_device(self):
|
|
@jax.pmap
|
|
def f(x, y):
|
|
return x/y
|
|
|
|
cf = checkify.checkify(f, errors=checkify.automatic_checks)
|
|
errs, _ = cf(jnp.ones((1,)), jnp.zeros((1,)))
|
|
self.assertIsNotNone(errs.get())
|
|
self.assertIn("division by zero", errs.get())
|
|
|
|
def test_psum_nan_check(self):
|
|
@partial(jax.vmap, axis_name="i")
|
|
def f(x, y):
|
|
return lax.psum((x, y), axis_name="i")
|
|
|
|
cf = checkify.checkify(f, errors=checkify.nan_checks)
|
|
err, _ = cf(jnp.array([-jnp.inf, 0, jnp.inf]), jnp.ones((3, 2)))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "nan generated by primitive")
|
|
|
|
def test_different_payload_effects(self):
|
|
def f(x, y):
|
|
x = x[y]
|
|
checkify.check(jnp.all(x > 0), "{x}", x=x)
|
|
return x
|
|
|
|
f = checkify.checkify(f, errors=checkify.all_checks)
|
|
err, _ = jax.vmap(f)(jnp.ones((2, 3))*-1, jnp.array([0, 5]))
|
|
self.assertIsNotNone(err.get())
|
|
|
|
def test_effects_total_ordering(self):
|
|
sds0 = jax.ShapeDtypeStruct((2,), jnp.float32)
|
|
sds1 = jax.ShapeDtypeStruct((2,), jnp.int32)
|
|
sds2 = jax.ShapeDtypeStruct((3,), jnp.int32)
|
|
self.assertTotallyOrdered(
|
|
[ErrorEffect(FailedCheckError, (sds0,))],
|
|
[ErrorEffect(FailedCheckError, (sds0, sds0))],
|
|
[ErrorEffect(FailedCheckError, (sds1,))],
|
|
[ErrorEffect(FailedCheckError, (sds1, sds0))],
|
|
[ErrorEffect(FailedCheckError, (sds2,))],
|
|
[ErrorEffect(OOBError, (sds0,))],
|
|
[ErrorEffect(OOBError, (sds0, sds0))],
|
|
)
|
|
|
|
def test_scan_xs_mapped_correctly(self):
|
|
def f(_, x):
|
|
return None, jnp.reshape(x, (2, 2))
|
|
|
|
@jax.jit
|
|
def g(x):
|
|
return jax.lax.scan(f, None, x)
|
|
|
|
checked_f = checkify.checkify(g)
|
|
checked_f = jax.jit(checked_f)
|
|
err, _ = checked_f(jnp.ones((2, 4)))
|
|
self.assertIsNone(err.get())
|
|
|
|
def test_retracing(self):
|
|
f = checkify.checkify(jax.jit(lambda x: jnp.sin(x) ** 2))
|
|
_ = f(3.)
|
|
with jtu.count_jit_and_pmap_lowerings() as count:
|
|
_ = f(3.)
|
|
self.assertEqual(count(), 0)
|
|
|
|
def test_goodfellow_custom_jvp(self):
|
|
def h(fext):
|
|
checkify.check(True, "")
|
|
return jax.nn.relu(fext)
|
|
|
|
h = checkify.checkify(h)
|
|
|
|
def h_out(fext):
|
|
_, out = h(fext)
|
|
return out
|
|
|
|
h_grad = jax.grad(h_out)
|
|
h_grad(0.) # doesn't crash
|
|
|
|
def test_goodfellow_custom_vjp(self):
|
|
@jax.custom_vjp
|
|
def sin(x):
|
|
return jnp.sin(x)
|
|
def sin_fwd(x):
|
|
return jnp.sin(x), 2. * x
|
|
def sin_bwd(x2, g):
|
|
return jnp.cos(x2 / 2.) * g,
|
|
sin.defvjp(sin_fwd, sin_bwd)
|
|
|
|
def h(fext):
|
|
checkify.check(True, "")
|
|
return sin(fext)
|
|
|
|
h = checkify.checkify(h)
|
|
|
|
def h_out(fext):
|
|
_, out = h(fext)
|
|
return out
|
|
|
|
h_grad = jax.grad(h_out)
|
|
h_grad(0.) # doesn't crash
|
|
|
|
def test_closed_call(self):
|
|
# lots of golfing went into this test
|
|
y = jnp.array([3.14])
|
|
summify = lambda f: lambda x: f(x).sum()
|
|
f = checkify.checkify(jax.grad(summify(jax.remat(
|
|
partial(partial, jax.lax.map)(lambda x: jnp.sin(x * y))))))
|
|
f(jnp.array([3.])) # don't crash
|
|
|
|
def test_while_loop_leaks(self):
|
|
def f(x):
|
|
n = jnp.minimum(1, 2)
|
|
return jax.lax.while_loop(lambda i: i < n, lambda i: i + 1, x)
|
|
|
|
jax.jit(checkify.checkify(f))(0) # Does not crash bc of leaked tracer.
|
|
|
|
@parameterized.parameters(True, False)
|
|
def test_remat(self, jit):
|
|
# basic test from https://github.com/jax-ml/jax/issues/23867
|
|
def fn(x: jax.Array):
|
|
checkify.check(jnp.all(x > 0), "x must be positive")
|
|
return x + 1
|
|
|
|
fn = jax.remat(fn)
|
|
if jit:
|
|
fn = jax.jit(fn)
|
|
fn = checkify.checkify(fn)
|
|
err, y = fn(jnp.array([1, 2, 3]))
|
|
self.assertIsNone(err.get())
|
|
self.assertAllClose(y, jnp.array([2, 3, 4]), check_dtypes=False)
|
|
|
|
err, _ = fn(jnp.array([0, 2, 3]))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "x must be positive")
|
|
|
|
|
|
@jtu.with_config(jax_check_tracer_leaks=True)
|
|
class AssertPrimitiveTests(jtu.JaxTestCase):
|
|
|
|
def test_assert_primitive_impl(self):
|
|
def f():
|
|
checkify.check(False, "hi")
|
|
|
|
with self.assertRaisesRegex(JaxRuntimeError, "hi"):
|
|
f()
|
|
|
|
def test_assert_primitive_lowering(self):
|
|
@jax.jit
|
|
def f():
|
|
checkify.check(False, "hi")
|
|
|
|
with self.assertRaisesRegex(ValueError, "Cannot abstractly evaluate"):
|
|
f()
|
|
|
|
def test_assert_primitive_jaxpr_effects(self):
|
|
def f(x):
|
|
checkify.check(False, "hi: {}", x)
|
|
|
|
jaxpr = jax.make_jaxpr(f)(jnp.ones(4, jnp.int32))
|
|
self.assertSetEqual(jaxpr.effects,
|
|
{ErrorEffect(FailedCheckError, (
|
|
jax.ShapeDtypeStruct((4,), jnp.int32),))})
|
|
def g(x, y):
|
|
checkify.check(False, "hi: {} {}", x, y)
|
|
|
|
self.assertSetEqual(
|
|
jax.make_jaxpr(g)(
|
|
jnp.ones(4, jnp.int32), jnp.ones(2, jnp.float32)).effects,
|
|
{ErrorEffect(FailedCheckError, (
|
|
jax.ShapeDtypeStruct((4,), jnp.int32),
|
|
jax.ShapeDtypeStruct((2,), jnp.float32)))})
|
|
|
|
def test_assert_primitive_eval_shape(self):
|
|
# The check is abstractly evaluated but not lowered.
|
|
def f():
|
|
checkify.check(False, "hi")
|
|
|
|
jax.eval_shape(f) # does not crash.
|
|
|
|
def test_assert_discharging(self):
|
|
@checkify.checkify
|
|
def f(x):
|
|
checkify.check(x > 0, "must be positive!")
|
|
return jnp.log(x)
|
|
|
|
err, _ = f(1.)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = f(0.)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "must be positive")
|
|
|
|
f = jax.jit(f)
|
|
|
|
err, _ = f(1.)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = f(0.)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "must be positive")
|
|
|
|
def test_assert_discharging_no_data_dependence(self):
|
|
@jax.jit
|
|
def g(x):
|
|
@checkify.checkify
|
|
def f():
|
|
# Note that x is not an argument to the checkified function.
|
|
checkify.check(x > 0, "must be positive!")
|
|
return jnp.log(x)
|
|
return f()
|
|
|
|
err, _ = g(1.)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = g(0.)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "must be positive")
|
|
|
|
def test_assert_discharging_scan(self):
|
|
def body(carry, x):
|
|
checkify.check(jnp.all(x > 0), "must be positive")
|
|
return carry, x
|
|
|
|
def f(x):
|
|
return jax.lax.scan(body, (None,), x)
|
|
|
|
err, _ = checkify.checkify(f)(jnp.array([-1]))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "must be positive")
|
|
|
|
err, _ = checkify.checkify(f)(jnp.array([1, 0, -1]))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "must be positive")
|
|
|
|
def test_assert_discharging_while_loop(self):
|
|
def while_cond(val):
|
|
i, _ = val
|
|
checkify.check(i < 0, "i must be negative")
|
|
return i < 2
|
|
|
|
def while_body(val):
|
|
i, x = val
|
|
checkify.check(x < 0, "x must be negative")
|
|
return i+1., x+1
|
|
|
|
@jax.jit
|
|
def f(init_i, init_val):
|
|
return lax.while_loop(while_cond, while_body, (init_i, init_val))
|
|
|
|
checked_f = checkify.checkify(f)
|
|
|
|
err, _ = checked_f(0, 1)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "i must be negative")
|
|
|
|
err, _ = checked_f(-1, 0)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "x must be negative")
|
|
|
|
def test_assert_discharging_cond(self):
|
|
def true_branch(x):
|
|
checkify.check(jnp.all(x != 0.), "x cannot be 0")
|
|
return 1/x
|
|
|
|
def false_branch(x):
|
|
checkify.check(jnp.all(x >= 0), "x must be positive")
|
|
return x*2
|
|
|
|
@jax.jit
|
|
def f(pred, x):
|
|
return lax.cond(pred, true_branch, false_branch, x)
|
|
|
|
checked_f = checkify.checkify(f)
|
|
|
|
err, _ = checked_f(True, 0.)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "x cannot be 0")
|
|
err, _ = checked_f(False, 0.)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = checked_f(False, -1.)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "x must be positive")
|
|
err, _ = checked_f(True, -1.)
|
|
self.assertIsNone(err.get())
|
|
|
|
def test_assert_batching_rule(self):
|
|
@jax.vmap
|
|
def f(x):
|
|
checkify.check(jnp.sum(x) == 1., "x must sum to one.")
|
|
return x
|
|
|
|
no_failures = jnp.array([[0.5, 0.5], [1., 0.]])
|
|
one_batch_fails = jnp.array([[0.5, 0.5], [1, 1]])
|
|
mult_batch_fail = jnp.array([[0.5, 0.5], [1, 1], [2, 2]])
|
|
|
|
f(no_failures)
|
|
with self.assertRaisesRegex(JaxRuntimeError, "x must sum to one."):
|
|
f(one_batch_fails)
|
|
|
|
with self.assertRaisesRegex(JaxRuntimeError, "x must sum to one."):
|
|
f(mult_batch_fail)
|
|
|
|
checked_f = checkify.checkify(f)
|
|
err, _ = checked_f(no_failures)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = checked_f(one_batch_fails)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "x must sum to one")
|
|
|
|
err, _ = checked_f(mult_batch_fail)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "x must sum to one")
|
|
|
|
def test_check_error(self):
|
|
def g():
|
|
checkify.check(False, "hi")
|
|
def f():
|
|
err, _ = checkify.checkify(g)()
|
|
checkify.check_error(err)
|
|
|
|
with self.assertRaisesRegex(JaxRuntimeError, "hi"):
|
|
f()
|
|
|
|
f = checkify.checkify(f)
|
|
err, none = f()
|
|
|
|
self.assertIsNone(none)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "hi")
|
|
|
|
def test_check_error_scanned(self):
|
|
def body(carry, x):
|
|
checkify.check(jnp.all(x > 0), "should be positive")
|
|
return carry, x
|
|
|
|
def checked_body(carry, x):
|
|
err, (carry, x) = checkify.checkify(body)(carry, x)
|
|
return carry, (x, err)
|
|
|
|
def f(x):
|
|
_, (xs, errs) = jax.lax.scan(checked_body, (None,), x)
|
|
checkify.check_error(errs)
|
|
return xs
|
|
|
|
err, _ = checkify.checkify(f)(jnp.array([-1]))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "should be positive")
|
|
|
|
err, _ = checkify.checkify(f)(jnp.array([1, 0, -1]))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "should be positive")
|
|
|
|
def test_discharge_recharge(self):
|
|
def ejit(f):
|
|
f = checkify.checkify(f)
|
|
f = jax.jit(f)
|
|
def jitted_f(*args):
|
|
err, out = f(*args)
|
|
checkify.check_error(err)
|
|
return out
|
|
return jitted_f
|
|
|
|
@ejit
|
|
def f(pred):
|
|
assert python_should_be_running
|
|
checkify.check(pred, "foo")
|
|
|
|
python_should_be_running = True
|
|
f(True)
|
|
|
|
python_should_be_running = False
|
|
f(True)
|
|
with self.assertRaisesRegex(JaxRuntimeError, "foo"):
|
|
f(False)
|
|
|
|
def test_cond_of_named_call(self):
|
|
def g(x):
|
|
branch = jax.named_call(lambda x: x)
|
|
out = jax.lax.cond(True, branch, branch, x)
|
|
return out
|
|
|
|
checkify.checkify(g)(0.) # does not crash
|
|
|
|
def test_grad(self):
|
|
@jax.grad
|
|
def f(x):
|
|
checkify.check(jnp.all(x > 0), "should be positive!")
|
|
return x
|
|
|
|
f = checkify.checkify(f)
|
|
err, _ = f(1.)
|
|
self.assertIsNone(err.get())
|
|
|
|
err, _ = f(0.)
|
|
self.assertIsNotNone(err.get())
|
|
self.assertIn("should be positive", err.get())
|
|
|
|
def test_checkify_of_vmap_of_while_errors(self):
|
|
@jax.vmap
|
|
def fun(n, v):
|
|
def while_cond(s):
|
|
counter, value = s
|
|
checkify.check(value < 6, "value needs to be less than 6!")
|
|
return counter > 0
|
|
|
|
def while_body(s):
|
|
counter, value = s
|
|
checkify.check(value >= 0, "value needs to be positive!")
|
|
return counter/value, value - 1.
|
|
|
|
_, result = jax.lax.while_loop(while_cond, while_body, (n, v))
|
|
return result
|
|
|
|
checked_f = checkify.checkify(fun, errors=checkify.all_checks)
|
|
|
|
with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"):
|
|
checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., 4.]))
|
|
# TODO(lenamartens): reenable assertions below.
|
|
# self.assertIsNotNone(err.get())
|
|
# self.assertStartsWith(err.get(), "division by zero")
|
|
|
|
# err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([5., 2., -4.]))
|
|
# self.assertIsNotNone(err.get())
|
|
# self.assertStartsWith(err.get(), "value needs to be positive")
|
|
|
|
# err, _ = checked_f(jnp.asarray([1., 2., 3.]), jnp.asarray([6., 2., -4.]))
|
|
# self.assertIsNotNone(err.get())
|
|
# self.assertStartsWith(err.get(), "value needs to be less than 6")
|
|
|
|
def test_checkify_of_vmap_of_while_masked_errors(self):
|
|
def cond(x):
|
|
return x < 5
|
|
|
|
def body(x):
|
|
# This will only trigger in the masked portion of the batched while.
|
|
checkify.check(x < 5, "should never happen")
|
|
return x + 1
|
|
|
|
@jax.vmap
|
|
def fun(x):
|
|
return lax.while_loop(cond, body, x)
|
|
|
|
checked_f = checkify.checkify(fun)
|
|
|
|
with self.assertRaisesRegex(ValueError, "checkify-of-vmap-of-while"):
|
|
checked_f(jnp.arange(5))
|
|
# TODO(lenamartens): reenable assertions below.
|
|
# self.assertIsNone(err.get())
|
|
|
|
def test_assert_cond_no_data_dependence(self):
|
|
def true_fun():
|
|
return checkify.check(False, "hi!")
|
|
def false_fun():
|
|
return checkify.check(False, "bye!")
|
|
def f():
|
|
return jax.lax.cond(True, true_fun, false_fun)
|
|
|
|
f = checkify.checkify(f)
|
|
err, _ = f()
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "hi!")
|
|
|
|
def test_assert_switch_no_data_dependence(self):
|
|
def branch():
|
|
checkify.check(False, "hi!")
|
|
|
|
def f():
|
|
return lax.switch(0, [branch]*3)
|
|
|
|
checked_f = checkify.checkify(f)
|
|
|
|
err, _ = checked_f()
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "hi!")
|
|
|
|
def test_debug_check_noop(self):
|
|
def f(x):
|
|
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
|
|
return x
|
|
x = jnp.ones(())
|
|
f(x) # no error.
|
|
jax.jit(f)(x) # no error.
|
|
jax.vmap(f)(jnp.ones((2,))) # no error.
|
|
jax.grad(f)(x) # no error.
|
|
|
|
@parameterized.named_parameters(("with_jit", True), ("without_jit", False))
|
|
def test_debug_check_nonscalar_pred(self, with_jit):
|
|
def f(x):
|
|
checkify.debug_check(x != x, "{x} cannot be {x}", x=x)
|
|
return x
|
|
checked_f = checkify.checkify(f)
|
|
if with_jit:
|
|
checked_f = jax.jit(checked_f)
|
|
|
|
with self.assertRaisesRegex(TypeError, "debug_check takes a scalar pred"):
|
|
checked_f(jnp.ones((5,)))
|
|
|
|
|
|
@parameterized.named_parameters(("with_jit", True), ("without_jit", False))
|
|
def test_debug_check(self, with_jit):
|
|
def f(x):
|
|
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
|
|
return x
|
|
checked_f = checkify.checkify(f)
|
|
if with_jit:
|
|
checked_f = jax.jit(checked_f)
|
|
err, _ = checked_f(jnp.ones(()))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "1.0 cannot be 1.0")
|
|
|
|
@parameterized.named_parameters(("with_jit", True), ("without_jit", False))
|
|
def test_debug_check_disabled_errors(self, with_jit):
|
|
def f(x):
|
|
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
|
|
return x
|
|
checked_f = checkify.checkify(f, errors={})
|
|
if with_jit:
|
|
checked_f = jax.jit(checked_f)
|
|
err, _ = checked_f(jnp.ones((1,)))
|
|
self.assertIsNone(err.get())
|
|
|
|
def test_debug_check_jaxpr_roundtrip(self):
|
|
def f(x):
|
|
checkify.debug_check(jnp.all(x != x), "{x} cannot be {x}", x=x)
|
|
return x
|
|
x = jnp.ones(())
|
|
jaxpr = jax.make_jaxpr(f)(x)
|
|
roundtrip_f = partial(core.eval_jaxpr, jaxpr.jaxpr, jaxpr.consts)
|
|
checked_f = checkify.checkify(jax.jit(roundtrip_f))
|
|
err, _ = checked_f(jnp.ones(()))
|
|
self.assertIsNotNone(err.get())
|
|
self.assertStartsWith(err.get(), "1.0 cannot be 1.0")
|
|
|
|
def test_fmt_args_array_type_error(self):
|
|
args_error = lambda: checkify.check(False, "{} world", "hello")
|
|
with self.assertRaisesRegex(TypeError, "Formatting arguments"):
|
|
checkify.checkify(args_error)()
|
|
|
|
kwargs_error = lambda: checkify.check(False, "{hello} world", hello="hello")
|
|
with self.assertRaisesRegex(TypeError, "Formatting arguments"):
|
|
checkify.checkify(kwargs_error)()
|
|
|
|
np_arrays_ok = lambda: checkify.check(False, "{} world", np.array(1.))
|
|
checkify.checkify(np_arrays_ok)()
|
|
|
|
trees_ok = lambda: checkify.check(False, "{}", {"hello": jnp.array(1.)})
|
|
checkify.checkify(trees_ok)()
|
|
|
|
def test_checkify_non_jax_type_input(self):
|
|
_ = checkify.checkify(lambda x: 1.)("hi") # does not crash
|
|
|
|
def test_checkify_static_args(self):
|
|
@checkify.checkify
|
|
def f(x):
|
|
if x:
|
|
return
|
|
|
|
_ = jax.jit(f, static_argnums=(0,))(True)
|
|
|
|
def test_check_pp_rule(self):
|
|
jaxpr = jax.make_jaxpr(lambda: checkify.check(False, "hi"))()
|
|
jaxpr.pretty_print(source_info=True, name_stack=True) # Does not crash.
|
|
|
|
|
|
class LowerableChecksTest(jtu.JaxTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
self.enter_context(config.xla_runtime_errors(True))
|
|
|
|
@jtu.run_on_devices("cpu", "gpu")
|
|
def test_jit(self):
|
|
@jax.jit
|
|
def f(x):
|
|
checkify.check(x > 0, "x needs to be positive")
|
|
return x
|
|
|
|
with self.assertRaisesRegex(xla_extension.XlaRuntimeError,
|
|
"x needs to be positive"):
|
|
f(-1.)
|
|
|
|
if __name__ == "__main__":
|
|
absltest.main(testLoader=jtu.JaxTestLoader())
|