mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #26883 from dfm:np-unique-sorted
PiperOrigin-RevId: 733287592
This commit is contained in:
commit
8906f281c4
@ -663,7 +663,8 @@ def _unique(ar: Array, axis: int, return_index: bool = False, return_inverse: bo
|
||||
@export
|
||||
def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = False,
|
||||
return_counts: bool = False, axis: int | None = None,
|
||||
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None):
|
||||
*, equal_nan: bool = True, size: int | None = None, fill_value: ArrayLike | None = None,
|
||||
sorted: bool = True):
|
||||
"""Return the unique values from an array.
|
||||
|
||||
JAX implementation of :func:`numpy.unique`.
|
||||
@ -686,6 +687,7 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
unique elements than ``size`` indicates, the return value will be padded with ``fill_value``.
|
||||
fill_value: when ``size`` is specified and there are fewer than the indicated number of
|
||||
elements, fill the remaining entries ``fill_value``. Defaults to the minimum unique value.
|
||||
sorted: unused by JAX.
|
||||
|
||||
Returns:
|
||||
An array or tuple of arrays, depending on the values of ``return_index``, ``return_inverse``,
|
||||
@ -830,6 +832,10 @@ def unique(ar: ArrayLike, return_index: bool = False, return_inverse: bool = Fal
|
||||
>>> print(counts)
|
||||
[2 1]
|
||||
"""
|
||||
# TODO: Investigate if it's possible that we could save some work in
|
||||
# _unique_sorted_mask when sorting is not requested, but that would require
|
||||
# refactoring the implementation a bit.
|
||||
del sorted # unused
|
||||
arr = ensure_arraylike("unique", ar)
|
||||
if size is None:
|
||||
arr = core.concrete_or_error(None, arr,
|
||||
|
@ -2127,7 +2127,7 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
if jtu.numpy_version() < (2, 0, 0):
|
||||
np_fun = np.unique
|
||||
else:
|
||||
np_fun = np.unique_values
|
||||
np_fun = lambda *args: np.sort(np.unique_values(*args))
|
||||
self._CheckAgainstNumpy(jnp.unique_values, np_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
@ -6452,9 +6452,10 @@ class NumpySignaturesTest(jtu.JaxTestCase):
|
||||
'compress': ['size', 'fill_value'],
|
||||
'einsum': ['subscripts', 'precision'],
|
||||
'einsum_path': ['subscripts'],
|
||||
'fill_diagonal': ['inplace'],
|
||||
'load': ['args', 'kwargs'],
|
||||
'take_along_axis': ['mode', 'fill_value'],
|
||||
'fill_diagonal': ['inplace'],
|
||||
'unique': ['size', 'fill_value'],
|
||||
}
|
||||
|
||||
mismatches = {}
|
||||
|
Loading…
x
Reference in New Issue
Block a user