mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Better documentation for jnp.lexsort
This commit is contained in:
parent
5d3cac6603
commit
8800fe2870
@ -10154,9 +10154,69 @@ def sort_complex(a: ArrayLike) -> Array:
|
||||
a = lax.sort(asarray(a))
|
||||
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
|
||||
|
||||
@util.implements(np.lexsort)
|
||||
|
||||
@partial(jit, static_argnames=('axis',))
|
||||
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
|
||||
"""Sort a sequence of keys in lexicographic order.
|
||||
|
||||
JAX implementation of :func:`numpy.lexsort`.
|
||||
|
||||
Args:
|
||||
keys: a sequence of arrays to sort; all arrays must have the same shape.
|
||||
The last key in the sequence is used as the primary key.
|
||||
axis: the axis along which to sort (default: -1).
|
||||
|
||||
Returns:
|
||||
An array of integers of shape ``keys[0].shape`` giving the indices of the
|
||||
entries in lexicographically-sorted order.
|
||||
|
||||
See also:
|
||||
- :func:`jax.numpy.argsort`: sort a single entry by index.
|
||||
- :func:`jax.lax.sort`: direct XLA sorting API.
|
||||
|
||||
Examples:
|
||||
:func:`lexsort` with a single key is equivalent to :func:`argsort`:
|
||||
|
||||
>>> key1 = jnp.array([4, 2, 3, 2, 5])
|
||||
>>> jnp.lexsort([key1])
|
||||
Array([1, 3, 2, 0, 4], dtype=int32)
|
||||
>>> jnp.argsort(key1)
|
||||
Array([1, 3, 2, 0, 4], dtype=int32)
|
||||
|
||||
With multiple keys, :func:`lexsort` uses the last key as the primary key:
|
||||
|
||||
>>> key2 = jnp.array([2, 1, 1, 2, 2])
|
||||
>>> jnp.lexsort([key1, key2])
|
||||
Array([1, 2, 3, 0, 4], dtype=int32)
|
||||
|
||||
The meaning of the indices become more clear when printing the sorted keys:
|
||||
|
||||
>>> indices = jnp.lexsort([key1, key2])
|
||||
>>> print(f"{key1[indices]}\\n{key2[indices]}")
|
||||
[2 3 2 4 5]
|
||||
[1 1 2 2 2]
|
||||
|
||||
Notice that the elements of ``key2`` appear in order, and within the sequences
|
||||
of duplicated values the corresponding elements of ```key1`` appear in order.
|
||||
|
||||
For multi-dimensional inputs, :func:`lexsort` defaults to sorting along the
|
||||
last axis:
|
||||
|
||||
>>> key1 = jnp.array([[2, 4, 2, 3],
|
||||
... [3, 1, 2, 2]])
|
||||
>>> key2 = jnp.array([[1, 2, 1, 3],
|
||||
... [2, 1, 2, 1]])
|
||||
>>> jnp.lexsort([key1, key2])
|
||||
Array([[0, 2, 1, 3],
|
||||
[1, 3, 2, 0]], dtype=int32)
|
||||
|
||||
A different sort axis can be chosen using the ``axis`` keyword; here we sort
|
||||
along the leading axis:
|
||||
|
||||
>>> jnp.lexsort([key1, key2], axis=0)
|
||||
Array([[0, 1, 0, 1],
|
||||
[1, 0, 1, 0]], dtype=int32)
|
||||
"""
|
||||
key_tuple = tuple(keys)
|
||||
util.check_arraylike("lexsort", *key_tuple)
|
||||
key_arrays = tuple(asarray(k) for k in key_tuple)
|
||||
|
Loading…
x
Reference in New Issue
Block a user