mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Add item() method to abstract arrays
This commit is contained in:
parent
0d71bff7b9
commit
d2908af8de
@ -7052,6 +7052,7 @@ def _set_shaped_array_attributes(shaped_array):
|
||||
setattr(shaped_array, "split", core.aval_method(split))
|
||||
setattr(shaped_array, "compress", _compress_method)
|
||||
setattr(shaped_array, "at", core.aval_property(_IndexUpdateHelper))
|
||||
setattr(shaped_array, "item", core.aval_method(device_array.DeviceArray.item))
|
||||
|
||||
_set_shaped_array_attributes(ShapedArray)
|
||||
|
||||
|
@ -825,6 +825,23 @@ class PythonJitTest(CPPJitTest):
|
||||
|
||||
class APITest(jtu.JaxTestCase):
|
||||
|
||||
def test_grad_item(self):
|
||||
def f(x):
|
||||
if x.astype(bool).item():
|
||||
return x ** 2
|
||||
else:
|
||||
return x
|
||||
out = jax.grad(f)(2.0)
|
||||
self.assertEqual(out, 4)
|
||||
|
||||
def test_jit_item(self):
|
||||
def f(x):
|
||||
return x.item()
|
||||
x = jnp.array(1.0)
|
||||
self.assertEqual(f(x), x)
|
||||
with self.assertRaisesRegex(core.ConcretizationTypeError, "Abstract tracer value"):
|
||||
jax.jit(f)(x)
|
||||
|
||||
def test_grad_bad_input(self):
|
||||
def f(x):
|
||||
return x
|
||||
|
Loading…
x
Reference in New Issue
Block a user