Better documentation for jnp.lexsort

This commit is contained in:
Jake VanderPlas 2024-10-21 16:33:14 -07:00
parent 5d3cac6603
commit 8800fe2870

View File

@ -10154,9 +10154,69 @@ def sort_complex(a: ArrayLike) -> Array:
a = lax.sort(asarray(a))
return lax.convert_element_type(a, dtypes.to_complex_dtype(a.dtype))
@util.implements(np.lexsort)
@partial(jit, static_argnames=('axis',))
def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> Array:
"""Sort a sequence of keys in lexicographic order.
JAX implementation of :func:`numpy.lexsort`.
Args:
keys: a sequence of arrays to sort; all arrays must have the same shape.
The last key in the sequence is used as the primary key.
axis: the axis along which to sort (default: -1).
Returns:
An array of integers of shape ``keys[0].shape`` giving the indices of the
entries in lexicographically-sorted order.
See also:
- :func:`jax.numpy.argsort`: sort a single entry by index.
- :func:`jax.lax.sort`: direct XLA sorting API.
Examples:
:func:`lexsort` with a single key is equivalent to :func:`argsort`:
>>> key1 = jnp.array([4, 2, 3, 2, 5])
>>> jnp.lexsort([key1])
Array([1, 3, 2, 0, 4], dtype=int32)
>>> jnp.argsort(key1)
Array([1, 3, 2, 0, 4], dtype=int32)
With multiple keys, :func:`lexsort` uses the last key as the primary key:
>>> key2 = jnp.array([2, 1, 1, 2, 2])
>>> jnp.lexsort([key1, key2])
Array([1, 2, 3, 0, 4], dtype=int32)
The meaning of the indices become more clear when printing the sorted keys:
>>> indices = jnp.lexsort([key1, key2])
>>> print(f"{key1[indices]}\\n{key2[indices]}")
[2 3 2 4 5]
[1 1 2 2 2]
Notice that the elements of ``key2`` appear in order, and within the sequences
of duplicated values the corresponding elements of ```key1`` appear in order.
For multi-dimensional inputs, :func:`lexsort` defaults to sorting along the
last axis:
>>> key1 = jnp.array([[2, 4, 2, 3],
... [3, 1, 2, 2]])
>>> key2 = jnp.array([[1, 2, 1, 3],
... [2, 1, 2, 1]])
>>> jnp.lexsort([key1, key2])
Array([[0, 2, 1, 3],
[1, 3, 2, 0]], dtype=int32)
A different sort axis can be chosen using the ``axis`` keyword; here we sort
along the leading axis:
>>> jnp.lexsort([key1, key2], axis=0)
Array([[0, 1, 0, 1],
[1, 0, 1, 0]], dtype=int32)
"""
key_tuple = tuple(keys)
util.check_arraylike("lexsort", *key_tuple)
key_arrays = tuple(asarray(k) for k in key_tuple)