Merge pull request #26883 from dfm:np-unique-sorted

PiperOrigin-RevId: 733287592
This commit is contained in:
jax authors 2025-03-04 05:19:55 -08:00
commit 8906f281c4
2 changed files with 10 additions and 3 deletions

View File

@ -663,7 +663,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
@export
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
return_counts: bool = False, axis: int | None = None,
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None):
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None,
sorted: bool = True):
"""Return the unique values from an array.
JAX implementation of :func:`numpy.unique`.
@ -686,6 +687,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
fill_value: when ``size`` is specified and there are fewer than the indicated number of
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
sorted: unused by JAX.
Returns:
An array or tuple of arrays, depending on the values of ``return_index``, ``return_inverse``,
@ -830,6 +832,10 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
>>> print(counts)
[2 1]
"""
# TODO: Investigate if it's possible that we could save some work in
# _unique_sorted_mask when sorting is not requested, but that would require
# refactoring the implementation a bit.
del sorted # unused
arr = ensure_arraylike("unique", ar)
if size is None:
arr = core.concrete_or_error(None, arr,

View File

@ -2127,7 +2127,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
if jtu.numpy_version() < (2, 0, 0):
np_fun = np.unique
else:
np_fun = np.unique_values
np_fun = lambda *args: np.sort(np.unique_values(*args))
self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker)
@jtu.sample_product(
@ -6452,9 +6452,10 @@ class NumpySignaturesTest(jtu.JaxTestCase):
'compress': ['size', 'fill_value'],
'einsum': ['subscripts', 'precision'],
'einsum_path': ['subscripts'],
'fill_diagonal': ['inplace'],
'load': ['args', 'kwargs'],
'take_along_axis': ['mode', 'fill_value'],
'fill_diagonal': ['inplace'],
'unique': ['size', 'fill_value'],
}
mismatches = {}