checkify_test: avoid passing argument to at[i].get()

This commit is contained in:
Jake VanderPlas 2023-03-10 12:37:33 -08:00
parent 0420192d29
commit f7dec15375

View File

@ -74,13 +74,18 @@ class CheckifyTransformTests(jtu.JaxTestCase):
self.assertIsNotNone(err.get())
self.assertStartsWith(err.get(), "out-of-bounds indexing")
@jtu.sample_product(update_fn=["set", "add", "multiply", "divide", "power",
"min", "max", "get"])
@parameterized.named_parameters(
("get", lambda x: x.get()),
("set", lambda x: x.set(1)),
("add", lambda x: x.add(1)),
("mul", lambda x: x.multiply(1)),
("div", lambda x: x.divide(1)),
("pow", lambda x: x.power(1)),
("min", lambda x: x.min(1)),
("max", lambda x: x.max(1)),
)
def test_jit_oob_update(self, update_fn):
def f(x, i):
return getattr(x.at[i], update_fn)(1)
f = jax.jit(f)
f = jax.jit(lambda x, i: update_fn(x.at[i]))
checked_f = checkify.checkify(f, errors=checkify.index_checks)
err, _ = checked_f(jnp.arange(3), 2)
@ -136,6 +141,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
@parameterized.named_parameters(
("gather", lambda x: x.get()),
("scatter_update", lambda x: x.set(1.)),
("scatter_add", lambda x: x.add(1.)),
("scatter_mul", lambda x: x.multiply(1.)),
("scatter_div", lambda x: x.divide(1.)),