Add item() method to abstract arrays

This commit is contained in:
Jake VanderPlas 2021-12-15 16:22:24 -08:00
parent 0d71bff7b9
commit d2908af8de
2 changed files with 18 additions and 0 deletions

View File

@ -7052,6 +7052,7 @@ def _set_shaped_array_attributes(shaped_array):
setattr(shaped_array, "split", core.aval_method(split))
setattr(shaped_array, "compress", _compress_method)
setattr(shaped_array, "at", core.aval_property(_IndexUpdateHelper))
setattr(shaped_array, "item", core.aval_method(device_array.DeviceArray.item))
_set_shaped_array_attributes(ShapedArray)

View File

@ -825,6 +825,23 @@ class PythonJitTest(CPPJitTest):
class APITest(jtu.JaxTestCase):
def test_grad_item(self):
def f(x):
if x.astype(bool).item():
return x ** 2
else:
return x
out = jax.grad(f)(2.0)
self.assertEqual(out, 4)
def test_jit_item(self):
def f(x):
return x.item()
x = jnp.array(1.0)
self.assertEqual(f(x), x)
with self.assertRaisesRegex(core.ConcretizationTypeError, "Abstract tracer value"):
jax.jit(f)(x)
def test_grad_bad_input(self):
def f(x):
return x