Remove deprecated kind argument from jnp.sort and jnp.argsort.

PiperOrigin-RevId: 631429900
This commit is contained in:
Jake VanderPlas 2024-05-07 08:18:05 -07:00 committed by jax authors
parent 500da57e91
commit 9b79f6520a
3 changed files with 23 additions and 23 deletions

View File

@ -69,6 +69,8 @@ Remember to align the itemized text with the first line of an item within a list
{func}`~jax.numpy.apply_over_axes`, {func}`~jax.numpy.inner`,
{func}`~jax.numpy.outer`, {func}`~jax.numpy.cross`,
{func}`~jax.numpy.kron`, and {func}`~jax.numpy.lexsort`.
* The ``kind`` argument to {func}`jax.numpy.sort` and {func}`jax.numpy.argsort`
is now removed. Use `stable=True` or `stable=False` instead.
* Bug fixes
* {func}`jax.numpy.astype` will now always return a copy when `copy=True`.

View File

@ -4502,30 +4502,29 @@ def _nanargmin(a, axis: int | None = None, keepdims : bool = False):
@util.implements(np.sort, extra_params="""
kind : deprecated; specify sort algorithm using stable=True or stable=False
order : not supported
stable : bool, default=True
Specify whether to use a stable sort.
descending : bool, default=False
Specify whether to do a descending sort.
""")
kind : deprecated; specify sort algorithm using stable=True or stable=False
order : not supported
""")
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
def sort(
a: ArrayLike,
axis: int | None = -1,
kind: str | None = None,
order: None = None, *,
*,
kind: None = None,
order: None = None,
stable: bool = True,
descending: bool = False,
) -> Array:
util.check_arraylike("sort", a)
if kind is not None:
# Deprecated 2024-01-05
warnings.warn("The 'kind' argument to sort has no effect, and is deprecated. "
"Use stable=True or stable=False to specify sort stability.",
category=DeprecationWarning, stacklevel=2)
raise TypeError("'kind' argument to sort is not supported. Use"
" stable=True or stable=False to specify sort stability.")
if order is not None:
raise ValueError("'order' argument to sort is not supported.")
raise TypeError("'order' argument to sort is not supported.")
if axis is None:
arr = ravel(a)
axis = 0
@ -4562,31 +4561,30 @@ def lexsort(keys: Array | np.ndarray | Sequence[ArrayLike], axis: int = -1) -> A
@util.implements(np.argsort, extra_params="""
kind : deprecated; specify sort algorithm using stable=True or stable=False
order : not supported
stable : bool, default=True
Specify whether to use a stable sort.
descending : bool, default=False
Specify whether to do a descending sort.
kind : deprecated; specify sort algorithm using stable=True or stable=False
order : not supported
""")
@partial(jit, static_argnames=('axis', 'kind', 'order', 'stable', 'descending'))
def argsort(
a: ArrayLike,
axis: int | None = -1,
kind: str | None = None,
*,
kind: None = None,
order: None = None,
*, stable: bool = True,
stable: bool = True,
descending: bool = False,
) -> Array:
util.check_arraylike("argsort", a)
arr = asarray(a)
if kind is not None:
# Deprecated 2024-01-05
warnings.warn("The 'kind' argument to argsort has no effect, and is deprecated. "
"Use stable=True or stable=False to specify sort stability.",
category=DeprecationWarning, stacklevel=2)
raise TypeError("'kind' argument to argsort is not supported. Use"
" stable=True or stable=False to specify sort stability.")
if order is not None:
raise ValueError("'order' argument to argsort is not supported.")
raise TypeError("'order' argument to argsort is not supported.")
if axis is None:
arr = ravel(arr)
axis = 0

View File

@ -83,11 +83,11 @@ def argpartition(a: ArrayLike, kth: int, axis: int = ...) -> Array: ...
def argsort(
a: ArrayLike,
axis: Optional[int] = ...,
kind: str | None = ...,
order: None = ...,
*,
stable: builtins.bool = ...,
descending: builtins.bool = ...,
kind: str | None = ...,
order: None = ...,
) -> Array: ...
def argwhere(
a: ArrayLike,
@ -737,11 +737,11 @@ sometrue = any
def sort(
a: ArrayLike,
axis: Optional[int] = ...,
kind: str | None = ...,
order: None = ...,
*,
stable: builtins.bool = ...,
descending: builtins.bool = ...,
kind: str | None = ...,
order: None = ...,
) -> Array: ...
def sort_complex(a: ArrayLike) -> Array: ...
def split(