mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
checkify_test: avoid passing argument to at[i].get()
This commit is contained in:
parent
0420192d29
commit
f7dec15375
@ -74,13 +74,18 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
self.assertIsNotNone(err.get())
|
||||
self.assertStartsWith(err.get(), "out-of-bounds indexing")
|
||||
|
||||
@jtu.sample_product(update_fn=["set", "add", "multiply", "divide", "power",
|
||||
"min", "max", "get"])
|
||||
@parameterized.named_parameters(
|
||||
("get", lambda x: x.get()),
|
||||
("set", lambda x: x.set(1)),
|
||||
("add", lambda x: x.add(1)),
|
||||
("mul", lambda x: x.multiply(1)),
|
||||
("div", lambda x: x.divide(1)),
|
||||
("pow", lambda x: x.power(1)),
|
||||
("min", lambda x: x.min(1)),
|
||||
("max", lambda x: x.max(1)),
|
||||
)
|
||||
def test_jit_oob_update(self, update_fn):
|
||||
def f(x, i):
|
||||
return getattr(x.at[i], update_fn)(1)
|
||||
|
||||
f = jax.jit(f)
|
||||
f = jax.jit(lambda x, i: update_fn(x.at[i]))
|
||||
checked_f = checkify.checkify(f, errors=checkify.index_checks)
|
||||
|
||||
err, _ = checked_f(jnp.arange(3), 2)
|
||||
@ -136,6 +141,7 @@ class CheckifyTransformTests(jtu.JaxTestCase):
|
||||
|
||||
@parameterized.named_parameters(
|
||||
("gather", lambda x: x.get()),
|
||||
("scatter_update", lambda x: x.set(1.)),
|
||||
("scatter_add", lambda x: x.add(1.)),
|
||||
("scatter_mul", lambda x: x.multiply(1.)),
|
||||
("scatter_div", lambda x: x.divide(1.)),
|
||||
|
Loading…
x
Reference in New Issue
Block a user