mirror of
https://github.com/ROCm/jax.git
synced 2025-04-16 03:46:06 +00:00
Finalize deprecation of jax.random.shuffle
This has been raising a DeprecationWarning for longer than anyone can remember. PiperOrigin-RevId: 656765001
This commit is contained in:
parent
dab15d6fdd
commit
a17c8d945b
@ -41,6 +41,8 @@ Remember to align the itemized text with the first line of an item within a list
|
|||||||
or `enable_xla=False` is now deprecated and this support will be removed in
|
or `enable_xla=False` is now deprecated and this support will be removed in
|
||||||
a future version.
|
a future version.
|
||||||
Native serialization has been the default since JAX 0.4.16 (September 2023).
|
Native serialization has been the default since JAX 0.4.16 (September 2023).
|
||||||
|
* The previously-deprecated function `jax.random.shuffle` has been removed;
|
||||||
|
instead use `jax.random.permutation` with `independent=True`.
|
||||||
|
|
||||||
## jaxlib 0.4.31
|
## jaxlib 0.4.31
|
||||||
|
|
||||||
|
@ -516,24 +516,6 @@ def _randint(key, shape, minval, maxval, dtype) -> Array:
|
|||||||
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
|
return lax.add(minval, lax.convert_element_type(random_offset, dtype))
|
||||||
|
|
||||||
|
|
||||||
def shuffle(key: KeyArrayLike, x: ArrayLike, axis: int = 0) -> Array:
|
|
||||||
"""Shuffle the elements of an array uniformly at random along an axis.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
key: a PRNG key used as the random key.
|
|
||||||
x: the array to be shuffled.
|
|
||||||
axis: optional, an int axis along which to shuffle (default 0).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A shuffled version of x.
|
|
||||||
"""
|
|
||||||
msg = ("jax.random.shuffle is deprecated and will be removed in a future release. "
|
|
||||||
"Use jax.random.permutation with independent=True.")
|
|
||||||
warnings.warn(msg, FutureWarning)
|
|
||||||
key, _ = _check_prng_key("shuffle", key)
|
|
||||||
return _shuffle(key, x, axis)
|
|
||||||
|
|
||||||
|
|
||||||
def permutation(key: KeyArrayLike,
|
def permutation(key: KeyArrayLike,
|
||||||
x: int | ArrayLike,
|
x: int | ArrayLike,
|
||||||
axis: int = 0,
|
axis: int = 0,
|
||||||
|
@ -242,7 +242,6 @@ from jax._src.random import (
|
|||||||
randint as randint,
|
randint as randint,
|
||||||
random_gamma_p as random_gamma_p,
|
random_gamma_p as random_gamma_p,
|
||||||
rayleigh as rayleigh,
|
rayleigh as rayleigh,
|
||||||
shuffle as _deprecated_shuffle,
|
|
||||||
split as split,
|
split as split,
|
||||||
t as t,
|
t as t,
|
||||||
triangular as triangular,
|
triangular as triangular,
|
||||||
@ -254,16 +253,16 @@ from jax._src.random import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
_deprecations = {
|
_deprecations = {
|
||||||
# Added November 6, 2023; but has been raising a FutureWarning since JAX 0.1.66
|
# Finalized Jul 26 2024; remove after Nov 2024.
|
||||||
"shuffle": (
|
"shuffle": (
|
||||||
"jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.",
|
"jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.",
|
||||||
_deprecated_shuffle,
|
None,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
import typing
|
import typing
|
||||||
if typing.TYPE_CHECKING:
|
if typing.TYPE_CHECKING:
|
||||||
shuffle = _deprecated_shuffle
|
pass
|
||||||
else:
|
else:
|
||||||
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
|
||||||
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
__getattr__ = _deprecation_getattr(__name__, _deprecations)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user