mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #26383 from jakevdp:jnp-sorting
PiperOrigin-RevId: 724381260
This commit is contained in:
commit
3b470b9530
@ -61,6 +61,7 @@ from jax._src.lib import xla_extension_version
|
||||
from jax._src.numpy import reductions
|
||||
from jax._src.numpy import ufuncs
|
||||
from jax._src.numpy import util
|
||||
from jax._src.numpy.sorting import argsort, sort
|
||||
from jax._src.numpy.vectorize import vectorize
|
||||
from jax._src.typing import (
|
||||
Array, ArrayLike,
|
||||
@ -10576,408 +10577,6 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False):
|
||||
return where(reductions.all(nan_mask, axis=axis, keepdims=keepdims), -1, res)
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
|
||||
def sort(
|
||||
a: ArrayLike,
|
||||
axis: int | None = -1,
|
||||
*,
|
||||
kind: None = None,
|
||||
order: None = None,
|
||||
stable: bool = True,
|
||||
descending: bool = False,
|
||||
) -> Array:
|
||||
"""Return a sorted copy of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.sort`.
|
||||
|
||||
Args:
|
||||
a: array to sort
|
||||
axis: integer axis along which to sort. Defaults to ``-1``, i.e. the last
|
||||
axis. If ``None``, then ``a`` is flattened before being sorted.
|
||||
stable: boolean specifying whether a stable sort should be used. Default=True.
|
||||
descending: boolean specifying whether to sort in descending order. Default=False.
|
||||
kind: deprecated; instead specify sort algorithm using stable=True or stable=False.
|
||||
order: not supported by JAX
|
||||
|
||||
Returns:
|
||||
Sorted array of shape ``a.shape`` (if ``axis`` is an integer) or of shape
|
||||
``(a.size,)`` (if ``axis`` is None).
|
||||
|
||||
Examples:
|
||||
Simple 1-dimensional sort
|
||||
|
||||
>>> x = jnp.array([1, 3, 5, 4, 2, 1])
|
||||
>>> jnp.sort(x)
|
||||
Array([1, 1, 2, 3, 4, 5], dtype=int32)
|
||||
|
||||
Sort along the last axis of an array:
|
||||
|
||||
>>> x = jnp.array([[2, 1, 3],
|
||||
... [4, 3, 6]])
|
||||
>>> jnp.sort(x, axis=1)
|
||||
Array([[1, 2, 3],
|
||||
[3, 4, 6]], dtype=int32)
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.argsort`: return indices of sorted values.
|
||||
- :func:`jax.numpy.lexsort`: lexicographical sort of multiple arrays.
|
||||
- :func:`jax.lax.sort`: lower-level function wrapping XLA's Sort operator.
|
||||
"""
|
||||
arr = util.ensure_arraylike("sort", a)
|
||||
if kind is not None:
|
||||
raise TypeError("'kind' argument to sort is not supported. Use"
|
||||
" stable=True or stable=False to specify sort stability.")
|
||||
if order is not None:
|
||||
raise TypeError("'order' argument to sort is not supported.")
|
||||
if axis is None:
|
||||
arr = arr.ravel()
|
||||
axis = 0
|
||||
dimension = _canonicalize_axis(axis, arr.ndim)
|
||||
result = lax.sort(arr, dimension=dimension, is_stable=stable)
|
||||
return lax.rev(result, dimensions=[dimension]) if descending else result
|
||||
|
||||
|
||||
@export
|
||||
@jit
|
||||
def sort_complex(a: ArrayLike) -> Array:
|
||||
"""Return a sorted copy of complex array.
|
||||
|
||||
JAX implementation of :func:`numpy.sort_complex`.
|
||||
|
||||
Complex numbers are sorted lexicographically, meaning by their real part
|
||||
first, and then by their imaginary part if real parts are equal.
|
||||
|
||||
Args:
|
||||
a: input array. If dtype is not complex, the array will be upcast to complex.
|
||||
|
||||
Returns:
|
||||
A sorted array of the same shape and complex dtype as the input. If ``a``
|
||||
is multi-dimensional, it is sorted along the last axis.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.sort`: Return a sorted copy of an array.
|
||||
|
||||
Examples:
|
||||
>>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j])
|
||||
>>> jnp.sort_complex(a)
|
||||
Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64)
|
||||
|
||||
Multi-dimensional arrays are sorted along the last axis:
|
||||
|
||||
>>> a = jnp.array([[5, 3, 4],
|
||||
... [6, 9, 2]])
|
||||
>>> jnp.sort_complex(a)
|
||||
Array([[3.+0.j, 4.+0.j, 5.+0.j],
|
||||
[2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64)
|
||||
"""
|
||||
a = util.ensure_arraylike("sort_complex", a)
|
||||
a = lax.sort(a)
|
||||
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
|
||||
"""Sort a sequence of keys in lexicographic order.
|
||||
|
||||
JAX implementation of :func:`numpy.lexsort`.
|
||||
|
||||
Args:
|
||||
keys: a sequence of arrays to sort; all arrays must have the same shape.
|
||||
The last key in the sequence is used as the primary key.
|
||||
axis: the axis along which to sort (default: -1).
|
||||
|
||||
Returns:
|
||||
An array of integers of shape ``keys[0].shape`` giving the indices of the
|
||||
entries in lexicographically-sorted order.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.argsort`: sort a single entry by index.
|
||||
- :func:`jax.lax.sort`: direct XLA sorting API.
|
||||
|
||||
Examples:
|
||||
:func:`lexsort` with a single key is equivalent to :func:`argsort`:
|
||||
|
||||
>>> key1 = jnp.array([4, 2, 3, 2, 5])
|
||||
>>> jnp.lexsort([key1])
|
||||
Array([1, 3, 2, 0, 4], dtype=int32)
|
||||
>>> jnp.argsort(key1)
|
||||
Array([1, 3, 2, 0, 4], dtype=int32)
|
||||
|
||||
With multiple keys, :func:`lexsort` uses the last key as the primary key:
|
||||
|
||||
>>> key2 = jnp.array([2, 1, 1, 2, 2])
|
||||
>>> jnp.lexsort([key1, key2])
|
||||
Array([1, 2, 3, 0, 4], dtype=int32)
|
||||
|
||||
The meaning of the indices become more clear when printing the sorted keys:
|
||||
|
||||
>>> indices = jnp.lexsort([key1, key2])
|
||||
>>> print(f"{key1[indices]}\\n{key2[indices]}")
|
||||
[2 3 2 4 5]
|
||||
[1 1 2 2 2]
|
||||
|
||||
Notice that the elements of ``key2`` appear in order, and within the sequences
|
||||
of duplicated values the corresponding elements of ```key1`` appear in order.
|
||||
|
||||
For multi-dimensional inputs, :func:`lexsort` defaults to sorting along the
|
||||
last axis:
|
||||
|
||||
>>> key1 = jnp.array([[2, 4, 2, 3],
|
||||
... [3, 1, 2, 2]])
|
||||
>>> key2 = jnp.array([[1, 2, 1, 3],
|
||||
... [2, 1, 2, 1]])
|
||||
>>> jnp.lexsort([key1, key2])
|
||||
Array([[0, 2, 1, 3],
|
||||
[1, 3, 2, 0]], dtype=int32)
|
||||
|
||||
A different sort axis can be chosen using the ``axis`` keyword; here we sort
|
||||
along the leading axis:
|
||||
|
||||
>>> jnp.lexsort([key1, key2], axis=0)
|
||||
Array([[0, 1, 0, 1],
|
||||
[1, 0, 1, 0]], dtype=int32)
|
||||
"""
|
||||
key_arrays = util.ensure_arraylike_tuple("lexsort", tuple(keys))
|
||||
if len(key_arrays) == 0:
|
||||
raise TypeError("need sequence of keys with len > 0 in lexsort")
|
||||
if len({shape(key) for key in key_arrays}) > 1:
|
||||
raise ValueError("all keys need to be the same shape")
|
||||
if ndim(key_arrays[0]) == 0:
|
||||
return array(0, dtype=dtypes.canonicalize_dtype(dtypes.int_))
|
||||
axis = _canonicalize_axis(axis, ndim(key_arrays[0]))
|
||||
use_64bit_index = key_arrays[0].shape[axis] >= (1 << 31)
|
||||
iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else dtypes.int_,
|
||||
shape(key_arrays[0]), axis)
|
||||
return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1]
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
|
||||
def argsort(
|
||||
a: ArrayLike,
|
||||
axis: int | None = -1,
|
||||
*,
|
||||
kind: None = None,
|
||||
order: None = None,
|
||||
stable: bool = True,
|
||||
descending: bool = False,
|
||||
) -> Array:
|
||||
"""Return indices that sort an array.
|
||||
|
||||
JAX implementation of :func:`numpy.argsort`.
|
||||
|
||||
Args:
|
||||
a: array to sort
|
||||
axis: integer axis along which to sort. Defaults to ``-1``, i.e. the last
|
||||
axis. If ``None``, then ``a`` is flattened before being sorted.
|
||||
stable: boolean specifying whether a stable sort should be used. Default=True.
|
||||
descending: boolean specifying whether to sort in descending order. Default=False.
|
||||
kind: deprecated; instead specify sort algorithm using stable=True or stable=False.
|
||||
order: not supported by JAX
|
||||
|
||||
Returns:
|
||||
Array of indices that sort an array. Returned array will be of shape ``a.shape``
|
||||
(if ``axis`` is an integer) or of shape ``(a.size,)`` (if ``axis`` is None).
|
||||
|
||||
Examples:
|
||||
Simple 1-dimensional sort
|
||||
|
||||
>>> x = jnp.array([1, 3, 5, 4, 2, 1])
|
||||
>>> indices = jnp.argsort(x)
|
||||
>>> indices
|
||||
Array([0, 5, 4, 1, 3, 2], dtype=int32)
|
||||
>>> x[indices]
|
||||
Array([1, 1, 2, 3, 4, 5], dtype=int32)
|
||||
|
||||
Sort along the last axis of an array:
|
||||
|
||||
>>> x = jnp.array([[2, 1, 3],
|
||||
... [6, 4, 3]])
|
||||
>>> indices = jnp.argsort(x, axis=1)
|
||||
>>> indices
|
||||
Array([[1, 0, 2],
|
||||
[2, 1, 0]], dtype=int32)
|
||||
>>> jnp.take_along_axis(x, indices, axis=1)
|
||||
Array([[1, 2, 3],
|
||||
[3, 4, 6]], dtype=int32)
|
||||
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.sort`: return sorted values directly.
|
||||
- :func:`jax.numpy.lexsort`: lexicographical sort of multiple arrays.
|
||||
- :func:`jax.lax.sort`: lower-level function wrapping XLA's Sort operator.
|
||||
"""
|
||||
arr = util.ensure_arraylike("argsort", a)
|
||||
if kind is not None:
|
||||
raise TypeError("'kind' argument to argsort is not supported. Use"
|
||||
" stable=True or stable=False to specify sort stability.")
|
||||
if order is not None:
|
||||
raise TypeError("'order' argument to argsort is not supported.")
|
||||
if axis is None:
|
||||
arr = arr.ravel()
|
||||
axis = 0
|
||||
dimension = _canonicalize_axis(axis, arr.ndim)
|
||||
use_64bit_index = not core.is_constant_dim(arr.shape[dimension]) or arr.shape[dimension] >= (1 << 31)
|
||||
iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else dtypes.int_, arr.shape, dimension)
|
||||
# For stable descending sort, we reverse the array and indices to ensure that
|
||||
# duplicates remain in their original order when the final indices are reversed.
|
||||
# For non-stable descending sort, we can avoid these extra operations.
|
||||
if descending and stable:
|
||||
arr = lax.rev(arr, dimensions=[dimension])
|
||||
iota = lax.rev(iota, dimensions=[dimension])
|
||||
_, indices = lax.sort_key_val(arr, iota, dimension=dimension, is_stable=stable)
|
||||
return lax.rev(indices, dimensions=[dimension]) if descending else indices
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=['kth', 'axis'])
|
||||
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
|
||||
"""Returns a partially-sorted copy of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.partition`. The JAX version differs from
|
||||
NumPy in the treatment of NaN entries: NaNs which have the negative bit set
|
||||
are sorted to the beginning of the array.
|
||||
|
||||
Args:
|
||||
a: array to be partitioned.
|
||||
kth: static integer index about which to partition the array.
|
||||
axis: static integer axis along which to partition the array; default is -1.
|
||||
|
||||
Returns:
|
||||
A copy of ``a`` partitioned at the ``kth`` value along ``axis``. The entries
|
||||
before ``kth`` are values smaller than ``take(a, kth, axis)``, and entries
|
||||
after ``kth`` are indices of values larger than ``take(a, kth, axis)``
|
||||
|
||||
Note:
|
||||
The JAX version requires the ``kth`` argument to be a static integer rather than
|
||||
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
|
||||
you're only accessing the top or bottom k values of the output, it may be more
|
||||
efficient to call :func:`jax.lax.top_k` directly.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.sort`: full sort
|
||||
- :func:`jax.numpy.argpartition`: indirect partial sort
|
||||
- :func:`jax.lax.top_k`: directly find the top k entries
|
||||
- :func:`jax.lax.approx_max_k`: compute the approximate top k entries
|
||||
- :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
|
||||
>>> kth = 4
|
||||
>>> x_partitioned = jnp.partition(x, kth)
|
||||
>>> x_partitioned
|
||||
Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
|
||||
|
||||
The result is a partially-sorted copy of the input. All values before ``kth``
|
||||
are of smaller than the pivot value, and all values after ``kth`` are larger
|
||||
than the pivot value:
|
||||
|
||||
>>> smallest_values = x_partitioned[:kth]
|
||||
>>> pivot_value = x_partitioned[kth]
|
||||
>>> largest_values = x_partitioned[kth + 1:]
|
||||
>>> print(smallest_values, pivot_value, largest_values)
|
||||
[1 2 3 3] 4 [9 8 7 6 5]
|
||||
|
||||
Notice that among ``smallest_values`` and ``largest_values``, the returned
|
||||
order is arbitrary and implementation-dependent.
|
||||
"""
|
||||
# TODO(jakevdp): handle NaN values like numpy.
|
||||
arr = util.ensure_arraylike("partition", a)
|
||||
if issubdtype(arr.dtype, np.complexfloating):
|
||||
raise NotImplementedError("jnp.partition for complex dtype is not implemented.")
|
||||
axis = _canonicalize_axis(axis, arr.ndim)
|
||||
kth = _canonicalize_axis(kth, arr.shape[axis])
|
||||
|
||||
arr = swapaxes(arr, axis, -1)
|
||||
if dtypes.isdtype(arr.dtype, "unsigned integer"):
|
||||
# Here, we apply a trick to handle correctly 0 values for unsigned integers
|
||||
bottom = -lax.top_k(-(arr + 1), kth + 1)[0] - 1
|
||||
else:
|
||||
bottom = -lax.top_k(-arr, kth + 1)[0]
|
||||
top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0]
|
||||
out = lax.concatenate([bottom, top], dimension=arr.ndim - 1)
|
||||
return swapaxes(out, -1, axis)
|
||||
|
||||
|
||||
@export
|
||||
@partial(jit, static_argnames=['kth', 'axis'])
|
||||
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
|
||||
"""Returns indices that partially sort an array.
|
||||
|
||||
JAX implementation of :func:`numpy.argpartition`. The JAX version differs from
|
||||
NumPy in the treatment of NaN entries: NaNs which have the negative bit set are
|
||||
sorted to the beginning of the array.
|
||||
|
||||
Args:
|
||||
a: array to be partitioned.
|
||||
kth: static integer index about which to partition the array.
|
||||
axis: static integer axis along which to partition the array; default is -1.
|
||||
|
||||
Returns:
|
||||
Indices which partition ``a`` at the ``kth`` value along ``axis``. The entries
|
||||
before ``kth`` are indices of values smaller than ``take(a, kth, axis)``, and
|
||||
entries after ``kth`` are indices of values larger than ``take(a, kth, axis)``
|
||||
|
||||
Note:
|
||||
The JAX version requires the ``kth`` argument to be a static integer rather than
|
||||
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
|
||||
you're only accessing the top or bottom k values of the output, it may be more
|
||||
efficient to call :func:`jax.lax.top_k` directly.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.partition`: direct partial sort
|
||||
- :func:`jax.numpy.argsort`: full indirect sort
|
||||
- :func:`jax.lax.top_k`: directly find the top k entries
|
||||
- :func:`jax.lax.approx_max_k`: compute the approximate top k entries
|
||||
- :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
|
||||
>>> kth = 4
|
||||
>>> idx = jnp.argpartition(x, kth)
|
||||
>>> idx
|
||||
Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
|
||||
|
||||
The result is a sequence of indices that partially sort the input. All indices
|
||||
before ``kth`` are of values smaller than the pivot value, and all indices
|
||||
after ``kth`` are of values larger than the pivot value:
|
||||
|
||||
>>> x_partitioned = x[idx]
|
||||
>>> smallest_values = x_partitioned[:kth]
|
||||
>>> pivot_value = x_partitioned[kth]
|
||||
>>> largest_values = x_partitioned[kth + 1:]
|
||||
>>> print(smallest_values, pivot_value, largest_values)
|
||||
[1 2 3 3] 4 [6 8 9 7 5]
|
||||
|
||||
Notice that among ``smallest_values`` and ``largest_values``, the returned
|
||||
order is arbitrary and implementation-dependent.
|
||||
"""
|
||||
# TODO(jakevdp): handle NaN values like numpy.
|
||||
arr = util.ensure_arraylike("partition", a)
|
||||
if issubdtype(arr.dtype, np.complexfloating):
|
||||
raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.")
|
||||
axis = _canonicalize_axis(axis, arr.ndim)
|
||||
kth = _canonicalize_axis(kth, arr.shape[axis])
|
||||
|
||||
arr = swapaxes(arr, axis, -1)
|
||||
if dtypes.isdtype(arr.dtype, "unsigned integer"):
|
||||
# Here, we apply a trick to handle correctly 0 values for unsigned integers
|
||||
bottom_ind = lax.top_k(-(arr + 1), kth + 1)[1]
|
||||
else:
|
||||
bottom_ind = lax.top_k(-arr, kth + 1)[1]
|
||||
|
||||
# To avoid issues with duplicate values, we compute the top indices via a proxy
|
||||
set_to_zero = lambda a, i: a.at[i].set(0)
|
||||
for _ in range(arr.ndim - 1):
|
||||
set_to_zero = jax.vmap(set_to_zero)
|
||||
proxy = set_to_zero(ones(arr.shape), bottom_ind)
|
||||
top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1]
|
||||
out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1)
|
||||
return swapaxes(out, -1, axis)
|
||||
|
||||
|
||||
@partial(jit, static_argnums=(2,))
|
||||
def _roll_dynamic(a: Array, shift: Array, axis: Sequence[int]) -> Array:
|
||||
b_shape = lax.broadcast_shapes(shift.shape, np.shape(axis))
|
||||
|
@ -29,9 +29,10 @@ from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.lax import lax as lax_internal
|
||||
from jax._src.numpy.lax_numpy import (
|
||||
append, arange, concatenate, diff, empty, full, full_like, lexsort,
|
||||
append, arange, concatenate, diff, empty, full, full_like,
|
||||
moveaxis, nonzero, ones, ravel, sort, where, zeros)
|
||||
from jax._src.numpy.reductions import any, cumsum
|
||||
from jax._src.numpy.sorting import lexsort
|
||||
from jax._src.numpy.ufuncs import isnan
|
||||
from jax._src.numpy.util import ensure_arraylike, promote_dtypes
|
||||
from jax._src.util import canonicalize_axis, set_module
|
||||
|
429
jax/_src/numpy/sorting.py
Normal file
429
jax/_src/numpy/sorting.py
Normal file
@ -0,0 +1,429 @@
|
||||
# Copyright 2025 The JAX Authors.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# https://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import partial
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
import jax
|
||||
from jax._src import api
|
||||
from jax._src import core
|
||||
from jax._src import dtypes
|
||||
from jax._src.numpy import util
|
||||
from jax._src.util import canonicalize_axis, set_module
|
||||
from jax._src.typing import Array, ArrayLike
|
||||
from jax import lax
|
||||
|
||||
export = set_module('jax.numpy')
|
||||
|
||||
@export
|
||||
@partial(api.jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
|
||||
def sort(
|
||||
a: ArrayLike,
|
||||
axis: int | None = -1,
|
||||
*,
|
||||
kind: None = None,
|
||||
order: None = None,
|
||||
stable: bool = True,
|
||||
descending: bool = False,
|
||||
) -> Array:
|
||||
"""Return a sorted copy of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.sort`.
|
||||
|
||||
Args:
|
||||
a: array to sort
|
||||
axis: integer axis along which to sort. Defaults to ``-1``, i.e. the last
|
||||
axis. If ``None``, then ``a`` is flattened before being sorted.
|
||||
stable: boolean specifying whether a stable sort should be used. Default=True.
|
||||
descending: boolean specifying whether to sort in descending order. Default=False.
|
||||
kind: deprecated; instead specify sort algorithm using stable=True or stable=False.
|
||||
order: not supported by JAX
|
||||
|
||||
Returns:
|
||||
Sorted array of shape ``a.shape`` (if ``axis`` is an integer) or of shape
|
||||
``(a.size,)`` (if ``axis`` is None).
|
||||
|
||||
Examples:
|
||||
Simple 1-dimensional sort
|
||||
|
||||
>>> x = jnp.array([1, 3, 5, 4, 2, 1])
|
||||
>>> jnp.sort(x)
|
||||
Array([1, 1, 2, 3, 4, 5], dtype=int32)
|
||||
|
||||
Sort along the last axis of an array:
|
||||
|
||||
>>> x = jnp.array([[2, 1, 3],
|
||||
... [4, 3, 6]])
|
||||
>>> jnp.sort(x, axis=1)
|
||||
Array([[1, 2, 3],
|
||||
[3, 4, 6]], dtype=int32)
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.argsort`: return indices of sorted values.
|
||||
- :func:`jax.numpy.lexsort`: lexicographical sort of multiple arrays.
|
||||
- :func:`jax.lax.sort`: lower-level function wrapping XLA's Sort operator.
|
||||
"""
|
||||
arr = util.ensure_arraylike("sort", a)
|
||||
if kind is not None:
|
||||
raise TypeError("'kind' argument to sort is not supported. Use"
|
||||
" stable=True or stable=False to specify sort stability.")
|
||||
if order is not None:
|
||||
raise TypeError("'order' argument to sort is not supported.")
|
||||
if axis is None:
|
||||
arr = arr.ravel()
|
||||
axis = 0
|
||||
dimension = canonicalize_axis(axis, arr.ndim)
|
||||
result = lax.sort(arr, dimension=dimension, is_stable=stable)
|
||||
return lax.rev(result, dimensions=[dimension]) if descending else result
|
||||
|
||||
@export
|
||||
@partial(api.jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
|
||||
def argsort(
|
||||
a: ArrayLike,
|
||||
axis: int | None = -1,
|
||||
*,
|
||||
kind: None = None,
|
||||
order: None = None,
|
||||
stable: bool = True,
|
||||
descending: bool = False,
|
||||
) -> Array:
|
||||
"""Return indices that sort an array.
|
||||
|
||||
JAX implementation of :func:`numpy.argsort`.
|
||||
|
||||
Args:
|
||||
a: array to sort
|
||||
axis: integer axis along which to sort. Defaults to ``-1``, i.e. the last
|
||||
axis. If ``None``, then ``a`` is flattened before being sorted.
|
||||
stable: boolean specifying whether a stable sort should be used. Default=True.
|
||||
descending: boolean specifying whether to sort in descending order. Default=False.
|
||||
kind: deprecated; instead specify sort algorithm using stable=True or stable=False.
|
||||
order: not supported by JAX
|
||||
|
||||
Returns:
|
||||
Array of indices that sort an array. Returned array will be of shape ``a.shape``
|
||||
(if ``axis`` is an integer) or of shape ``(a.size,)`` (if ``axis`` is None).
|
||||
|
||||
Examples:
|
||||
Simple 1-dimensional sort
|
||||
|
||||
>>> x = jnp.array([1, 3, 5, 4, 2, 1])
|
||||
>>> indices = jnp.argsort(x)
|
||||
>>> indices
|
||||
Array([0, 5, 4, 1, 3, 2], dtype=int32)
|
||||
>>> x[indices]
|
||||
Array([1, 1, 2, 3, 4, 5], dtype=int32)
|
||||
|
||||
Sort along the last axis of an array:
|
||||
|
||||
>>> x = jnp.array([[2, 1, 3],
|
||||
... [6, 4, 3]])
|
||||
>>> indices = jnp.argsort(x, axis=1)
|
||||
>>> indices
|
||||
Array([[1, 0, 2],
|
||||
[2, 1, 0]], dtype=int32)
|
||||
>>> jnp.take_along_axis(x, indices, axis=1)
|
||||
Array([[1, 2, 3],
|
||||
[3, 4, 6]], dtype=int32)
|
||||
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.sort`: return sorted values directly.
|
||||
- :func:`jax.numpy.lexsort`: lexicographical sort of multiple arrays.
|
||||
- :func:`jax.lax.sort`: lower-level function wrapping XLA's Sort operator.
|
||||
"""
|
||||
arr = util.ensure_arraylike("argsort", a)
|
||||
if kind is not None:
|
||||
raise TypeError("'kind' argument to argsort is not supported. Use"
|
||||
" stable=True or stable=False to specify sort stability.")
|
||||
if order is not None:
|
||||
raise TypeError("'order' argument to argsort is not supported.")
|
||||
if axis is None:
|
||||
arr = arr.ravel()
|
||||
axis = 0
|
||||
dimension = canonicalize_axis(axis, arr.ndim)
|
||||
use_64bit_index = not core.is_constant_dim(arr.shape[dimension]) or arr.shape[dimension] >= (1 << 31)
|
||||
iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else dtypes.int_, arr.shape, dimension)
|
||||
# For stable descending sort, we reverse the array and indices to ensure that
|
||||
# duplicates remain in their original order when the final indices are reversed.
|
||||
# For non-stable descending sort, we can avoid these extra operations.
|
||||
if descending and stable:
|
||||
arr = lax.rev(arr, dimensions=[dimension])
|
||||
iota = lax.rev(iota, dimensions=[dimension])
|
||||
_, indices = lax.sort_key_val(arr, iota, dimension=dimension, is_stable=stable)
|
||||
return lax.rev(indices, dimensions=[dimension]) if descending else indices
|
||||
|
||||
|
||||
@export
|
||||
@partial(api.jit, static_argnames=['kth', 'axis'])
|
||||
def partition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
|
||||
"""Returns a partially-sorted copy of an array.
|
||||
|
||||
JAX implementation of :func:`numpy.partition`. The JAX version differs from
|
||||
NumPy in the treatment of NaN entries: NaNs which have the negative bit set
|
||||
are sorted to the beginning of the array.
|
||||
|
||||
Args:
|
||||
a: array to be partitioned.
|
||||
kth: static integer index about which to partition the array.
|
||||
axis: static integer axis along which to partition the array; default is -1.
|
||||
|
||||
Returns:
|
||||
A copy of ``a`` partitioned at the ``kth`` value along ``axis``. The entries
|
||||
before ``kth`` are values smaller than ``take(a, kth, axis)``, and entries
|
||||
after ``kth`` are indices of values larger than ``take(a, kth, axis)``
|
||||
|
||||
Note:
|
||||
The JAX version requires the ``kth`` argument to be a static integer rather than
|
||||
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
|
||||
you're only accessing the top or bottom k values of the output, it may be more
|
||||
efficient to call :func:`jax.lax.top_k` directly.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.sort`: full sort
|
||||
- :func:`jax.numpy.argpartition`: indirect partial sort
|
||||
- :func:`jax.lax.top_k`: directly find the top k entries
|
||||
- :func:`jax.lax.approx_max_k`: compute the approximate top k entries
|
||||
- :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
|
||||
>>> kth = 4
|
||||
>>> x_partitioned = jnp.partition(x, kth)
|
||||
>>> x_partitioned
|
||||
Array([1, 2, 3, 3, 4, 9, 8, 7, 6, 5], dtype=int32)
|
||||
|
||||
The result is a partially-sorted copy of the input. All values before ``kth``
|
||||
are of smaller than the pivot value, and all values after ``kth`` are larger
|
||||
than the pivot value:
|
||||
|
||||
>>> smallest_values = x_partitioned[:kth]
|
||||
>>> pivot_value = x_partitioned[kth]
|
||||
>>> largest_values = x_partitioned[kth + 1:]
|
||||
>>> print(smallest_values, pivot_value, largest_values)
|
||||
[1 2 3 3] 4 [9 8 7 6 5]
|
||||
|
||||
Notice that among ``smallest_values`` and ``largest_values``, the returned
|
||||
order is arbitrary and implementation-dependent.
|
||||
"""
|
||||
# TODO(jakevdp): handle NaN values like numpy.
|
||||
arr = util.ensure_arraylike("partition", a)
|
||||
if dtypes.issubdtype(arr.dtype, np.complexfloating):
|
||||
raise NotImplementedError("jnp.partition for complex dtype is not implemented.")
|
||||
axis = canonicalize_axis(axis, arr.ndim)
|
||||
kth = canonicalize_axis(kth, arr.shape[axis])
|
||||
|
||||
arr = jax.numpy.swapaxes(arr, axis, -1)
|
||||
if dtypes.isdtype(arr.dtype, "unsigned integer"):
|
||||
# Here, we apply a trick to handle correctly 0 values for unsigned integers
|
||||
bottom = -lax.top_k(-(arr + 1), kth + 1)[0] - 1
|
||||
else:
|
||||
bottom = -lax.top_k(-arr, kth + 1)[0]
|
||||
top = lax.top_k(arr, arr.shape[-1] - kth - 1)[0]
|
||||
out = lax.concatenate([bottom, top], dimension=arr.ndim - 1)
|
||||
return jax.numpy.swapaxes(out, -1, axis)
|
||||
|
||||
|
||||
@export
|
||||
@partial(api.jit, static_argnames=['kth', 'axis'])
|
||||
def argpartition(a: ArrayLike, kth: int, axis: int = -1) -> Array:
|
||||
"""Returns indices that partially sort an array.
|
||||
|
||||
JAX implementation of :func:`numpy.argpartition`. The JAX version differs from
|
||||
NumPy in the treatment of NaN entries: NaNs which have the negative bit set are
|
||||
sorted to the beginning of the array.
|
||||
|
||||
Args:
|
||||
a: array to be partitioned.
|
||||
kth: static integer index about which to partition the array.
|
||||
axis: static integer axis along which to partition the array; default is -1.
|
||||
|
||||
Returns:
|
||||
Indices which partition ``a`` at the ``kth`` value along ``axis``. The entries
|
||||
before ``kth`` are indices of values smaller than ``take(a, kth, axis)``, and
|
||||
entries after ``kth`` are indices of values larger than ``take(a, kth, axis)``
|
||||
|
||||
Note:
|
||||
The JAX version requires the ``kth`` argument to be a static integer rather than
|
||||
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
|
||||
you're only accessing the top or bottom k values of the output, it may be more
|
||||
efficient to call :func:`jax.lax.top_k` directly.
|
||||
|
||||
See Also:
|
||||
- :func:`jax.numpy.partition`: direct partial sort
|
||||
- :func:`jax.numpy.argsort`: full indirect sort
|
||||
- :func:`jax.lax.top_k`: directly find the top k entries
|
||||
- :func:`jax.lax.approx_max_k`: compute the approximate top k entries
|
||||
- :func:`jax.lax.approx_min_k`: compute the approximate bottom k entries
|
||||
|
||||
Examples:
|
||||
>>> x = jnp.array([6, 8, 4, 3, 1, 9, 7, 5, 2, 3])
|
||||
>>> kth = 4
|
||||
>>> idx = jnp.argpartition(x, kth)
|
||||
>>> idx
|
||||
Array([4, 8, 3, 9, 2, 0, 1, 5, 6, 7], dtype=int32)
|
||||
|
||||
The result is a sequence of indices that partially sort the input. All indices
|
||||
before ``kth`` are of values smaller than the pivot value, and all indices
|
||||
after ``kth`` are of values larger than the pivot value:
|
||||
|
||||
>>> x_partitioned = x[idx]
|
||||
>>> smallest_values = x_partitioned[:kth]
|
||||
>>> pivot_value = x_partitioned[kth]
|
||||
>>> largest_values = x_partitioned[kth + 1:]
|
||||
>>> print(smallest_values, pivot_value, largest_values)
|
||||
[1 2 3 3] 4 [6 8 9 7 5]
|
||||
|
||||
Notice that among ``smallest_values`` and ``largest_values``, the returned
|
||||
order is arbitrary and implementation-dependent.
|
||||
"""
|
||||
# TODO(jakevdp): handle NaN values like numpy.
|
||||
arr = util.ensure_arraylike("partition", a)
|
||||
if dtypes.issubdtype(arr.dtype, np.complexfloating):
|
||||
raise NotImplementedError("jnp.argpartition for complex dtype is not implemented.")
|
||||
axis = canonicalize_axis(axis, arr.ndim)
|
||||
kth = canonicalize_axis(kth, arr.shape[axis])
|
||||
|
||||
arr = jax.numpy.swapaxes(arr, axis, -1)
|
||||
if dtypes.isdtype(arr.dtype, "unsigned integer"):
|
||||
# Here, we apply a trick to handle correctly 0 values for unsigned integers
|
||||
bottom_ind = lax.top_k(-(arr + 1), kth + 1)[1]
|
||||
else:
|
||||
bottom_ind = lax.top_k(-arr, kth + 1)[1]
|
||||
|
||||
# To avoid issues with duplicate values, we compute the top indices via a proxy
|
||||
set_to_zero = lambda a, i: a.at[i].set(0)
|
||||
for _ in range(arr.ndim - 1):
|
||||
set_to_zero = jax.vmap(set_to_zero)
|
||||
proxy = set_to_zero(jax.numpy.ones(arr.shape), bottom_ind)
|
||||
top_ind = lax.top_k(proxy, arr.shape[-1] - kth - 1)[1]
|
||||
out = lax.concatenate([bottom_ind, top_ind], dimension=arr.ndim - 1)
|
||||
return jax.numpy.swapaxes(out, -1, axis)
|
||||
|
||||
|
||||
@export
|
||||
@api.jit
|
||||
def sort_complex(a: ArrayLike) -> Array:
|
||||
"""Return a sorted copy of complex array.
|
||||
|
||||
JAX implementation of :func:`numpy.sort_complex`.
|
||||
|
||||
Complex numbers are sorted lexicographically, meaning by their real part
|
||||
first, and then by their imaginary part if real parts are equal.
|
||||
|
||||
Args:
|
||||
a: input array. If dtype is not complex, the array will be upcast to complex.
|
||||
|
||||
Returns:
|
||||
A sorted array of the same shape and complex dtype as the input. If ``a``
|
||||
is multi-dimensional, it is sorted along the last axis.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.sort`: Return a sorted copy of an array.
|
||||
|
||||
Examples:
|
||||
>>> a = jnp.array([1+2j, 2+4j, 3-1j, 2+3j])
|
||||
>>> jnp.sort_complex(a)
|
||||
Array([1.+2.j, 2.+3.j, 2.+4.j, 3.-1.j], dtype=complex64)
|
||||
|
||||
Multi-dimensional arrays are sorted along the last axis:
|
||||
|
||||
>>> a = jnp.array([[5, 3, 4],
|
||||
... [6, 9, 2]])
|
||||
>>> jnp.sort_complex(a)
|
||||
Array([[3.+0.j, 4.+0.j, 5.+0.j],
|
||||
[2.+0.j, 6.+0.j, 9.+0.j]], dtype=complex64)
|
||||
"""
|
||||
a = util.ensure_arraylike("sort_complex", a)
|
||||
a = lax.sort(a)
|
||||
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
|
||||
|
||||
|
||||
@export
|
||||
@partial(api.jit, static_argnames=('axis',))
|
||||
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
|
||||
"""Sort a sequence of keys in lexicographic order.
|
||||
|
||||
JAX implementation of :func:`numpy.lexsort`.
|
||||
|
||||
Args:
|
||||
keys: a sequence of arrays to sort; all arrays must have the same shape.
|
||||
The last key in the sequence is used as the primary key.
|
||||
axis: the axis along which to sort (default: -1).
|
||||
|
||||
Returns:
|
||||
An array of integers of shape ``keys[0].shape`` giving the indices of the
|
||||
entries in lexicographically-sorted order.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.argsort`: sort a single entry by index.
|
||||
- :func:`jax.lax.sort`: direct XLA sorting API.
|
||||
|
||||
Examples:
|
||||
:func:`lexsort` with a single key is equivalent to :func:`argsort`:
|
||||
|
||||
>>> key1 = jnp.array([4, 2, 3, 2, 5])
|
||||
>>> jnp.lexsort([key1])
|
||||
Array([1, 3, 2, 0, 4], dtype=int32)
|
||||
>>> jnp.argsort(key1)
|
||||
Array([1, 3, 2, 0, 4], dtype=int32)
|
||||
|
||||
With multiple keys, :func:`lexsort` uses the last key as the primary key:
|
||||
|
||||
>>> key2 = jnp.array([2, 1, 1, 2, 2])
|
||||
>>> jnp.lexsort([key1, key2])
|
||||
Array([1, 2, 3, 0, 4], dtype=int32)
|
||||
|
||||
The meaning of the indices become more clear when printing the sorted keys:
|
||||
|
||||
>>> indices = jnp.lexsort([key1, key2])
|
||||
>>> print(f"{key1[indices]}\\n{key2[indices]}")
|
||||
[2 3 2 4 5]
|
||||
[1 1 2 2 2]
|
||||
|
||||
Notice that the elements of ``key2`` appear in order, and within the sequences
|
||||
of duplicated values the corresponding elements of ```key1`` appear in order.
|
||||
|
||||
For multi-dimensional inputs, :func:`lexsort` defaults to sorting along the
|
||||
last axis:
|
||||
|
||||
>>> key1 = jnp.array([[2, 4, 2, 3],
|
||||
... [3, 1, 2, 2]])
|
||||
>>> key2 = jnp.array([[1, 2, 1, 3],
|
||||
... [2, 1, 2, 1]])
|
||||
>>> jnp.lexsort([key1, key2])
|
||||
Array([[0, 2, 1, 3],
|
||||
[1, 3, 2, 0]], dtype=int32)
|
||||
|
||||
A different sort axis can be chosen using the ``axis`` keyword; here we sort
|
||||
along the leading axis:
|
||||
|
||||
>>> jnp.lexsort([key1, key2], axis=0)
|
||||
Array([[0, 1, 0, 1],
|
||||
[1, 0, 1, 0]], dtype=int32)
|
||||
"""
|
||||
key_arrays = util.ensure_arraylike_tuple("lexsort", tuple(keys))
|
||||
if len(key_arrays) == 0:
|
||||
raise TypeError("need sequence of keys with len > 0 in lexsort")
|
||||
if len({np.shape(key) for key in key_arrays}) > 1:
|
||||
raise ValueError("all keys need to be the same shape")
|
||||
if np.ndim(key_arrays[0]) == 0:
|
||||
return jax.numpy.array(0, dtype=dtypes.canonicalize_dtype(dtypes.int_))
|
||||
axis = canonicalize_axis(axis, np.ndim(key_arrays[0]))
|
||||
use_64bit_index = key_arrays[0].shape[axis] >= (1 << 31)
|
||||
iota = lax.broadcasted_iota(np.dtype('int64') if use_64bit_index else dtypes.int_,
|
||||
np.shape(key_arrays[0]), axis)
|
||||
return lax.sort((*key_arrays[::-1], iota), dimension=axis, num_keys=len(key_arrays))[-1]
|
@ -34,7 +34,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
arange as arange,
|
||||
argmax as argmax,
|
||||
argmin as argmin,
|
||||
argsort as argsort,
|
||||
argwhere as argwhere,
|
||||
around as around,
|
||||
array as array,
|
||||
@ -139,7 +138,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
kaiser as kaiser,
|
||||
kron as kron,
|
||||
lcm as lcm,
|
||||
lexsort as lexsort,
|
||||
linspace as linspace,
|
||||
load as load,
|
||||
logspace as logspace,
|
||||
@ -153,7 +151,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
nan_to_num as nan_to_num,
|
||||
nanargmax as nanargmax,
|
||||
nanargmin as nanargmin,
|
||||
argpartition as argpartition,
|
||||
ndim as ndim,
|
||||
newaxis as newaxis,
|
||||
nonzero as nonzero,
|
||||
@ -162,7 +159,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
outer as outer,
|
||||
packbits as packbits,
|
||||
pad as pad,
|
||||
partition as partition,
|
||||
permute_dims as permute_dims,
|
||||
pi as pi,
|
||||
piecewise as piecewise,
|
||||
@ -188,8 +184,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
set_printoptions as set_printoptions,
|
||||
shape as shape,
|
||||
size as size,
|
||||
sort as sort,
|
||||
sort_complex as sort_complex,
|
||||
split as split,
|
||||
squeeze as squeeze,
|
||||
stack as stack,
|
||||
@ -259,6 +253,15 @@ from jax._src.numpy.scalar_types import (
|
||||
uint64 as uint64,
|
||||
)
|
||||
|
||||
from jax._src.numpy.sorting import (
|
||||
argpartition as argpartition,
|
||||
argsort as argsort,
|
||||
lexsort as lexsort,
|
||||
partition as partition,
|
||||
sort as sort,
|
||||
sort_complex as sort_complex,
|
||||
)
|
||||
|
||||
# NumPy generic scalar types:
|
||||
from numpy import (
|
||||
character as character,
|
||||
|
@ -1600,7 +1600,7 @@ class DebugInfoTest(jtu.JaxTestCase):
|
||||
expected_jaxpr_debug_infos=[
|
||||
# TODO(necula): this should not be pointing into the JAX internals
|
||||
re.compile(r"traced_for=jit, fun=checked_fun at .*jax._src.checkify.py:.*, arg_names=args\[0\]"),
|
||||
re.compile(r"traced_for=jit, fun=argsort at .*numpy.lax_numpy.py:.*, arg_names=a, result_paths="),
|
||||
re.compile(r"traced_for=jit, fun=argsort at .*numpy.sorting.py:.*, arg_names=a, result_paths="),
|
||||
"traced_for=pmap, fun=my_f, arg_names=my_x, result_paths=[0]",
|
||||
],
|
||||
expected_tracer_debug_infos=[
|
||||
|
Loading…
x
Reference in New Issue
Block a user