Merge pull request #15321 from jakevdp:remove-msort

PiperOrigin-RevId: 520952178
This commit is contained in:
jax authors 2023-03-31 10:16:18 -07:00
commit 2841bd310e
5 changed files with 2 additions and 24 deletions

View File

@ -11,6 +11,8 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* The `in_axis_resources` and `out_axis_resources` arguments of pjit have been
deprecated. Please use `in_shardings` and `out_shardings` respectively.
* The function `jax.numpy.msort` has been removed. It has been deprecated since
JAX v0.4.1. Use `jnp.sort(a, axis=0)` instead.
## jaxlib 0.4.9

View File

@ -270,7 +270,6 @@ namespace; they are listed below.
mod
modf
moveaxis
msort
multiply
nan_to_num
nanargmax

View File

@ -3550,13 +3550,6 @@ def argsort(a: ArrayLike, axis: Optional[int] = -1, kind: str = 'stable', order=
return perm
@util._wraps(np.msort)
def msort(a):
# TODO(jakevdp): remove msort after Feb 2023
warnings.warn("jnp.msort is deprecated; use jnp.sort(a, axis=0) instead", DeprecationWarning)
return sort(a, axis=0)
@util._wraps(np.partition, lax_description="""
The JAX version requires the ``kth`` argument to be a static integer rather than
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If

View File

@ -170,7 +170,6 @@ from jax._src.numpy.lax_numpy import (
matmul as matmul,
meshgrid as meshgrid,
moveaxis as moveaxis,
msort as msort,
nan as nan,
nan_to_num as nan_to_num,
nanargmax as nanargmax,

View File

@ -3571,21 +3571,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
self._CompileAndCheck(jnp_fun, args_maker)
@jtu.sample_product(
dtype=all_dtypes,
shape=nonzerodim_shapes,
)
def testMsort(self, dtype, shape):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
with self.assertWarnsRegex(DeprecationWarning, "jnp.msort is deprecated"):
jnp.msort(*args_maker())
with jtu.ignore_warning(category=DeprecationWarning, message=".*msort is deprecated"):
self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker)
self._CompileAndCheck(jnp.msort, args_maker)
@jtu.sample_product(
[{'shape': shape, 'axis': axis, 'kth': kth}
for shape in nonzerodim_shapes