mirror of
https://github.com/ROCm/jax.git
synced 2025-04-15 19:36:06 +00:00
Finalize deprecation of jax.random.shuffle
This has been raising a DeprecationWarning for longer than anyone can remember. PiperOrigin-RevId: 656765001
This commit is contained in:
parent
dab15d6fdd
commit
a17c8d945b
@ -41,6 +41,8 @@ Remember to align the itemized text with the first line of an item within a list
|
||||
or `enable_xla=False` is now deprecated and this support will be removed in
|
||||
a future version.
|
||||
Native serialization has been the default since JAX 0.4.16 (September 2023).
|
||||
* The previously-deprecated function `jax.random.shuffle` has been removed;
|
||||
instead use `jax.random.permutation` with `independent=True`.
|
||||
|
||||
## jaxlib 0.4.31
|
||||
|
||||
|
@ -516,24 +516,6 @@ def _randint(key, shape, minval, maxval, dtype) -> Array:
|
||||
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
|
||||
|
||||
|
||||
def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array:
|
||||
"""Shuffle the elements of an array uniformly at random along an axis.
|
||||
|
||||
Args:
|
||||
key: a PRNG key used as the random key.
|
||||
x: the array to be shuffled.
|
||||
axis: optional, an int axis along which to shuffle (default 0).
|
||||
|
||||
Returns:
|
||||
A shuffled version of x.
|
||||
"""
|
||||
msg = ("jax.random.shuffle is deprecated and will be removed in a future release. "
|
||||
"Use jax.random.permutation with independent=True.")
|
||||
warnings.warn(msg, FutureWarning)
|
||||
key, _ = _check_prng_key("shuffle", key)
|
||||
return _shuffle(key, x, axis)
|
||||
|
||||
|
||||
def permutation(key: KeyArrayLike,
|
||||
x: int | ArrayLike,
|
||||
axis: int = 0,
|
||||
|
@ -242,7 +242,6 @@ from jax._src.random import (
|
||||
randint as randint,
|
||||
random_gamma_p as random_gamma_p,
|
||||
rayleigh as rayleigh,
|
||||
shuffle as _deprecated_shuffle,
|
||||
split as split,
|
||||
t as t,
|
||||
triangular as triangular,
|
||||
@ -254,16 +253,16 @@ from jax._src.random import (
|
||||
)
|
||||
|
||||
_deprecations = {
|
||||
# Added November 6, 2023; but has been raising a FutureWarning since JAX 0.1.66
|
||||
# Finalized Jul 26 2024; remove after Nov 2024.
|
||||
"shuffle": (
|
||||
"jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.",
|
||||
_deprecated_shuffle,
|
||||
None,
|
||||
)
|
||||
}
|
||||
|
||||
import typing
|
||||
if typing.TYPE_CHECKING:
|
||||
shuffle = _deprecated_shuffle
|
||||
pass
|
||||
else:
|
||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||
|
Loading…
x
Reference in New Issue
Block a user