mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
jnp.ndarray.item(): add args support
This commit is contained in:
parent
b6136795dd
commit
47e5c81a2c
@ -156,6 +156,7 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
automatically. Currently, NCCL 2.16 or newer is required.
|
||||
* We now provide Linux aarch64 wheels, both with and without NVIDIA GPU
|
||||
support.
|
||||
* {meth}`jax.Array.item` now supports optional index arguments.
|
||||
|
||||
* Deprecations
|
||||
* A number of internal utilities and inadvertent exports in {mod}`jax.lax` have
|
||||
|
@ -71,19 +71,12 @@ def _nbytes(arr: ArrayLike) -> int:
|
||||
return np.size(arr) * dtypes.dtype(arr, canonicalize=True).itemsize
|
||||
|
||||
|
||||
def _item(a: Array) -> Any:
|
||||
def _item(a: Array, *args) -> bool | int | float | complex:
|
||||
"""Copy an element of an array to a standard Python scalar and return it."""
|
||||
if dtypes.issubdtype(a.dtype, np.complexfloating):
|
||||
return complex(a)
|
||||
elif dtypes.issubdtype(a.dtype, np.floating):
|
||||
return float(a)
|
||||
elif dtypes.issubdtype(a.dtype, np.integer):
|
||||
return int(a)
|
||||
elif dtypes.issubdtype(a.dtype, np.bool_):
|
||||
return bool(a)
|
||||
else:
|
||||
raise TypeError(a.dtype)
|
||||
|
||||
arr = core.concrete_or_error(np.asarray, a, context="This occurred in the item() method of jax.Array")
|
||||
if dtypes.issubdtype(a.dtype, dtypes.extended):
|
||||
raise TypeError(f"No Python scalar type for {a.dtype=}")
|
||||
return arr.item(*args)
|
||||
|
||||
def _itemsize(arr: ArrayLike) -> int:
|
||||
"""Length of one array element in bytes."""
|
||||
|
@ -3622,6 +3622,34 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
self._CompileAndCheck(jnp_op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
shape=nonempty_array_shapes,
|
||||
dtype=all_dtypes,
|
||||
num_args=[0, 1, "all"],
|
||||
use_tuple=[True, False]
|
||||
)
|
||||
def testItem(self, shape, dtype, num_args, use_tuple):
|
||||
rng = jtu.rand_default(self.rng())
|
||||
size = math.prod(shape)
|
||||
|
||||
if num_args == 0:
|
||||
args = ()
|
||||
elif num_args == 1:
|
||||
args = (self.rng().randint(0, size),)
|
||||
else:
|
||||
args = tuple(self.rng().randint(0, s) for s in shape)
|
||||
args = (args,) if use_tuple else args
|
||||
|
||||
np_op = lambda x: np.asarray(x).item(*args)
|
||||
jnp_op = lambda x: jnp.asarray(x).item(*args)
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
if size != 1 and num_args == 0:
|
||||
with self.assertRaises(ValueError):
|
||||
jnp_op(*args_maker())
|
||||
else:
|
||||
self._CheckAgainstNumpy(np_op, jnp_op, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
# Final dimension must be a multiple of 16 to ensure compatibilty of all dtype pairs.
|
||||
shape=[(0,), (32,), (2, 16)],
|
||||
|
Loading…
x
Reference in New Issue
Block a user