1
0
mirror of https://github.com/ROCm/jax.git synced 2025-04-19 21:36:05 +00:00

Merge pull request from dfm:numpy-nightly-unique

PiperOrigin-RevId: 652520243
This commit is contained in:
jax authors 2024-07-15 10:16:03 -07:00
commit 26ec43f9e5
2 changed files with 6 additions and 8 deletions
jax/_src/numpy
tests

@ -490,8 +490,6 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
inv_idx = inv_idx.at[perm].set(imask)
else:
inv_idx = zeros(ar.shape[axis], dtype=int)
if ar.ndim > 1:
inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],)
ret += (inv_idx,)
if return_counts:
if aux.size:
@ -550,7 +548,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
``ar[unique_index]`` is equivalent to ``unique_values``.
- ``unique_inverse``:
*(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis``
is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified.
is None, or of shape ``(ar.shape[axis],)`` if ``axis`` is specified.
Contains the indices within ``unique_values`` of each value in ``ar``. For 1D inputs,
``unique_values[unique_inverse]`` is equivalent to ``ar``.
- ``unique_counts``:
@ -652,10 +650,10 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
Array(True, dtype=bool)
In multiple dimensions, the input can be reconstructed using
:func:`jax.numpy.take_along_axis`:
:func:`jax.numpy.take`:
>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True)
>>> jnp.all(jnp.take_along_axis(values, inverse, axis=0) == M)
>>> jnp.all(jnp.take(values, inverse, axis=0) == M)
Array(True, dtype=bool)
**Returning counts**

@ -91,15 +91,15 @@ def np_unique_backport(ar, return_index=False, return_inverse=False, return_coun
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
result = np.unique(ar, return_index=return_index, return_inverse=return_inverse,
return_counts=return_counts, axis=axis, **kwds)
if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse:
if jtu.numpy_version() >= (2, 0, 1) or np.ndim(ar) == 1 or not return_inverse:
return result
idx = 2 if return_index else 1
inverse_indices = result[idx]
if axis is None:
inverse_indices = inverse_indices.reshape(np.shape(ar))
else:
inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis])
elif jtu.numpy_version() == (2, 0, 0):
inverse_indices = inverse_indices.reshape(-1)
return (*result[:idx], inverse_indices, *result[idx + 1:])