diff --git a/CHANGELOG.md b/CHANGELOG.md index 47ac5c746..5a533241f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -156,6 +156,7 @@ Remember to align the itemized text with the first line of an item within a list automatically. Currently, NCCL 2.16 or newer is required. * We now provide Linux aarch64 wheels, both with and without NVIDIA GPU support. + * {meth}`jax.Array.item` now supports optional index arguments. * Deprecations * A number of internal utilities and inadvertent exports in {mod}`jax.lax` have diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 7de5be2bc..fb1e52bd1 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -71,19 +71,12 @@ def _nbytes(arr: ArrayLike) -> int: return np.size(arr) * dtypes.dtype(arr, canonicalize=True).itemsize -def _item(a: Array) -> Any: +def _item(a: Array, *args) -> bool | int | float | complex: """Copy an element of an array to a standard Python scalar and return it.""" - if dtypes.issubdtype(a.dtype, np.complexfloating): - return complex(a) - elif dtypes.issubdtype(a.dtype, np.floating): - return float(a) - elif dtypes.issubdtype(a.dtype, np.integer): - return int(a) - elif dtypes.issubdtype(a.dtype, np.bool_): - return bool(a) - else: - raise TypeError(a.dtype) - + arr = core.concrete_or_error(np.asarray, a, context="This occurred in the item() method of jax.Array") + if dtypes.issubdtype(a.dtype, dtypes.extended): + raise TypeError(f"No Python scalar type for {a.dtype=}") + return arr.item(*args) def _itemsize(arr: ArrayLike) -> int: """Length of one array element in bytes.""" diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 9e83be82c..0774c24bc 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3622,6 +3622,34 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + shape=nonempty_array_shapes, + dtype=all_dtypes, + num_args=[0, 1, "all"], + use_tuple=[True, False] + ) + def testItem(self, shape, dtype, num_args, use_tuple): + rng = jtu.rand_default(self.rng()) + size = math.prod(shape) + + if num_args == 0: + args = () + elif num_args == 1: + args = (self.rng().randint(0, size),) + else: + args = tuple(self.rng().randint(0, s) for s in shape) + args = (args,) if use_tuple else args + + np_op = lambda x: np.asarray(x).item(*args) + jnp_op = lambda x: jnp.asarray(x).item(*args) + args_maker = lambda: [rng(shape, dtype)] + + if size != 1 and num_args == 0: + with self.assertRaises(ValueError): + jnp_op(*args_maker()) + else: + self._CheckAgainstNumpy(np_op, jnp_op, args_maker) + @jtu.sample_product( # Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs. shape=[(0,), (32,), (2, 16)],