jnp.ndarray.item(): add args support

This commit is contained in:
Jake VanderPlas 2024-01-03 13:03:47 -08:00
parent b6136795dd
commit 47e5c81a2c
3 changed files with 34 additions and 12 deletions

View File

@ -156,6 +156,7 @@ Remember to align the itemized text with the first line of an item within a list
automatically. Currently, NCCL 2.16 or newer is required.
* We now provide Linux aarch64 wheels, both with and without NVIDIA GPU
support.
* {meth}`jax.Array.item` now supports optional index arguments.
* Deprecations
* A number of internal utilities and inadvertent exports in {mod}`jax.lax` have

View File

@ -71,19 +71,12 @@ def _nbytes(arr: ArrayLike) -> int:
return np.size(arr) * dtypes.dtype(arr, canonicalize=True).itemsize
def _item(a: Array) -> Any:
def _item(a: Array, *args) -> bool | int | float | complex:
"""Copy an element of an array to a standard Python scalar and return it."""
if dtypes.issubdtype(a.dtype, np.complexfloating):
return complex(a)
elif dtypes.issubdtype(a.dtype, np.floating):
return float(a)
elif dtypes.issubdtype(a.dtype, np.integer):
return int(a)
elif dtypes.issubdtype(a.dtype, np.bool_):
return bool(a)
else:
raise TypeError(a.dtype)
arr = core.concrete_or_error(np.asarray, a, context="This occurred in the item() method of jax.Array")
if dtypes.issubdtype(a.dtype, dtypes.extended):
raise TypeError(f"No Python scalar type for {a.dtype=}")
return arr.item(*args)
def _itemsize(arr: ArrayLike) -> int:
"""Length of one array element in bytes."""

View File

@ -3622,6 +3622,34 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
self._CompileAndCheck(jnp_op, args_maker)
@jtu.sample_product(
shape=nonempty_array_shapes,
dtype=all_dtypes,
num_args=[0, 1, "all"],
use_tuple=[True, False]
)
def testItem(self, shape, dtype, num_args, use_tuple):
rng = jtu.rand_default(self.rng())
size = math.prod(shape)
if num_args == 0:
args = ()
elif num_args == 1:
args = (self.rng().randint(0, size),)
else:
args = tuple(self.rng().randint(0, s) for s in shape)
args = (args,) if use_tuple else args
np_op = lambda x: np.asarray(x).item(*args)
jnp_op = lambda x: jnp.asarray(x).item(*args)
args_maker = lambda: [rng(shape, dtype)]
if size != 1 and num_args == 0:
with self.assertRaises(ValueError):
jnp_op(*args_maker())
else:
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
@jtu.sample_product(
# Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs.
shape=[(0,), (32,), (2, 16)],