Change default kind for jnp.argsort to stable. Warn if anything other than stable is passed.

This commit is contained in:
Peter Hawkins 2021-10-15 15:39:40 -04:00
parent 613dc4bffa
commit af5d3675dd

View File

@ -5293,12 +5293,18 @@ def lexsort(keys, axis=-1):
return lax.sort((*keys[::-1], iota), dimension=axis, num_keys=len(keys))[-1]
@_wraps(np.argsort)
_ARGSORT_DOC = """
Only :code:`kind='stable'` is supported. Other :code:`kind` values will produce
a warning and be treated as if they were :code:`'stable'`.
"""
@_wraps(np.argsort, lax_description=_ARGSORT_DOC)
@partial(jit, static_argnames=('axis', 'kind', 'order'))
def argsort(a, axis: Optional[int] = -1, kind='quicksort', order=None):
def argsort(a, axis: Optional[int] = -1, kind='stable', order=None):
_check_arraylike("argsort", a)
if kind != 'quicksort':
warnings.warn("'kind' argument to argsort is ignored.")
if kind != 'stable':
warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts "
"are supported.")
if order is not None:
raise ValueError("'order' argument to argsort is not supported.")