mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
[array api] add stable & descending params to jnp.sort & jnp.argsort
This commit is contained in:
parent
ebc7af95df
commit
8b62516676
@ -17,6 +17,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
devices to create `Sharding`s during lowering.
|
devices to create `Sharding`s during lowering.
|
||||||
This is a temporary state until we can create `Sharding`s without physical
|
This is a temporary state until we can create `Sharding`s without physical
|
||||||
devices.
|
devices.
|
||||||
|
* {func}`jax.numpy.argsort` and {func}`jax.numpy.sort` now support the `stable`
|
||||||
|
and `descending` arguments.
|
||||||
* Deprecations & Removals
|
* Deprecations & Removals
|
||||||
* A number of previously deprecated functions have been removed, following a
|
* A number of previously deprecated functions have been removed, following a
|
||||||
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
|
standard 3+ month deprecation cycle (see {ref}`api-compatibility`).
|
||||||
|
@ -3901,23 +3901,28 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False):
|
|||||||
|
|
||||||
|
|
||||||
@util._wraps(np.sort)
|
@util._wraps(np.sort)
|
||||||
@partial(jit, static_argnames=('axis', 'kind', 'order'))
|
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
|
||||||
def sort(
|
def sort(
|
||||||
a: ArrayLike,
|
a: ArrayLike,
|
||||||
axis: int | None = -1,
|
axis: int | None = -1,
|
||||||
kind: str = "quicksort",
|
kind: None = None,
|
||||||
order: None = None,
|
order: None = None, *,
|
||||||
|
stable: bool = True,
|
||||||
|
descending: bool = False,
|
||||||
) -> Array:
|
) -> Array:
|
||||||
util.check_arraylike("sort", a)
|
util.check_arraylike("sort", a)
|
||||||
if kind != 'quicksort':
|
if kind is not None:
|
||||||
warnings.warn("'kind' argument to sort is ignored.")
|
warnings.warn("'kind' argument to sort is ignored.")
|
||||||
if order is not None:
|
if order is not None:
|
||||||
raise ValueError("'order' argument to sort is not supported.")
|
raise ValueError("'order' argument to sort is not supported.")
|
||||||
|
|
||||||
if axis is None:
|
if axis is None:
|
||||||
return lax.sort(ravel(a), dimension=0)
|
arr = ravel(a)
|
||||||
|
axis = 0
|
||||||
else:
|
else:
|
||||||
return lax.sort(asarray(a), dimension=_canonicalize_axis(axis, ndim(a)))
|
arr = asarray(a)
|
||||||
|
dimension = _canonicalize_axis(axis, arr.ndim)
|
||||||
|
result = lax.sort(arr, dimension=dimension, is_stable=stable)
|
||||||
|
return lax.rev(result, dimensions=[dimension]) if descending else result
|
||||||
|
|
||||||
|
|
||||||
@util._wraps(np.sort_complex)
|
@util._wraps(np.sort_complex)
|
||||||
@ -3953,29 +3958,37 @@ a warning and be treated as if they were :code:`'stable'`.
|
|||||||
|
|
||||||
|
|
||||||
@util._wraps(np.argsort, lax_description=_ARGSORT_DOC)
|
@util._wraps(np.argsort, lax_description=_ARGSORT_DOC)
|
||||||
@partial(jit, static_argnames=('axis', 'kind', 'order'))
|
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
|
||||||
def argsort(
|
def argsort(
|
||||||
a: ArrayLike,
|
a: ArrayLike,
|
||||||
axis: int | None = -1,
|
axis: int | None = -1,
|
||||||
kind: str = "stable",
|
kind: None = None,
|
||||||
order: None = None,
|
order: None = None,
|
||||||
|
*, stable: bool = True,
|
||||||
|
descending: bool = False,
|
||||||
) -> Array:
|
) -> Array:
|
||||||
util.check_arraylike("argsort", a)
|
util.check_arraylike("argsort", a)
|
||||||
arr = asarray(a)
|
arr = asarray(a)
|
||||||
if kind != 'stable':
|
if kind is not None:
|
||||||
warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts "
|
warnings.warn("'kind' argument to argsort is ignored.")
|
||||||
"are supported.")
|
|
||||||
if order is not None:
|
if order is not None:
|
||||||
raise ValueError("'order' argument to argsort is not supported.")
|
raise ValueError("'order' argument to argsort is not supported.")
|
||||||
|
|
||||||
if axis is None:
|
if axis is None:
|
||||||
return argsort(arr.ravel(), 0)
|
arr = ravel(arr)
|
||||||
|
axis = 0
|
||||||
else:
|
else:
|
||||||
axis_num = _canonicalize_axis(axis, arr.ndim)
|
arr = asarray(a)
|
||||||
use_64bit_index = not core.is_constant_dim(arr.shape[axis_num]) or arr.shape[axis_num] >= (1 << 31)
|
dimension = _canonicalize_axis(axis, arr.ndim)
|
||||||
iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, axis_num)
|
use_64bit_index = not core.is_constant_dim(arr.shape[dimension]) or arr.shape[dimension] >= (1 << 31)
|
||||||
_, perm = lax.sort_key_val(arr, iota, dimension=axis_num)
|
iota = lax.broadcasted_iota(int64 if use_64bit_index else int_, arr.shape, dimension)
|
||||||
return perm
|
# For stable descending sort, we reverse the array and indices to ensure that
|
||||||
|
# duplicates remain in their original order when the final indices are reversed.
|
||||||
|
# For non-stable descending sort, we can avoid these extra operations.
|
||||||
|
if descending and stable:
|
||||||
|
arr = lax.rev(arr, dimensions=[dimension])
|
||||||
|
iota = lax.rev(iota, dimensions=[dimension])
|
||||||
|
_, indices = lax.sort_key_val(arr, iota, dimension=dimension, is_stable=stable)
|
||||||
|
return lax.rev(indices, dimensions=[dimension]) if descending else indices
|
||||||
|
|
||||||
|
|
||||||
@util._wraps(np.partition, lax_description="""
|
@util._wraps(np.partition, lax_description="""
|
||||||
|
@ -19,18 +19,10 @@ from jax import Array
|
|||||||
def argsort(x: Array, /, *, axis: int = -1, descending: bool = False,
|
def argsort(x: Array, /, *, axis: int = -1, descending: bool = False,
|
||||||
stable: bool = True) -> Array:
|
stable: bool = True) -> Array:
|
||||||
"""Returns the indices that sort an array x along a specified axis."""
|
"""Returns the indices that sort an array x along a specified axis."""
|
||||||
del stable # unused
|
return jax.numpy.argsort(x, axis=axis, descending=descending, stable=stable)
|
||||||
if descending:
|
|
||||||
return jax.numpy.argsort(-x, axis=axis)
|
|
||||||
else:
|
|
||||||
return jax.numpy.argsort(x, axis=axis)
|
|
||||||
|
|
||||||
|
|
||||||
def sort(x: Array, /, *, axis: int = -1, descending: bool = False,
|
def sort(x: Array, /, *, axis: int = -1, descending: bool = False,
|
||||||
stable: bool = True) -> Array:
|
stable: bool = True) -> Array:
|
||||||
"""Returns a sorted copy of an input array x."""
|
"""Returns a sorted copy of an input array x."""
|
||||||
del stable # unused
|
return jax.numpy.sort(x, axis=axis, descending=descending, stable=stable)
|
||||||
result = jax.numpy.sort(x, axis=axis)
|
|
||||||
if descending:
|
|
||||||
return jax.lax.rev(result, dimensions=[axis + x.ndim if axis < 0 else axis])
|
|
||||||
return result
|
|
||||||
|
@ -13,8 +13,5 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
|
|||||||
array_api_tests/test_linalg.py::test_matrix_power
|
array_api_tests/test_linalg.py::test_matrix_power
|
||||||
array_api_tests/test_linalg.py::test_solve
|
array_api_tests/test_linalg.py::test_solve
|
||||||
|
|
||||||
# JAX's NaN sorting doesn't match specification
|
|
||||||
array_api_tests/test_sorting_functions.py::test_argsort
|
|
||||||
|
|
||||||
# fft test suite is buggy as of 83f0bcdc
|
# fft test suite is buggy as of 83f0bcdc
|
||||||
array_api_tests/test_fft.py
|
array_api_tests/test_fft.py
|
||||||
|
@ -74,9 +74,12 @@ def argmin(
|
|||||||
def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ...
|
def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ...
|
||||||
def argsort(
|
def argsort(
|
||||||
a: ArrayLike,
|
a: ArrayLike,
|
||||||
axis: Optional[int] = -1,
|
axis: Optional[int] = ...,
|
||||||
kind: str = "stable",
|
kind: None = ...,
|
||||||
order: None = ...,
|
order: None = ...,
|
||||||
|
*,
|
||||||
|
stable: bool = ...,
|
||||||
|
descending: bool = ...,
|
||||||
) -> Array: ...
|
) -> Array: ...
|
||||||
def argwhere(
|
def argwhere(
|
||||||
a: ArrayLike,
|
a: ArrayLike,
|
||||||
@ -701,8 +704,11 @@ sometrue = any
|
|||||||
def sort(
|
def sort(
|
||||||
a: ArrayLike,
|
a: ArrayLike,
|
||||||
axis: Optional[int] = ...,
|
axis: Optional[int] = ...,
|
||||||
kind: str = ...,
|
kind: None = ...,
|
||||||
order: None = ...,
|
order: None = ...,
|
||||||
|
*,
|
||||||
|
stable: bool = ...,
|
||||||
|
descending: bool = ...,
|
||||||
) -> Array: ...
|
) -> Array: ...
|
||||||
def sort_complex(a: ArrayLike) -> Array: ...
|
def sort_complex(a: ArrayLike) -> Array: ...
|
||||||
def split(
|
def split(
|
||||||
|
@ -79,6 +79,7 @@ inexact_dtypes = float_dtypes + complex_dtypes
|
|||||||
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
|
number_dtypes = float_dtypes + complex_dtypes + int_dtypes + unsigned_dtypes
|
||||||
all_dtypes = number_dtypes + bool_dtypes
|
all_dtypes = number_dtypes + bool_dtypes
|
||||||
|
|
||||||
|
NO_VALUE = object()
|
||||||
|
|
||||||
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
python_scalar_dtypes = [jnp.bool_, jnp.int_, jnp.float_, jnp.complex_]
|
||||||
|
|
||||||
@ -3771,21 +3772,41 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
[dict(shape=shape, axis=axis)
|
[dict(shape=shape, axis=axis)
|
||||||
for shape in nonzerodim_shapes
|
for shape in nonzerodim_shapes
|
||||||
for axis in (None, *range(len(shape)))
|
for axis in (NO_VALUE, None, *range(-len(shape), len(shape)))
|
||||||
],
|
],
|
||||||
|
stable=[True, False],
|
||||||
dtype=all_dtypes,
|
dtype=all_dtypes,
|
||||||
)
|
)
|
||||||
def testSort(self, dtype, shape, axis):
|
def testSort(self, dtype, shape, axis, stable):
|
||||||
rng = jtu.rand_some_equal(self.rng())
|
rng = jtu.rand_some_equal(self.rng()) if stable else jtu.rand_some_inf_and_nan(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
jnp_fun = jnp.sort
|
kwds = {} if axis is NO_VALUE else {'axis': axis}
|
||||||
np_fun = np.sort
|
|
||||||
if axis is not None:
|
def np_fun(arr):
|
||||||
jnp_fun = partial(jnp_fun, axis=axis)
|
# Note: numpy sort fails on NaN and Inf values with bfloat16
|
||||||
np_fun = partial(np_fun, axis=axis)
|
dtype = arr.dtype
|
||||||
|
if arr.dtype == jnp.bfloat16:
|
||||||
|
arr = arr.astype('float32')
|
||||||
|
# TODO(jakevdp): switch to stable=stable when supported by numpy.
|
||||||
|
result = np.sort(arr, kind='stable' if stable else None, **kwds)
|
||||||
|
with jtu.ignore_warning(category=RuntimeWarning, message='invalid value'):
|
||||||
|
return result.astype(dtype)
|
||||||
|
jnp_fun = partial(jnp.sort, stable=stable, **kwds)
|
||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
|
def testSortStableDescending(self):
|
||||||
|
# TODO(jakevdp): test directly against np.sort when descending is supported.
|
||||||
|
x = jnp.array([0, 1, jnp.nan, 0, 2, jnp.nan, -jnp.inf, jnp.inf])
|
||||||
|
x_sorted = jnp.array([-jnp.inf, 0, 0, 1, 2, jnp.inf, jnp.nan, jnp.nan])
|
||||||
|
argsorted_stable = jnp.array([6, 0, 3, 1, 4, 7, 2, 5])
|
||||||
|
argsorted_rev_stable = jnp.array([2, 5, 7, 4, 1, 0, 3, 6])
|
||||||
|
|
||||||
|
self.assertArraysEqual(jnp.sort(x), x_sorted)
|
||||||
|
self.assertArraysEqual(jnp.sort(x, descending=True), lax.rev(x_sorted, [0]))
|
||||||
|
self.assertArraysEqual(jnp.argsort(x), argsorted_stable)
|
||||||
|
self.assertArraysEqual(jnp.argsort(x, descending=True), argsorted_rev_stable)
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
[dict(shape=shape, axis=axis)
|
[dict(shape=shape, axis=axis)
|
||||||
for shape in one_dim_array_shapes
|
for shape in one_dim_array_shapes
|
||||||
@ -3819,21 +3840,48 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
|||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
[dict(shape=shape, axis=axis)
|
[dict(shape=shape, axis=axis)
|
||||||
for shape in nonzerodim_shapes
|
for shape in nonzerodim_shapes
|
||||||
for axis in (None, *range(len(shape)))
|
for axis in (NO_VALUE, None, *range(-len(shape), len(shape)))
|
||||||
],
|
],
|
||||||
dtype=all_dtypes,
|
dtype=all_dtypes,
|
||||||
)
|
)
|
||||||
def testArgsort(self, dtype, shape, axis):
|
def testArgsort(self, dtype, shape, axis):
|
||||||
rng = jtu.rand_some_equal(self.rng())
|
rng = jtu.rand_some_equal(self.rng())
|
||||||
args_maker = lambda: [rng(shape, dtype)]
|
args_maker = lambda: [rng(shape, dtype)]
|
||||||
jnp_fun = jnp.argsort
|
kwds = {} if axis is NO_VALUE else {'axis': axis}
|
||||||
np_fun = jtu.with_jax_dtype_defaults(np.argsort)
|
|
||||||
if axis is not None:
|
@jtu.with_jax_dtype_defaults
|
||||||
jnp_fun = partial(jnp_fun, axis=axis)
|
def np_fun(arr):
|
||||||
np_fun = partial(np_fun, axis=axis)
|
# Note: numpy sort fails on NaN and Inf values with bfloat16
|
||||||
|
if arr.dtype == jnp.bfloat16:
|
||||||
|
arr = arr.astype('float32')
|
||||||
|
# TODO(jakevdp): switch to stable=True when supported by numpy.
|
||||||
|
return np.argsort(arr, kind='stable', **kwds)
|
||||||
|
jnp_fun = partial(jnp.argsort, stable=True, **kwds)
|
||||||
|
|
||||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||||
self._CompileAndCheck(jnp_fun, args_maker)
|
self._CompileAndCheck(jnp_fun, args_maker)
|
||||||
|
|
||||||
|
@jtu.sample_product(
|
||||||
|
[dict(shape=shape, axis=axis)
|
||||||
|
for shape in nonempty_nonscalar_array_shapes
|
||||||
|
for axis in (NO_VALUE, None, *range(-len(shape), len(shape)))
|
||||||
|
],
|
||||||
|
descending=[True, False],
|
||||||
|
dtype=all_dtypes,
|
||||||
|
)
|
||||||
|
def testArgsortUnstable(self, dtype, shape, axis, descending):
|
||||||
|
# We cannot directly compare unstable argsorts, so instead check that indexed values match.
|
||||||
|
rng = jtu.rand_some_equal(self.rng())
|
||||||
|
x = rng(shape, dtype)
|
||||||
|
kwds = {} if axis is NO_VALUE else {'axis': axis}
|
||||||
|
expected = jnp.sort(x, descending=descending, stable=False, **kwds)
|
||||||
|
indices = jnp.argsort(x, descending=descending, stable=False, **kwds)
|
||||||
|
if axis is None:
|
||||||
|
actual = jnp.ravel(x)[indices]
|
||||||
|
else:
|
||||||
|
actual = jnp.take_along_axis(x, indices, axis=-1 if axis is NO_VALUE else axis)
|
||||||
|
self.assertArraysEqual(actual, expected)
|
||||||
|
|
||||||
@jtu.sample_product(
|
@jtu.sample_product(
|
||||||
[{'shape': shape, 'axis': axis, 'kth': kth}
|
[{'shape': shape, 'axis': axis, 'kth': kth}
|
||||||
for shape in nonzerodim_shapes
|
for shape in nonzerodim_shapes
|
||||||
|
Loading…
x
Reference in New Issue
Block a user