Add scatter OOB error.

This commit is contained in:
Lena Martens 2021-12-23 15:23:58 +00:00 committed by lenamartens
parent f35014d655
commit 7b5b9cefbd
2 changed files with 63 additions and 1 deletions

View File

@ -287,6 +287,52 @@ def div_error_check(error, x, y):
return nan_error_check(lax.div_p, div_by_zero_err, x, y)
error_checks[lax.div_p] = div_error_check
def scatter_in_bounds(operand, indices, updates, dnums):
# Ref: see clamping code used in scatter_translation_rule
slice_sizes = []
pos = 0
for i in range(len(operand.shape)):
if i in dnums.inserted_window_dims:
slice_sizes.append(1)
else:
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
pos += 1
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
for i in dnums.scatter_dims_to_operand_dims],
np.int64)
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
(len(indices.shape) - 1,))
lower_in_bounds = jnp.all(jnp.greater_equal(indices, 0))
upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
return jnp.logical_and(lower_in_bounds, upper_in_bounds)
def scatter_error_check(prim, error, operand, indices, updates, *,
update_jaxpr, update_consts,
dimension_numbers, indices_are_sorted,
unique_indices, mode):
"""Checks if indices are within bounds and update does not generate NaN."""
out = prim.bind(
operand, indices, updates, update_jaxpr=update_jaxpr,
update_consts=update_consts, dimension_numbers=dimension_numbers,
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
mode=mode)
in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers)
oob_msg = f'out-of-bounds indexing while updating at {summary()}'
oob_error = assert_func(error, in_bounds, oob_msg)
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
nan_msg = f'nan generated by primitive {prim.name} at {summary()}'
return out, assert_func(oob_error, no_nans, nan_msg)
error_checks[lax.scatter_p] = partial(scatter_error_check, lax.scatter_p)
error_checks[lax.scatter_add_p] = partial(scatter_error_check, lax.scatter_add_p)
error_checks[lax.scatter_mul_p] = partial(scatter_error_check, lax.scatter_mul_p)
error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p)
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
def cond_error_check(error, index, *ops, branches, linear):
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
new_linear = (False, False, *linear)
@ -439,4 +485,3 @@ add_nan_check(lax.abs_p)
add_nan_check(lax.select_p)
add_nan_check(lax.max_p)
add_nan_check(lax.min_p)
add_nan_check(lax.scatter_add_p)

View File

@ -66,6 +66,23 @@ class CheckifyTransformTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
@parameterized.named_parameters(
{"testcase_name": f"_update={update_fn}", "update_fn": update_fn}
for update_fn in ["set", "add", "multiply", "divide", "power", "min",
"max", "get"])
def test_jit_oob_update(self, update_fn):
def f(x, i):
return getattr(x.at[i], update_fn)(1.)
f = jax.jit(f)
err, _ = checkify.checkify(f)(jnp.arange(3), 2)
self.assertIs(err.get(), None)
err, _ = checkify.checkify(f)(jnp.arange(3), 3)
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_jit={}".format(jit), "jit": jit}
for jit in [False, True]))