Deprecate jnp.msort following deprecation of numpy.msort

This commit is contained in:
Jake VanderPlas 2022-12-07 10:08:18 -08:00
parent 794cec15cf
commit 09d1b6d8d5
3 changed files with 12 additions and 2 deletions

View File

@ -34,6 +34,9 @@ Remember to align the itemized text with the first line of an item within a list
Before, it was using the 0th device for the JAX-default backend.
* A number of `jax.numpy` functions now have their arguments marked as
positional-only, matching NumPy.
* `jnp.msort` is now deprecated, following the deprecation of `np.msort` in numpy 1.24.
It will be removed in a future release, in accordance with the {ref}`api-compatibility`
policy. It can be replaced with `jnp.sort(a, axis=0)`.
## jaxlib 0.4.0
* Changes

View File

@ -3485,6 +3485,8 @@ def argsort(a, axis: Optional[int] = -1, kind='stable', order=None):
@_wraps(np.msort)
def msort(a):
# TODO(jakevdp): remove msort after Feb 2023
warnings.warn("jnp.msort is deprecated; use jnp.sort(a, axis=0) instead", DeprecationWarning)
return sort(a, axis=0)

View File

@ -3583,8 +3583,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
def testMsort(self, dtype, shape):
rng = jtu.rand_some_equal(self.rng())
args_maker = lambda: [rng(shape, dtype)]
self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker)
self._CompileAndCheck(jnp.msort, args_maker)
with self.assertWarnsRegex(DeprecationWarning, "jnp.msort is deprecated"):
jnp.msort(*args_maker())
with jtu.ignore_warning(category=DeprecationWarning, message=".*msort is deprecated"):
self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker)
self._CompileAndCheck(jnp.msort, args_maker)
@jtu.sample_product(
[dict(shifts=shifts, axis=axis)