mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Add scatter OOB error.
This commit is contained in:
parent
f35014d655
commit
7b5b9cefbd
@ -287,6 +287,52 @@ def div_error_check(error, x, y):
|
||||
return nan_error_check(lax.div_p, div_by_zero_err, x, y)
|
||||
error_checks[lax.div_p] = div_error_check
|
||||
|
||||
def scatter_in_bounds(operand, indices, updates, dnums):
|
||||
# Ref: see clamping code used in scatter_translation_rule
|
||||
slice_sizes = []
|
||||
pos = 0
|
||||
for i in range(len(operand.shape)):
|
||||
if i in dnums.inserted_window_dims:
|
||||
slice_sizes.append(1)
|
||||
else:
|
||||
slice_sizes.append(updates.shape[dnums.update_window_dims[pos]])
|
||||
pos += 1
|
||||
|
||||
upper_bound = np.array([operand.shape[i] - slice_sizes[i]
|
||||
for i in dnums.scatter_dims_to_operand_dims],
|
||||
np.int64)
|
||||
upper_bound = np.minimum(upper_bound, np.iinfo(indices.dtype).max)
|
||||
upper_bound = lax.broadcast_in_dim(upper_bound, indices.shape,
|
||||
(len(indices.shape) - 1,))
|
||||
|
||||
lower_in_bounds = jnp.all(jnp.greater_equal(indices, 0))
|
||||
upper_in_bounds = jnp.all(jnp.less_equal(indices, upper_bound))
|
||||
return jnp.logical_and(lower_in_bounds, upper_in_bounds)
|
||||
|
||||
def scatter_error_check(prim, error, operand, indices, updates, *,
|
||||
update_jaxpr, update_consts,
|
||||
dimension_numbers, indices_are_sorted,
|
||||
unique_indices, mode):
|
||||
"""Checks if indices are within bounds and update does not generate NaN."""
|
||||
out = prim.bind(
|
||||
operand, indices, updates, update_jaxpr=update_jaxpr,
|
||||
update_consts=update_consts, dimension_numbers=dimension_numbers,
|
||||
indices_are_sorted=indices_are_sorted, unique_indices=unique_indices,
|
||||
mode=mode)
|
||||
|
||||
in_bounds = scatter_in_bounds(operand, indices, updates, dimension_numbers)
|
||||
oob_msg = f'out-of-bounds indexing while updating at {summary()}'
|
||||
oob_error = assert_func(error, in_bounds, oob_msg)
|
||||
|
||||
no_nans = jnp.logical_not(jnp.any(jnp.isnan(out)))
|
||||
nan_msg = f'nan generated by primitive {prim.name} at {summary()}'
|
||||
return out, assert_func(oob_error, no_nans, nan_msg)
|
||||
error_checks[lax.scatter_p] = partial(scatter_error_check, lax.scatter_p)
|
||||
error_checks[lax.scatter_add_p] = partial(scatter_error_check, lax.scatter_add_p)
|
||||
error_checks[lax.scatter_mul_p] = partial(scatter_error_check, lax.scatter_mul_p)
|
||||
error_checks[lax.scatter_min_p] = partial(scatter_error_check, lax.scatter_min_p)
|
||||
error_checks[lax.scatter_max_p] = partial(scatter_error_check, lax.scatter_max_p)
|
||||
|
||||
def cond_error_check(error, index, *ops, branches, linear):
|
||||
new_branches, msgs_ = unzip2(checkify_jaxpr(jxpr, error) for jxpr in branches)
|
||||
new_linear = (False, False, *linear)
|
||||
@ -439,4 +485,3 @@ add_nan_check(lax.abs_p)
|
||||
add_nan_check(lax.select_p)
|
||||
add_nan_check(lax.max_p)
|
||||
add_nan_check(lax.min_p)
|
||||
add_nan_check(lax.scatter_add_p)
|
||||
|
@ -66,6 +66,23 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
|
||||
|
||||
@parameterized.named_parameters(
|
||||
{"testcase_name": f"_update={update_fn}", "update_fn": update_fn}
|
||||
for update_fn in ["set", "add", "multiply", "divide", "power", "min",
|
||||
"max", "get"])
|
||||
def test_jit_oob_update(self, update_fn):
|
||||
def f(x, i):
|
||||
return getattr(x.at[i], update_fn)(1.)
|
||||
|
||||
f = jax.jit(f)
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.arange(3), 2)
|
||||
self.assertIs(err.get(), None)
|
||||
|
||||
err, _ = checkify.checkify(f)(jnp.arange(3), 3)
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), 'out-of-bounds indexing')
|
||||
|
||||
@parameterized.named_parameters(jtu.cases_from_list(
|
||||
{"testcase_name": "_jit={}".format(jit), "jit": jit}
|
||||
for jit in [False, True]))
|
||||
|
Loading…
x
Reference in New Issue
Block a user