[random] deprecate named key creation functions

This commit is contained in:
Jake VanderPlas 2023-09-21 13:57:49 -07:00
parent 6a551a1efa
commit 22818d664f
4 changed files with 36 additions and 15 deletions

View File

@ -10,6 +10,11 @@ Remember to align the itemized text with the first line of an item within a list
* Deprecations
* Removed the deprecated module `jax.abstract_arrays` and all its contents.
* Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument
to {func}`jax.random.PRNGKey` or {func}`jax.random.key` instead:
* `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')`
* `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')`
* `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')`
# jaxlib 0.4.17

View File

@ -195,8 +195,8 @@ def _check_default_impl_with_no_custom_prng(impl, name):
default_name = config.jax_default_prng_impl
if not config.jax_enable_custom_prng and default_impl is not impl:
raise RuntimeError('jax_enable_custom_prng must be enabled in order '
f'to seed an RNG with an implementation "f{name}" '
f'differing from the default "f{default_name}".')
f'to seed an RNG with an implementation "{name}" '
f'differing from the default "{default_name}".')
def threefry2x32_key(seed: int) -> KeyArray:
"""Creates a threefry2x32 PRNG key from an integer seed."""

View File

@ -170,17 +170,17 @@ from jax._src.random import (
randint as randint,
random_gamma_p as random_gamma_p,
rayleigh as rayleigh,
rbg_key as rbg_key,
rbg_key as _deprecated_rbg_key,
shuffle as shuffle,
split as split,
t as t,
threefry_2x32 as threefry_2x32,
threefry2x32_key as threefry2x32_key,
threefry2x32_key as _deprecated_threefry2x32_key,
threefry2x32_p as threefry2x32_p,
triangular as triangular,
truncated_normal as truncated_normal,
uniform as uniform,
unsafe_rbg_key as unsafe_rbg_key,
unsafe_rbg_key as _deprecated_unsafe_rbg_key,
wald as wald,
weibull_min as weibull_min,
wrap_key_data as wrap_key_data,
@ -202,12 +202,25 @@ _deprecations = {
"jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of "
"typed prng keys.", _PRNGKeyArray
),
# Added September 21, 2023
"threefry2x32_key": (
"jax.random.threefry2x32_key(seed) is deprecated. "
"Use jax.random.PRNGKey(seed, 'threefry2x32')", _deprecated_threefry2x32_key),
"rbg_key": (
"jax.random.rbg_key(seed) is deprecated. "
"Use jax.random.PRNGKey(seed, 'rbg')", _deprecated_rbg_key),
"unsafe_rbg_key": (
"jax.random.unsafe_rbg_key(seed) is deprecated. "
"Use jax.random.PRNGKey(seed, 'unsafe_rbg')", _deprecated_unsafe_rbg_key),
}
import typing
if typing.TYPE_CHECKING:
PRNGKeyArray = typing.Any
KeyArray = typing.Any
threefry_2x32_key = _deprecated_threefry2x32_key
rbg_key = _deprecated_rbg_key
unsafe_rbg_key = _deprecated_unsafe_rbg_key
else:
from jax._src.deprecations import deprecation_getattr as _deprecation_getattr
__getattr__ = _deprecation_getattr(__name__, _deprecations)

View File

@ -522,18 +522,21 @@ class PrngTest(jtu.JaxTestCase):
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_threefry2x32_key(self):
self.check_key_has_impl(random.threefry2x32_key(42),
prng_internal.threefry_prng_impl)
with self.assertWarnsRegex(DeprecationWarning, "jax.random.threefry2x32_key"):
self.check_key_has_impl(random.threefry2x32_key(42),
prng_internal.threefry_prng_impl)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_rbg_key(self):
self.check_key_has_impl(random.rbg_key(42),
prng_internal.rbg_prng_impl)
with self.assertWarnsRegex(DeprecationWarning, "jax.random.rbg_key"):
self.check_key_has_impl(random.rbg_key(42),
prng_internal.rbg_prng_impl)
@skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag')
def test_explicit_unsafe_rbg_key(self):
self.check_key_has_impl(random.unsafe_rbg_key(42),
prng_internal.unsafe_rbg_prng_impl)
with self.assertWarnsRegex(DeprecationWarning, "jax.random.unsafe_rbg_key"):
self.check_key_has_impl(random.unsafe_rbg_key(42),
prng_internal.unsafe_rbg_prng_impl)
@parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl}
for ctor in KEY_CTORS
@ -579,7 +582,7 @@ class PrngTest(jtu.JaxTestCase):
class ThreefryPrngTest(jtu.JaxTestCase):
@parameterized.parameters([{'make_key': ctor} for ctor in [
random.threefry2x32_key,
jax_random.threefry2x32_key,
partial(random.PRNGKey, impl='threefry2x32'),
partial(random.key, impl='threefry2x32')]])
def test_seed_no_implicit_transfers(self, make_key):
@ -640,7 +643,7 @@ class LaxRandomTest(jtu.JaxTestCase):
f'{expected_freq}\n{actual_freq}')
def make_key(self, seed):
return random.threefry2x32_key(seed)
return random.PRNGKey(seed, impl='threefry2x32')
@jtu.sample_product(
num=(None, 6, (6,), (2, 3), (2, 3, 4)),
@ -2296,7 +2299,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest):
@jtu.with_config(jax_default_prng_impl='rbg')
class LaxRandomWithRBGPRNGTest(LaxRandomTest):
def make_key(self, seed):
return random.rbg_key(seed)
return random.PRNGKey(seed, impl='rbg')
def test_split_shape(self):
key = self.make_key(73)
@ -2372,7 +2375,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest):
@jtu.with_config(jax_default_prng_impl='unsafe_rbg')
class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest):
def make_key(self, seed):
return random.unsafe_rbg_key(seed)
return random.PRNGKey(seed, impl="unsafe_rbg")
def _sampler_unimplemented_with_custom_prng(*args, **kwargs):