mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 21:36:05 +00:00
Merge pull request #22445 from dfm:numpy-nightly-unique
PiperOrigin-RevId: 652520243
This commit is contained in:
commit
26ec43f9e5
@ -490,8 +490,6 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
||||
inv_idx = inv_idx.at[perm].set(imask)
|
||||
else:
|
||||
inv_idx = zeros(ar.shape[axis], dtype=int)
|
||||
if ar.ndim > 1:
|
||||
inv_idx = lax.expand_dims(inv_idx, [i for i in range(ar.ndim) if i != axis],)
|
||||
ret += (inv_idx,)
|
||||
if return_counts:
|
||||
if aux.size:
|
||||
@ -550,7 +548,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
``ar[unique_index]`` is equivalent to ``unique_values``.
|
||||
- ``unique_inverse``:
|
||||
*(returned only if return_inverse is True)* An array of shape ``(ar.size,)`` if ``axis``
|
||||
is None, or of shape ``(1, 1, ..., ar.shape[axis], 1, ... 1)`` if ``axis`` is specified.
|
||||
is None, or of shape ``(ar.shape[axis],)`` if ``axis`` is specified.
|
||||
Contains the indices within ``unique_values`` of each value in ``ar``. For 1D inputs,
|
||||
``unique_values[unique_inverse]`` is equivalent to ``ar``.
|
||||
- ``unique_counts``:
|
||||
@ -652,10 +650,10 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
Array(True, dtype=bool)
|
||||
|
||||
In multiple dimensions, the input can be reconstructed using
|
||||
:func:`jax.numpy.take_along_axis`:
|
||||
:func:`jax.numpy.take`:
|
||||
|
||||
>>> values, inverse = jnp.unique(M, axis=0, return_inverse=True)
|
||||
>>> jnp.all(jnp.take_along_axis(values, inverse, axis=0) == M)
|
||||
>>> jnp.all(jnp.take(values, inverse, axis=0) == M)
|
||||
Array(True, dtype=bool)
|
||||
|
||||
**Returning counts**
|
||||
|
@ -91,15 +91,15 @@ def np_unique_backport(ar, return_index=False, return_inverse=False, return_coun
|
||||
# Wrapper for np.unique, handling the change to inverse_indices in numpy 2.0
|
||||
result = np.unique(ar, return_index=return_index, return_inverse=return_inverse,
|
||||
return_counts=return_counts, axis=axis, **kwds)
|
||||
if jtu.numpy_version() >= (2, 0, 0) or np.ndim(ar) == 1 or not return_inverse:
|
||||
if jtu.numpy_version() >= (2, 0, 1) or np.ndim(ar) == 1 or not return_inverse:
|
||||
return result
|
||||
|
||||
idx = 2 if return_index else 1
|
||||
inverse_indices = result[idx]
|
||||
if axis is None:
|
||||
inverse_indices = inverse_indices.reshape(np.shape(ar))
|
||||
else:
|
||||
inverse_indices = np.expand_dims(inverse_indices, [i for i in range(np.ndim(ar)) if i != axis])
|
||||
elif jtu.numpy_version() == (2, 0, 0):
|
||||
inverse_indices = inverse_indices.reshape(-1)
|
||||
return (*result[:idx], inverse_indices, *result[idx + 1:])
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user