Merge pull request #14919 from jakevdp:fix-checkify

PiperOrigin-RevId: 515729564
This commit is contained in:
jax authors 2023-03-10 13:53:49 -08:00
commit 054c07b025

View File

@ -74,13 +74,18 @@ class CheckifyTransformTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "out-of-bounds indexing")
@jtu.sample_product(update_fn=["set", "add", "multiply", "divide", "power",
"min", "max", "get"])
@parameterized.named_parameters(
("get", lambda x: x.get()),
("set", lambda x: x.set(1)),
("add", lambda x: x.add(1)),
("mul", lambda x: x.multiply(1)),
("div", lambda x: x.divide(1)),
("pow", lambda x: x.power(1)),
("min", lambda x: x.min(1)),
("max", lambda x: x.max(1)),
)
def test_jit_oob_update(self, update_fn):
def f(x, i):
return getattr(x.at[i], update_fn)(1)
f = jax.jit(f)
f = jax.jit(lambda x, i: update_fn(x.at[i]))
checked_f = checkify.checkify(f, errors=checkify.index_checks)
err, _ = checked_f(jnp.arange(3), 2)
@ -136,6 +141,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
@parameterized.named_parameters(
("gather", lambda x: x.get()),
("scatter_update", lambda x: x.set(1.)),
("scatter_add", lambda x: x.add(1.)),
("scatter_mul", lambda x: x.multiply(1.)),
("scatter_div", lambda x: x.divide(1.)),