diff --git a/jax/experimental/checkify.py b/jax/experimental/checkify.py index 410109bb9..4365f4c64 100644 --- a/jax/experimental/checkify.py +++ b/jax/experimental/checkify.py @@ -287,6 +287,52 @@ def div_error_check(error, x, y): return nan_error_check(lax.div_p, div_by_zero_err, x, y) error_checks[lax.div_p] = div_error_check +def scatter_in_bounds(operand, indices, updates, dnums): + # Ref: see clamping code used in scatter_translation_rule + slice_sizes = [] + pos = 0 + for i in range(len(operand.shape)): + if i in dnums.inserted_window_dims: + slice_sizes.append(1) + else: + slice_sizes.append(updates.shape[dnums.update_window_dims[pos]]) + pos += 1 + + upper_bound = np.array([operand.shape[i] - slice_sizes[i] + for i in dnums.scatter_dims_to_operand_dims], + np.int64) + upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max) + upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape, + (len(indices.shape) - 1,)) + + lower_in_bounds = jnp.all(jnp.greater_equal(indices, 0)) + upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound)) + return jnp.logical_and(lower_in_bounds, upper_in_bounds) + +def scatter_error_check(prim, error, operand, indices, updates, *, + update_jaxpr, update_consts, + dimension_numbers, indices_are_sorted, + unique_indices, mode): + """Checks if indices are within bounds and update does not generate NaN.""" + out = prim.bind( + operand, indices, updates, update_jaxpr=update_jaxpr, + update_consts=update_consts, dimension_numbers=dimension_numbers, + indices_are_sorted=indices_are_sorted, unique_indices=unique_indices, + mode=mode) + + in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers) + oob_msg = f'out-of-bounds indexing while updating at {summary()}' + oob_error = assert_func(error, in_bounds, oob_msg) + + no_nans = jnp.logical_not(jnp.any(jnp.isnan(out))) + nan_msg = f'nan generated by primitive {prim.name} at {summary()}' + return out, assert_func(oob_error, no_nans, nan_msg) +error_checks[lax.scatter_p] = partial(scatter_error_check, lax.scatter_p) +error_checks[lax.scatter_add_p] = partial(scatter_error_check, lax.scatter_add_p) +error_checks[lax.scatter_mul_p] = partial(scatter_error_check, lax.scatter_mul_p) +error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p) +error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p) + def cond_error_check(error, index, *ops, branches, linear): new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches) new_linear = (False, False, *linear) @@ -439,4 +485,3 @@ add_nan_check(lax.abs_p) add_nan_check(lax.select_p) add_nan_check(lax.max_p) add_nan_check(lax.min_p) -add_nan_check(lax.scatter_add_p) diff --git a/tests/checkify_test.py b/tests/checkify_test.py index a6c1f281e..5c27317d0 100644 --- a/tests/checkify_test.py +++ b/tests/checkify_test.py @@ -66,6 +66,23 @@ class CheckifyTransformTests(jtu.JaxTestCase): self.assertIsNotNone(err.get()) self.assertStartsWith(err.get(), 'out-of-bounds indexing') + @parameterized.named_parameters( + {"testcase_name": f"_update={update_fn}", "update_fn": update_fn} + for update_fn in ["set", "add", "multiply", "divide", "power", "min", + "max", "get"]) + def test_jit_oob_update(self, update_fn): + def f(x, i): + return getattr(x.at[i], update_fn)(1.) + + f = jax.jit(f) + + err, _ = checkify.checkify(f)(jnp.arange(3), 2) + self.assertIs(err.get(), None) + + err, _ = checkify.checkify(f)(jnp.arange(3), 3) + self.assertIsNotNone(err.get()) + self.assertStartsWith(err.get(), 'out-of-bounds indexing') + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_jit={}".format(jit), "jit": jit} for jit in [False, True]))