checkify: dynamic_update_slice OOB index check

This commit is contained in:
Jake VanderPlas 2023-04-17 13:42:32 -07:00
parent cabf8b7302
commit 8d1cf99825
2 changed files with 34 additions and 0 deletions

View File

@ -591,6 +591,22 @@ def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, sl
return error, out
error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check
def dynamic_update_slice_error_check(error, enabled_errors, operand, update, *start_indices):
out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
if OOBError not in enabled_errors:
return error, out
operand_dims = np.array(operand.shape)
update_dims = np.array(update.shape)
start_indices = jnp.array(start_indices)
oob_mask = (start_indices < 0) | (start_indices + update_dims > operand_dims)
payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "dynamic_update_slice", operand.shape, payload))
return error, out
error_checks[lax.dynamic_update_slice_p] = dynamic_update_slice_error_check
def gather_error_check(error, enabled_errors, operand, start_indices, *,
dimension_numbers, slice_sizes, unique_indices,
indices_are_sorted, mode, fill_value):

View File

@ -198,6 +198,24 @@ class CheckifyTransformTests(jtu.JaxTestCase):
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, 8), 'index 8')
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, -10), 'index -3')
def test_dynamic_update_slice_oobs(self):
def raises_oob(fn, x, y, idx, *expected_strs):
err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, y, idx)
error_txt = err.get()
self.assertIsNotNone(error_txt)
self.assertStartsWith(error_txt, "out-of-bounds indexing")
for s in expected_strs:
self.assertIn(s, error_txt)
x = jnp.ones((2, 3, 7))
y = jnp.zeros((1, 1, 1))
raises_oob(lax.dynamic_update_slice, x, y, (2, 0, 0), 'index 2')
raises_oob(lax.dynamic_update_slice, x, y, (-3, 0, 0), 'index -1')
raises_oob(lax.dynamic_update_slice, x, y, (0, 3, 0), 'index 3')
raises_oob(lax.dynamic_update_slice, x, y, (0, -5, 0), 'index -2')
raises_oob(lax.dynamic_update_slice, x, y, (0, 1, 8), 'index 8')
raises_oob(lax.dynamic_update_slice, x, y, (0, 1, -10), 'index -3')
@jtu.sample_product(jit=[False, True])
def test_jit_ordering(self, jit):
def f(x, i):