mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Deprecate jnp.msort following deprecation of numpy.msort
This commit is contained in:
parent
794cec15cf
commit
09d1b6d8d5
@ -34,6 +34,9 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
Before, it was using the 0th device for the JAX-default backend.
|
||||
* A number of `jax.numpy` functions now have their arguments marked as
|
||||
positional-only, matching NumPy.
|
||||
* `jnp.msort` is now deprecated, following the deprecation of `np.msort` in numpy 1.24.
|
||||
It will be removed in a future release, in accordance with the {ref}`api-compatibility`
|
||||
policy. It can be replaced with `jnp.sort(a, axis=0)`.
|
||||
|
||||
## jaxlib 0.4.0
|
||||
* Changes
|
||||
|
@ -3485,6 +3485,8 @@ def argsort(a, axis: Optional[int] = -1, kind='stable', order=None):
|
||||
|
||||
@_wraps(np.msort)
|
||||
def msort(a):
|
||||
# TODO(jakevdp): remove msort after Feb 2023
|
||||
warnings.warn("jnp.msort is deprecated; use jnp.sort(a, axis=0) instead", DeprecationWarning)
|
||||
return sort(a, axis=0)
|
||||
|
||||
|
||||
|
@ -3583,8 +3583,13 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
def testMsort(self, dtype, shape):
|
||||
rng = jtu.rand_some_equal(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker)
|
||||
self._CompileAndCheck(jnp.msort, args_maker)
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "jnp.msort is deprecated"):
|
||||
jnp.msort(*args_maker())
|
||||
|
||||
with jtu.ignore_warning(category=DeprecationWarning, message=".*msort is deprecated"):
|
||||
self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker)
|
||||
self._CompileAndCheck(jnp.msort, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[dict(shifts=shifts, axis=axis)
|
||||
|
Loading…
x
Reference in New Issue
Block a user