diff --git a/CHANGELOG.md b/CHANGELOG.md index 379a65532..3cdbefcb3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,11 @@ Remember to align the itemized text with the first line of an item within a list * Deprecations * Removed the deprecated module `jax.abstract_arrays` and all its contents. + * Named key constructors in {mod}`jax.random` are deprecated. Pass the `impl` argument + to {func}`jax.random.PRNGKey` or {func}`jax.random.key` instead: + * `random.threefry2x32_key(seed)` becomes `random.PRNGKey(seed, impl='threefry2x32')` + * `random.rbg_key(seed)` becomes `random.PRNGKey(seed, impl='rbg')` + * `random.unsafe_rbg_key(seed)` becomes `random.PRNGKey(seed, impl='unsafe_rbg')` # jaxlib 0.4.17 diff --git a/jax/_src/random.py b/jax/_src/random.py index bd4353b95..f867f82c6 100644 --- a/jax/_src/random.py +++ b/jax/_src/random.py @@ -195,8 +195,8 @@ def _check_default_impl_with_no_custom_prng(impl, name): default_name = config.jax_default_prng_impl if not config.jax_enable_custom_prng and default_impl is not impl: raise RuntimeError('jax_enable_custom_prng must be enabled in order ' - f'to seed an RNG with an implementation "f{name}" ' - f'differing from the default "f{default_name}".') + f'to seed an RNG with an implementation "{name}" ' + f'differing from the default "{default_name}".') def threefry2x32_key(seed: int) -> KeyArray: """Creates a threefry2x32 PRNG key from an integer seed.""" diff --git a/jax/random.py b/jax/random.py index 6e0843a22..fde8f0a6d 100644 --- a/jax/random.py +++ b/jax/random.py @@ -170,17 +170,17 @@ from jax._src.random import ( randint as randint, random_gamma_p as random_gamma_p, rayleigh as rayleigh, - rbg_key as rbg_key, + rbg_key as _deprecated_rbg_key, shuffle as shuffle, split as split, t as t, threefry_2x32 as threefry_2x32, - threefry2x32_key as threefry2x32_key, + threefry2x32_key as _deprecated_threefry2x32_key, threefry2x32_p as threefry2x32_p, triangular as triangular, truncated_normal as truncated_normal, uniform as uniform, - unsafe_rbg_key as unsafe_rbg_key, + unsafe_rbg_key as _deprecated_unsafe_rbg_key, wald as wald, weibull_min as weibull_min, wrap_key_data as wrap_key_data, @@ -202,12 +202,25 @@ _deprecations = { "jax.dtypes.issubdtype(arr, jax.dtypes.prng_key) for runtime detection of " "typed prng keys.", _PRNGKeyArray ), + # Added September 21, 2023 + "threefry2x32_key": ( + "jax.random.threefry2x32_key(seed) is deprecated. " + "Use jax.random.PRNGKey(seed, 'threefry2x32')", _deprecated_threefry2x32_key), + "rbg_key": ( + "jax.random.rbg_key(seed) is deprecated. " + "Use jax.random.PRNGKey(seed, 'rbg')", _deprecated_rbg_key), + "unsafe_rbg_key": ( + "jax.random.unsafe_rbg_key(seed) is deprecated. " + "Use jax.random.PRNGKey(seed, 'unsafe_rbg')", _deprecated_unsafe_rbg_key), } import typing if typing.TYPE_CHECKING: PRNGKeyArray = typing.Any KeyArray = typing.Any + threefry_2x32_key = _deprecated_threefry2x32_key + rbg_key = _deprecated_rbg_key + unsafe_rbg_key = _deprecated_unsafe_rbg_key else: from jax._src.deprecations import deprecation_getattr as _deprecation_getattr __getattr__ = _deprecation_getattr(__name__, _deprecations) diff --git a/tests/random_test.py b/tests/random_test.py index 37b9f20e0..89e1e9115 100644 --- a/tests/random_test.py +++ b/tests/random_test.py @@ -522,18 +522,21 @@ class PrngTest(jtu.JaxTestCase): @skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag') def test_explicit_threefry2x32_key(self): - self.check_key_has_impl(random.threefry2x32_key(42), - prng_internal.threefry_prng_impl) + with self.assertWarnsRegex(DeprecationWarning, "jax.random.threefry2x32_key"): + self.check_key_has_impl(random.threefry2x32_key(42), + prng_internal.threefry_prng_impl) @skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag') def test_explicit_rbg_key(self): - self.check_key_has_impl(random.rbg_key(42), - prng_internal.rbg_prng_impl) + with self.assertWarnsRegex(DeprecationWarning, "jax.random.rbg_key"): + self.check_key_has_impl(random.rbg_key(42), + prng_internal.rbg_prng_impl) @skipIf(not config.jax_enable_custom_prng, 'relies on typed key upgrade flag') def test_explicit_unsafe_rbg_key(self): - self.check_key_has_impl(random.unsafe_rbg_key(42), - prng_internal.unsafe_rbg_prng_impl) + with self.assertWarnsRegex(DeprecationWarning, "jax.random.unsafe_rbg_key"): + self.check_key_has_impl(random.unsafe_rbg_key(42), + prng_internal.unsafe_rbg_prng_impl) @parameterized.parameters([{'make_key': ctor, 'name': name, 'impl': impl} for ctor in KEY_CTORS @@ -579,7 +582,7 @@ class PrngTest(jtu.JaxTestCase): class ThreefryPrngTest(jtu.JaxTestCase): @parameterized.parameters([{'make_key': ctor} for ctor in [ - random.threefry2x32_key, + jax_random.threefry2x32_key, partial(random.PRNGKey, impl='threefry2x32'), partial(random.key, impl='threefry2x32')]]) def test_seed_no_implicit_transfers(self, make_key): @@ -640,7 +643,7 @@ class LaxRandomTest(jtu.JaxTestCase): f'{expected_freq}\n{actual_freq}') def make_key(self, seed): - return random.threefry2x32_key(seed) + return random.PRNGKey(seed, impl='threefry2x32') @jtu.sample_product( num=(None, 6, (6,), (2, 3), (2, 3, 4)), @@ -2296,7 +2299,7 @@ class LaxRandomWithCustomPRNGTest(LaxRandomTest): @jtu.with_config(jax_default_prng_impl='rbg') class LaxRandomWithRBGPRNGTest(LaxRandomTest): def make_key(self, seed): - return random.rbg_key(seed) + return random.PRNGKey(seed, impl='rbg') def test_split_shape(self): key = self.make_key(73) @@ -2372,7 +2375,7 @@ class LaxRandomWithRBGPRNGTest(LaxRandomTest): @jtu.with_config(jax_default_prng_impl='unsafe_rbg') class LaxRandomWithUnsafeRBGPRNGTest(LaxRandomWithRBGPRNGTest): def make_key(self, seed): - return random.unsafe_rbg_key(seed) + return random.PRNGKey(seed, impl="unsafe_rbg") def _sampler_unimplemented_with_custom_prng(*args, **kwargs):