diff --git a/CHANGELOG.md b/CHANGELOG.md index cf4c59516..d1a115749 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,9 @@ Remember to align the itemized text with the first line of an item within a list Before, it was using the 0th device for the JAX-default backend. * A number of `jax.numpy` functions now have their arguments marked as positional-only, matching NumPy. + * `jnp.msort` is now deprecated, following the deprecation of `np.msort` in numpy 1.24. + It will be removed in a future release, in accordance with the {ref}`api-compatibility` + policy. It can be replaced with `jnp.sort(a, axis=0)`. ## jaxlib 0.4.0 * Changes diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 7a9cd5200..f47913ff0 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -3485,6 +3485,8 @@ def argsort(a, axis: Optional[int] = -1, kind='stable', order=None): @_wraps(np.msort) def msort(a): + # TODO(jakevdp): remove msort after Feb 2023 + warnings.warn("jnp.msort is deprecated; use jnp.sort(a, axis=0) instead", DeprecationWarning) return sort(a, axis=0) diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index 8a56b92e0..60f382f0a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3583,8 +3583,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase): def testMsort(self, dtype, shape): rng = jtu.rand_some_equal(self.rng()) args_maker = lambda: [rng(shape, dtype)] - self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker) - self._CompileAndCheck(jnp.msort, args_maker) + + with self.assertWarnsRegex(DeprecationWarning, "jnp.msort is deprecated"): + jnp.msort(*args_maker()) + + with jtu.ignore_warning(category=DeprecationWarning, message=".*msort is deprecated"): + self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker) + self._CompileAndCheck(jnp.msort, args_maker) @jtu.sample_product( [dict(shifts=shifts, axis=axis)