mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Change default kind for jnp.argsort to stable
. Warn if anything other than stable
is passed.
This commit is contained in:
parent
613dc4bffa
commit
af5d3675dd
@ -5293,12 +5293,18 @@ def lexsort(keys, axis=-1):
|
||||
return lax.sort((*keys[::-1], iota), dimension=axis, num_keys=len(keys))[-1]
|
||||
|
||||
|
||||
@_wraps(np.argsort)
|
||||
_ARGSORT_DOC = """
|
||||
Only :code:`kind='stable'` is supported. Other :code:`kind` values will produce
|
||||
a warning and be treated as if they were :code:`'stable'`.
|
||||
"""
|
||||
|
||||
@_wraps(np.argsort, lax_description=_ARGSORT_DOC)
|
||||
@partial(jit, static_argnames=('axis', 'kind', 'order'))
|
||||
def argsort(a, axis: Optional[int] = -1, kind='quicksort', order=None):
|
||||
def argsort(a, axis: Optional[int] = -1, kind='stable', order=None):
|
||||
_check_arraylike("argsort", a)
|
||||
if kind != 'quicksort':
|
||||
warnings.warn("'kind' argument to argsort is ignored.")
|
||||
if kind != 'stable':
|
||||
warnings.warn("'kind' argument to argsort is ignored; only 'stable' sorts "
|
||||
"are supported.")
|
||||
if order is not None:
|
||||
raise ValueError("'order' argument to argsort is not supported.")
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user