From 8d1cf998257d3c7ce81e422faff433c4cc5d8c80 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 17 Apr 2023 13:42:32 -0700 Subject: [PATCH] checkify: dynamic_update_slice OOB index check --- jax/_src/checkify.py | 16 ++++++++++++++++ tests/checkify_test.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+) diff --git a/jax/_src/checkify.py b/jax/_src/checkify.py index 578db6be3..a49181390 100644 --- a/jax/_src/checkify.py +++ b/jax/_src/checkify.py @@ -591,6 +591,22 @@ def dynamic_slice_error_check(error, enabled_errors, operand, *start_indices, sl return error, out error_checks[lax.dynamic_slice_p] = dynamic_slice_error_check +def dynamic_update_slice_error_check(error, enabled_errors, operand, update, *start_indices): + out = lax.dynamic_update_slice_p.bind(operand, update, *start_indices) + + if OOBError not in enabled_errors: + return error, out + + operand_dims = np.array(operand.shape) + update_dims = np.array(update.shape) + start_indices = jnp.array(start_indices) + oob_mask = (start_indices < 0) | (start_indices + update_dims > operand_dims) + + payload = oob_payload(oob_mask, start_indices, range(operand.ndim), operand.shape) + error = assert_func(error, jnp.any(oob_mask), OOBError(summary(), "dynamic_update_slice", operand.shape, payload)) + return error, out +error_checks[lax.dynamic_update_slice_p] = dynamic_update_slice_error_check + def gather_error_check(error, enabled_errors, operand, start_indices, *, dimension_numbers, slice_sizes, unique_indices, indices_are_sorted, mode, fill_value): diff --git a/tests/checkify_test.py b/tests/checkify_test.py index c6183d229..7eb905a0d 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -198,6 +198,24 @@ class CheckifyTransformTests(jtu.JaxTestCase): raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, 8), 'index 8') raises_oob(partial(lax.dynamic_slice, slice_sizes=(1, 1, 1)), x, (0, 1, -10), 'index -3') + def test_dynamic_update_slice_oobs(self): + def raises_oob(fn, x, y, idx, *expected_strs): + err, _ = checkify.checkify(jax.jit(fn), errors=checkify.index_checks)(x, y, idx) + error_txt = err.get() + self.assertIsNotNone(error_txt) + self.assertStartsWith(error_txt, "out-of-bounds indexing") + for s in expected_strs: + self.assertIn(s, error_txt) + + x = jnp.ones((2, 3, 7)) + y = jnp.zeros((1, 1, 1)) + raises_oob(lax.dynamic_update_slice, x, y, (2, 0, 0), 'index 2') + raises_oob(lax.dynamic_update_slice, x, y, (-3, 0, 0), 'index -1') + raises_oob(lax.dynamic_update_slice, x, y, (0, 3, 0), 'index 3') + raises_oob(lax.dynamic_update_slice, x, y, (0, -5, 0), 'index -2') + raises_oob(lax.dynamic_update_slice, x, y, (0, 1, 8), 'index 8') + raises_oob(lax.dynamic_update_slice, x, y, (0, 1, -10), 'index -3') + @jtu.sample_product(jit=[False, True]) def test_jit_ordering(self, jit): def f(x, i):