Deprecate jax.random.shuffle

This has been long deprecated, but this PR uses the standard deprecation
framework to make it easier to finalize.
This commit is contained in:
Jake VanderPlas 2023-11-06 12:21:56 -08:00
parent 189d3aba2d
commit 5f7335fb55
3 changed files with 9 additions and 4 deletions

View File

@ -60,7 +60,6 @@ Random Samplers
rademacher
randint
rayleigh
shuffle
t
triangular
truncated_normal

View File

@ -172,7 +172,7 @@ from jax._src.random import (
random_gamma_p as random_gamma_p,
rayleigh as rayleigh,
rbg_key as _deprecated_rbg_key,
shuffle as shuffle,
shuffle as _deprecated_shuffle,
split as split,
t as t,
threefry2x32_key as _deprecated_threefry2x32_key,
@ -235,6 +235,11 @@ _deprecations = {
"jax.random.key_impl(key), jax.eval_shape(jax.random.key, 0).dtype, or similar.",
_deprecated_default_prng_impl,
),
# Added November 6, 2023; but has been raising a FutureWarning since JAX 0.1.66
"shuffle": (
"jax.random.shuffle is deprecated. Use jax.random.permutation with independent=True.",
_deprecated_shuffle,
)
}
import typing
@ -242,6 +247,7 @@ if typing.TYPE_CHECKING:
PRNGKeyArray = typing.Any
KeyArray = typing.Any
default_prng_impl = _deprecated_default_prng_impl
shuffle = _deprecated_shuffle
threefry_2x32 = _deprecated_threefry_2x32
threefry2x32_p = _deprecated_threefry2x32_p
threefry2x32_key = _deprecated_threefry2x32_key

View File

@ -209,9 +209,9 @@ class LaxRandomTest(jtu.JaxTestCase):
rand = lambda key: random.shuffle(key, x)
crand = jax.jit(rand)
with self.assertWarns(FutureWarning):
with self.assertWarns((DeprecationWarning, FutureWarning)):
perm1 = rand(key)
with self.assertWarns(FutureWarning):
with self.assertWarns((DeprecationWarning, FutureWarning)):
perm2 = crand(key)
self.assertAllClose(perm1, perm2)