mirror of
https://github.com/ROCm/jax.git
synced 2025-04-19 05:16:06 +00:00
Merge pull request #15321 from jakevdp:remove-msort
PiperOrigin-RevId: 520952178
This commit is contained in:
commit
2841bd310e
@ -11,6 +11,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
* Deprecations
|
||||
* The `in_axis_resources` and `out_axis_resources` arguments of pjit have been
|
||||
deprecated. Please use `in_shardings` and `out_shardings` respectively.
|
||||
* The function `jax.numpy.msort` has been removed. It has been deprecated since
|
||||
JAX v0.4.1. Use `jnp.sort(a, axis=0)` instead.
|
||||
|
||||
## jaxlib 0.4.9
|
||||
|
||||
|
@ -270,7 +270,6 @@ namespace; they are listed below.
|
||||
mod
|
||||
modf
|
||||
moveaxis
|
||||
msort
|
||||
multiply
|
||||
nan_to_num
|
||||
nanargmax
|
||||
|
@ -3550,13 +3550,6 @@ def argsort(a: ArrayLike, axis: Optional[int] = -1, kind: str = 'stable', order=
|
||||
return perm
|
||||
|
||||
|
||||
@util._wraps(np.msort)
|
||||
def msort(a):
|
||||
# TODO(jakevdp): remove msort after Feb 2023
|
||||
warnings.warn("jnp.msort is deprecated; use jnp.sort(a, axis=0) instead", DeprecationWarning)
|
||||
return sort(a, axis=0)
|
||||
|
||||
|
||||
@util._wraps(np.partition, lax_description="""
|
||||
The JAX version requires the ``kth`` argument to be a static integer rather than
|
||||
a general array. This is implemented via two calls to :func:`jax.lax.top_k`. If
|
||||
|
@ -170,7 +170,6 @@ from jax._src.numpy.lax_numpy import (
|
||||
matmul as matmul,
|
||||
meshgrid as meshgrid,
|
||||
moveaxis as moveaxis,
|
||||
msort as msort,
|
||||
nan as nan,
|
||||
nan_to_num as nan_to_num,
|
||||
nanargmax as nanargmax,
|
||||
|
@ -3571,21 +3571,6 @@ class LaxBackedNumpyTests(jtu.JaxTestCase):
|
||||
self._CheckAgainstNumpy(np_fun, jnp_fun, args_maker)
|
||||
self._CompileAndCheck(jnp_fun, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
dtype=all_dtypes,
|
||||
shape=nonzerodim_shapes,
|
||||
)
|
||||
def testMsort(self, dtype, shape):
|
||||
rng = jtu.rand_some_equal(self.rng())
|
||||
args_maker = lambda: [rng(shape, dtype)]
|
||||
|
||||
with self.assertWarnsRegex(DeprecationWarning, "jnp.msort is deprecated"):
|
||||
jnp.msort(*args_maker())
|
||||
|
||||
with jtu.ignore_warning(category=DeprecationWarning, message=".*msort is deprecated"):
|
||||
self._CheckAgainstNumpy(np.msort, jnp.msort, args_maker)
|
||||
self._CompileAndCheck(jnp.msort, args_maker)
|
||||
|
||||
@jtu.sample_product(
|
||||
[{'shape': shape, 'axis': axis, 'kth': kth}
|
||||
for shape in nonzerodim_shapes
|
||||
|
Loading…
x
Reference in New Issue
Block a user