mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
checkify: dynamic_update_slice OOB index check
This commit is contained in:
parent
cabf8b7302
commit
8d1cf99825
@ -591,6 +591,22 @@ def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, sl
|
||||
return error, out
|
||||
error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check
|
||||
|
||||
def dynamic_update_slice_error_check(error, enabled_errors, operand, update, *start_indices):
|
||||
out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices)
|
||||
|
||||
if OOBError not in enabled_errors:
|
||||
return error, out
|
||||
|
||||
operand_dims = np.array(operand.shape)
|
||||
update_dims = np.array(update.shape)
|
||||
start_indices = jnp.array(start_indices)
|
||||
oob_mask = (start_indices < 0) | (start_indices + update_dims > operand_dims)
|
||||
|
||||
payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape)
|
||||
error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "dynamic_update_slice", operand.shape, payload))
|
||||
return error, out
|
||||
error_checks[lax.dynamic_update_slice_p] = dynamic_update_slice_error_check
|
||||
|
||||
def gather_error_check(error, enabled_errors, operand, start_indices, *,
|
||||
dimension_numbers, slice_sizes, unique_indices,
|
||||
indices_are_sorted, mode, fill_value):
|
||||
|
@ -198,6 +198,24 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, 8), 'index 8')
|
||||
raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, -10), 'index -3')
|
||||
|
||||
def test_dynamic_update_slice_oobs(self):
|
||||
def raises_oob(fn, x, y, idx, *expected_strs):
|
||||
err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, y, idx)
|
||||
error_txt = err.get()
|
||||
self.assertIsNotNone(error_txt)
|
||||
self.assertStartsWith(error_txt, "out-of-bounds indexing")
|
||||
for s in expected_strs:
|
||||
self.assertIn(s, error_txt)
|
||||
|
||||
x = jnp.ones((2, 3, 7))
|
||||
y = jnp.zeros((1, 1, 1))
|
||||
raises_oob(lax.dynamic_update_slice, x, y, (2, 0, 0), 'index 2')
|
||||
raises_oob(lax.dynamic_update_slice, x, y, (-3, 0, 0), 'index -1')
|
||||
raises_oob(lax.dynamic_update_slice, x, y, (0, 3, 0), 'index 3')
|
||||
raises_oob(lax.dynamic_update_slice, x, y, (0, -5, 0), 'index -2')
|
||||
raises_oob(lax.dynamic_update_slice, x, y, (0, 1, 8), 'index 8')
|
||||
raises_oob(lax.dynamic_update_slice, x, y, (0, 1, -10), 'index -3')
|
||||
|
||||
@jtu.sample_product(jit=[False, True])
|
||||
def test_jit_ordering(self, jit):
|
||||
def f(x, i):
|
||||
|
Loading…
x
Reference in New Issue
Block a user