Finalize deprecation of jax.random.shuffle

This has been raising a DeprecationWarning for longer than anyone can remember.

PiperOrigin-RevId: 656765001
This commit is contained in:
Jake VanderPlas 2024-07-27 11:21:02 -07:00 committed by jax authors
parent dab15d6fdd
commit a17c8d945b
3 changed files with 5 additions and 22 deletions

View File

@ -41,6 +41,8 @@ Remember to align the itemized text with the first line of an item within a list
or `enable_xla=False` is now deprecated and this support will be removed in
a future version.
Native serialization has been the default since JAX 0.4.16 (September 2023).
* The previously-deprecated function `jax.random.shuffle` has been removed;
instead use `jax.random.permutation` with `independent=True`.
## jaxlib 0.4.31

View File

@ -516,24 +516,6 @@ def _randint(key, shape, minval, maxval, dtype) -> Array:
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array:
"""Shuffle the elements of an array uniformly at random along an axis.
Args:
key: a PRNG key used as the random key.
x: the array to be shuffled.
axis: optional, an int axis along which to shuffle (default 0).
Returns:
A shuffled version of x.
"""
msg = ("jax.random.shuffle is deprecated and will be removed in a future release. "
"Use jax.random.permutation with independent=True.")
warnings.warn(msg, FutureWarning)
key, _ = _check_prng_key("shuffle", key)
return _shuffle(key, x, axis)
def permutation(key: KeyArrayLike,
x: int | ArrayLike,
axis: int = 0,

View File

@ -242,7 +242,6 @@ from jax._src.random import (
randint as randint,
random_gamma_p as random_gamma_p,
rayleigh as rayleigh,
shuffle as _deprecated_shuffle,
split as split,
t as t,
triangular as triangular,
@ -254,16 +253,16 @@ from jax._src.random import (
)
_deprecations = {
# Added November 6, 2023; but has been raising a FutureWarning since JAX 0.1.66
# Finalized Jul 26 2024; remove after Nov 2024.
"shuffle": (
"jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.",
_deprecated_shuffle,
None,
)
}
import typing
if typing.TYPE_CHECKING:
shuffle = _deprecated_shuffle
pass
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)